allow using unconditional as prompts

This commit is contained in:
Georg Kucsko
2023-04-21 16:14:10 -04:00
parent c372430112
commit 7d39f48c7a
2 changed files with 57 additions and 19 deletions

View File

@@ -36,6 +36,7 @@ def semantic_to_waveform(
history_prompt: Optional[str] = None, history_prompt: Optional[str] = None,
temp: float = 0.7, temp: float = 0.7,
silent: bool = False, silent: bool = False,
output_full: bool = False,
): ):
"""Generate audio array from semantic input. """Generate audio array from semantic input.
@@ -44,31 +45,49 @@ def semantic_to_waveform(
history_prompt: history choice for audio cloning history_prompt: history choice for audio cloning
temp: generation temperature (1.0 more diverse, 0.0 more conservative) temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar silent: disable progress bar
output_full: return full generation to be used as a history prompt
Returns: Returns:
numpy audio array at sample frequency 24khz numpy audio array at sample frequency 24khz
""" """
x_coarse_gen = generate_coarse( coarse_tokens = generate_coarse(
semantic_tokens, semantic_tokens,
history_prompt=history_prompt, history_prompt=history_prompt,
temp=temp, temp=temp,
silent=silent, silent=silent,
) )
x_fine_gen = generate_fine( fine_tokens = generate_fine(
x_coarse_gen, coarse_tokens,
history_prompt=history_prompt, history_prompt=history_prompt,
temp=0.5, temp=0.5,
) )
audio_arr = codec_decode(x_fine_gen) audio_arr = codec_decode(fine_tokens)
if output_full:
full_generation = {
"semantic_prompt": semantic_tokens,
"coarse_prompt": coarse_tokens,
"fine_prompt": fine_tokens,
}
return full_generation, audio_arr
return audio_arr return audio_arr
def save_as_prompt(filepath, full_generation):
assert(filepath.endswith(".npz"))
assert(isinstance(full_generation, dict))
assert("semantic_prompt" in full_generation)
assert("coarse_prompt" in full_generation)
assert("fine_prompt" in full_generation)
np.savez(filepath, **full_generation)
def generate_audio( def generate_audio(
text: str, text: str,
history_prompt: Optional[str] = None, history_prompt: Optional[str] = None,
text_temp: float = 0.7, text_temp: float = 0.7,
waveform_temp: float = 0.7, waveform_temp: float = 0.7,
silent: bool = False, silent: bool = False,
output_full: bool = False,
): ):
"""Generate audio array from input text. """Generate audio array from input text.
@@ -78,14 +97,24 @@ def generate_audio(
text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar silent: disable progress bar
output_full: return full generation to be used as a history prompt
Returns: Returns:
numpy audio array at sample frequency 24khz numpy audio array at sample frequency 24khz
""" """
x_semantic = text_to_semantic( semantic_tokens = text_to_semantic(
text, history_prompt=history_prompt, temp=text_temp, silent=silent, text, history_prompt=history_prompt, temp=text_temp, silent=silent,
) )
audio_arr = semantic_to_waveform( out = semantic_to_waveform(
x_semantic, history_prompt=history_prompt, temp=waveform_temp, silent=silent, semantic_tokens,
history_prompt=history_prompt,
temp=waveform_temp,
silent=silent,
output_full=output_full,
) )
if output_full:
full_generation, audio_arr = out
return full_generation, audio_arr
else:
audio_arr = out
return audio_arr return audio_arr

View File

@@ -365,10 +365,13 @@ def generate_text_semantic(
text = _normalize_whitespace(text) text = _normalize_whitespace(text)
assert len(text.strip()) > 0 assert len(text.strip()) > 0
if history_prompt is not None: if history_prompt is not None:
assert (history_prompt in ALLOWED_PROMPTS) if history_prompt.endswith(".npz"):
semantic_history = np.load( semantic_history = np.load(history_prompt)["semantic_prompt"]
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") else:
)["semantic_prompt"] assert (history_prompt in ALLOWED_PROMPTS)
semantic_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["semantic_prompt"]
assert ( assert (
isinstance(semantic_history, np.ndarray) isinstance(semantic_history, np.ndarray)
and len(semantic_history.shape) == 1 and len(semantic_history.shape) == 1
@@ -509,10 +512,13 @@ def generate_coarse(
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
if history_prompt is not None: if history_prompt is not None:
assert (history_prompt in ALLOWED_PROMPTS) if history_prompt.endswith(".npz"):
x_history = np.load( x_history = np.load(history_prompt)
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") else:
) assert (history_prompt in ALLOWED_PROMPTS)
x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)
x_semantic_history = x_history["semantic_prompt"] x_semantic_history = x_history["semantic_prompt"]
x_coarse_history = x_history["coarse_prompt"] x_coarse_history = x_history["coarse_prompt"]
assert ( assert (
@@ -652,10 +658,13 @@ def generate_fine(
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1 and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
) )
if history_prompt is not None: if history_prompt is not None:
assert (history_prompt in ALLOWED_PROMPTS) if history_prompt.endswith(".npz"):
x_fine_history = np.load( x_fine_history = np.load(history_prompt)["fine_prompt"]
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") else:
)["fine_prompt"] assert (history_prompt in ALLOWED_PROMPTS)
x_fine_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["fine_prompt"]
assert ( assert (
isinstance(x_fine_history, np.ndarray) isinstance(x_fine_history, np.ndarray)
and len(x_fine_history.shape) == 2 and len(x_fine_history.shape) == 2