allow using unconditional as prompts

This commit is contained in:
Georg Kucsko
2023-04-21 16:14:10 -04:00
parent c372430112
commit 7d39f48c7a
2 changed files with 57 additions and 19 deletions

View File

@@ -365,10 +365,13 @@ def generate_text_semantic(
text = _normalize_whitespace(text)
assert len(text.strip()) > 0
if history_prompt is not None:
assert (history_prompt in ALLOWED_PROMPTS)
semantic_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["semantic_prompt"]
if history_prompt.endswith(".npz"):
semantic_history = np.load(history_prompt)["semantic_prompt"]
else:
assert (history_prompt in ALLOWED_PROMPTS)
semantic_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["semantic_prompt"]
assert (
isinstance(semantic_history, np.ndarray)
and len(semantic_history.shape) == 1
@@ -509,10 +512,13 @@ def generate_coarse(
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
if history_prompt is not None:
assert (history_prompt in ALLOWED_PROMPTS)
x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)
if history_prompt.endswith(".npz"):
x_history = np.load(history_prompt)
else:
assert (history_prompt in ALLOWED_PROMPTS)
x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)
x_semantic_history = x_history["semantic_prompt"]
x_coarse_history = x_history["coarse_prompt"]
assert (
@@ -652,10 +658,13 @@ def generate_fine(
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
)
if history_prompt is not None:
assert (history_prompt in ALLOWED_PROMPTS)
x_fine_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["fine_prompt"]
if history_prompt.endswith(".npz"):
x_fine_history = np.load(history_prompt)["fine_prompt"]
else:
assert (history_prompt in ALLOWED_PROMPTS)
x_fine_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["fine_prompt"]
assert (
isinstance(x_fine_history, np.ndarray)
and len(x_fine_history.shape) == 2