switch some prints to logging

This commit is contained in:
Georg Kucsko
2023-04-13 14:27:35 -04:00
parent 2c038176b3
commit 76966a8b1f

View File

@@ -184,10 +184,10 @@ def _load_model(ckpt_path, device, model_type="text"):
os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
):
print(f"found outdated {model_type} model, removing...")
logger.warning(f"found outdated {model_type} model, removing...")
os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
print(f"{model_type} model not found, downloading...")
logger.info(f"{model_type} model not found, downloading...")
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
@@ -215,7 +215,7 @@ def _load_model(ckpt_path, device, model_type="text"):
model.load_state_dict(state_dict, strict=False)
n_params = model.get_num_params()
val_loss = checkpoint["best_val_loss"].item()
print(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
model.eval()
model.to(device)
del checkpoint, state_dict
@@ -346,7 +346,7 @@ def generate_text_semantic(
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
if len(encoded_text) > 256:
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
print(f"warning, text too long, lopping of last {p}%")
logger.warning(f"warning, text too long, lopping of last {p}%")
encoded_text = encoded_text[:256]
encoded_text = np.pad(
encoded_text,