diff --git a/bark/generation.py b/bark/generation.py index ec313e7..4ac165c 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -298,6 +298,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te if model_type not in ("text", "coarse", "fine"): raise NotImplementedError() global models + global models_devices device = _grab_best_device(use_gpu=use_gpu) model_key = f"{model_type}" if OFFLOAD_CPU: @@ -317,6 +318,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te def load_codec_model(use_gpu=True, force_reload=False): global models + global models_devices device = _grab_best_device(use_gpu=use_gpu) if device == "mps": # encodec doesn't support mps @@ -421,6 +423,7 @@ def generate_text_semantic( semantic_history = None # load models if not yet exist global models + global models_devices if "text" not in models: preload_models() model_container = models["text"] @@ -616,6 +619,7 @@ def generate_coarse( x_coarse_history = np.array([], dtype=np.int32) # load models if not yet exist global models + global models_devices if "coarse" not in models: preload_models() model = models["coarse"] @@ -755,6 +759,7 @@ def generate_fine( n_coarse = x_coarse_gen.shape[0] # load models if not yet exist global models + global models_devices if "fine" not in models: preload_models() model = models["fine"] @@ -842,6 +847,7 @@ def codec_decode(fine_tokens): """Turn quantized audio codes into audio array using encodec.""" # load models if not yet exist global models + global models_devices if "codec" not in models: preload_models() model = models["codec"]