mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 03:07:58 +01:00
switch some prints to logging
This commit is contained in:
@@ -184,10 +184,10 @@ def _load_model(ckpt_path, device, model_type="text"):
|
||||
os.path.exists(ckpt_path) and
|
||||
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
|
||||
):
|
||||
print(f"found outdated {model_type} model, removing...")
|
||||
logger.warning(f"found outdated {model_type} model, removing...")
|
||||
os.remove(ckpt_path)
|
||||
if not os.path.exists(ckpt_path):
|
||||
print(f"{model_type} model not found, downloading...")
|
||||
logger.info(f"{model_type} model not found, downloading...")
|
||||
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
@@ -215,7 +215,7 @@ def _load_model(ckpt_path, device, model_type="text"):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
n_params = model.get_num_params()
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
print(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
|
||||
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
del checkpoint, state_dict
|
||||
@@ -346,7 +346,7 @@ def generate_text_semantic(
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
if len(encoded_text) > 256:
|
||||
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
|
||||
print(f"warning, text too long, lopping of last {p}%")
|
||||
logger.warning(f"warning, text too long, lopping of last {p}%")
|
||||
encoded_text = encoded_text[:256]
|
||||
encoded_text = np.pad(
|
||||
encoded_text,
|
||||
|
||||
Reference in New Issue
Block a user