mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 11:48:09 +01:00
Merge pull request #146 from jn-jairo/offload-cpu
Option to offload models to cpu
This commit is contained in:
@@ -36,6 +36,9 @@ else:
|
|||||||
global models
|
global models
|
||||||
models = {}
|
models = {}
|
||||||
|
|
||||||
|
global models_devices
|
||||||
|
models_devices = {}
|
||||||
|
|
||||||
|
|
||||||
CONTEXT_WINDOW_SIZE = 1024
|
CONTEXT_WINDOW_SIZE = 1024
|
||||||
|
|
||||||
@@ -84,6 +87,7 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno",
|
|||||||
|
|
||||||
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
||||||
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
|
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
|
||||||
|
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
|
||||||
|
|
||||||
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
||||||
|
|
||||||
@@ -294,8 +298,12 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
|||||||
if model_type not in ("text", "coarse", "fine"):
|
if model_type not in ("text", "coarse", "fine"):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
global models
|
global models
|
||||||
|
global models_devices
|
||||||
device = _grab_best_device(use_gpu=use_gpu)
|
device = _grab_best_device(use_gpu=use_gpu)
|
||||||
model_key = f"{model_type}"
|
model_key = f"{model_type}"
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
models_devices[model_key] = device
|
||||||
|
device = "cpu"
|
||||||
if model_key not in models or force_reload:
|
if model_key not in models or force_reload:
|
||||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||||
clean_models(model_key=model_key)
|
clean_models(model_key=model_key)
|
||||||
@@ -310,11 +318,15 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
|||||||
|
|
||||||
def load_codec_model(use_gpu=True, force_reload=False):
|
def load_codec_model(use_gpu=True, force_reload=False):
|
||||||
global models
|
global models
|
||||||
|
global models_devices
|
||||||
device = _grab_best_device(use_gpu=use_gpu)
|
device = _grab_best_device(use_gpu=use_gpu)
|
||||||
if device == "mps":
|
if device == "mps":
|
||||||
# encodec doesn't support mps
|
# encodec doesn't support mps
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
model_key = "codec"
|
model_key = "codec"
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
models_devices[model_key] = device
|
||||||
|
device = "cpu"
|
||||||
if model_key not in models or force_reload:
|
if model_key not in models or force_reload:
|
||||||
clean_models(model_key=model_key)
|
clean_models(model_key=model_key)
|
||||||
model = _load_codec_model(device)
|
model = _load_codec_model(device)
|
||||||
@@ -411,12 +423,15 @@ def generate_text_semantic(
|
|||||||
semantic_history = None
|
semantic_history = None
|
||||||
# load models if not yet exist
|
# load models if not yet exist
|
||||||
global models
|
global models
|
||||||
|
global models_devices
|
||||||
if "text" not in models:
|
if "text" not in models:
|
||||||
preload_models()
|
preload_models()
|
||||||
model_container = models["text"]
|
model_container = models["text"]
|
||||||
model = model_container["model"]
|
model = model_container["model"]
|
||||||
tokenizer = model_container["tokenizer"]
|
tokenizer = model_container["tokenizer"]
|
||||||
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
|
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to(models_devices["text"])
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
if len(encoded_text) > 256:
|
if len(encoded_text) > 256:
|
||||||
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
|
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
|
||||||
@@ -514,6 +529,8 @@ def generate_text_semantic(
|
|||||||
pbar_state = req_pbar_state
|
pbar_state = req_pbar_state
|
||||||
pbar.close()
|
pbar.close()
|
||||||
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
|
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to("cpu")
|
||||||
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
|
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
|
||||||
_clear_cuda_cache()
|
_clear_cuda_cache()
|
||||||
return out
|
return out
|
||||||
@@ -602,9 +619,12 @@ def generate_coarse(
|
|||||||
x_coarse_history = np.array([], dtype=np.int32)
|
x_coarse_history = np.array([], dtype=np.int32)
|
||||||
# load models if not yet exist
|
# load models if not yet exist
|
||||||
global models
|
global models
|
||||||
|
global models_devices
|
||||||
if "coarse" not in models:
|
if "coarse" not in models:
|
||||||
preload_models()
|
preload_models()
|
||||||
model = models["coarse"]
|
model = models["coarse"]
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to(models_devices["coarse"])
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
# start loop
|
# start loop
|
||||||
n_steps = int(
|
n_steps = int(
|
||||||
@@ -691,6 +711,8 @@ def generate_coarse(
|
|||||||
n_step += 1
|
n_step += 1
|
||||||
del x_in
|
del x_in
|
||||||
del x_semantic_in
|
del x_semantic_in
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to("cpu")
|
||||||
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
|
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
|
||||||
del x_coarse_in
|
del x_coarse_in
|
||||||
assert len(gen_coarse_arr) == n_steps
|
assert len(gen_coarse_arr) == n_steps
|
||||||
@@ -737,9 +759,12 @@ def generate_fine(
|
|||||||
n_coarse = x_coarse_gen.shape[0]
|
n_coarse = x_coarse_gen.shape[0]
|
||||||
# load models if not yet exist
|
# load models if not yet exist
|
||||||
global models
|
global models
|
||||||
|
global models_devices
|
||||||
if "fine" not in models:
|
if "fine" not in models:
|
||||||
preload_models()
|
preload_models()
|
||||||
model = models["fine"]
|
model = models["fine"]
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to(models_devices["fine"])
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
# make input arr
|
# make input arr
|
||||||
in_arr = np.vstack(
|
in_arr = np.vstack(
|
||||||
@@ -808,6 +833,8 @@ def generate_fine(
|
|||||||
del in_buffer
|
del in_buffer
|
||||||
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
|
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
|
||||||
del in_arr
|
del in_arr
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to("cpu")
|
||||||
gen_fine_arr = gen_fine_arr[:, n_history:]
|
gen_fine_arr = gen_fine_arr[:, n_history:]
|
||||||
if n_remove_from_end > 0:
|
if n_remove_from_end > 0:
|
||||||
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
|
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
|
||||||
@@ -820,9 +847,12 @@ def codec_decode(fine_tokens):
|
|||||||
"""Turn quantized audio codes into audio array using encodec."""
|
"""Turn quantized audio codes into audio array using encodec."""
|
||||||
# load models if not yet exist
|
# load models if not yet exist
|
||||||
global models
|
global models
|
||||||
|
global models_devices
|
||||||
if "codec" not in models:
|
if "codec" not in models:
|
||||||
preload_models()
|
preload_models()
|
||||||
model = models["codec"]
|
model = models["codec"]
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to(models_devices["codec"])
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
arr = torch.from_numpy(fine_tokens)[None]
|
arr = torch.from_numpy(fine_tokens)[None]
|
||||||
arr = arr.to(device)
|
arr = arr.to(device)
|
||||||
@@ -831,4 +861,6 @@ def codec_decode(fine_tokens):
|
|||||||
out = model.decoder(emb)
|
out = model.decoder(emb)
|
||||||
audio_arr = out.detach().cpu().numpy().squeeze()
|
audio_arr = out.detach().cpu().numpy().squeeze()
|
||||||
del arr, emb, out
|
del arr, emb, out
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to("cpu")
|
||||||
return audio_arr
|
return audio_arr
|
||||||
|
|||||||
Reference in New Issue
Block a user