diff --git a/bark/generation.py b/bark/generation.py index f1c32f5..0c0e8da 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -184,10 +184,10 @@ def _load_model(ckpt_path, device, model_type="text"): os.path.exists(ckpt_path) and _md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"] ): - print(f"found outdated {model_type} model, removing...") + logger.warning(f"found outdated {model_type} model, removing...") os.remove(ckpt_path) if not os.path.exists(ckpt_path): - print(f"{model_type} model not found, downloading...") + logger.info(f"{model_type} model not found, downloading...") _download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack @@ -215,7 +215,7 @@ def _load_model(ckpt_path, device, model_type="text"): model.load_state_dict(state_dict, strict=False) n_params = model.get_num_params() val_loss = checkpoint["best_val_loss"].item() - print(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") + logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") model.eval() model.to(device) del checkpoint, state_dict @@ -346,7 +346,7 @@ def generate_text_semantic( device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" if len(encoded_text) > 256: p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) - print(f"warning, text too long, lopping of last {p}%") + logger.warning(f"warning, text too long, lopping of last {p}%") encoded_text = encoded_text[:256] encoded_text = np.pad( encoded_text,