small updates

This commit is contained in:
Georg Kucsko
2023-04-21 15:13:16 -04:00
parent d53b43e865
commit 9751cfbfc4
4 changed files with 22 additions and 10 deletions

View File

@@ -9,6 +9,7 @@ def text_to_semantic(
text: str,
history_prompt: Optional[str] = None,
temp: float = 0.7,
silent: bool = False,
):
"""Generate semantic array from text.
@@ -16,6 +17,7 @@ def text_to_semantic(
text: text to be turned into audio
history_prompt: history choice for audio cloning
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar
Returns:
numpy semantic array to be fed into `semantic_to_waveform`
@@ -24,6 +26,7 @@ def text_to_semantic(
text,
history_prompt=history_prompt,
temp=temp,
silent=silent,
)
return x_semantic
@@ -32,6 +35,7 @@ def semantic_to_waveform(
semantic_tokens: np.ndarray,
history_prompt: Optional[str] = None,
temp: float = 0.7,
silent: bool = False,
):
"""Generate audio array from semantic input.
@@ -39,6 +43,7 @@ def semantic_to_waveform(
semantic_tokens: semantic token output from `text_to_semantic`
history_prompt: history choice for audio cloning
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar
Returns:
numpy audio array at sample frequency 24khz
@@ -47,6 +52,7 @@ def semantic_to_waveform(
semantic_tokens,
history_prompt=history_prompt,
temp=temp,
silent=silent,
)
x_fine_gen = generate_fine(
x_coarse_gen,
@@ -62,6 +68,7 @@ def generate_audio(
history_prompt: Optional[str] = None,
text_temp: float = 0.7,
waveform_temp: float = 0.7,
silent: bool = False,
):
"""Generate audio array from input text.
@@ -70,10 +77,15 @@ def generate_audio(
history_prompt: history choice for audio cloning
text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar
Returns:
numpy audio array at sample frequency 24khz
"""
x_semantic = text_to_semantic(text, history_prompt=history_prompt, temp=text_temp)
audio_arr = semantic_to_waveform(x_semantic, history_prompt=history_prompt, temp=waveform_temp)
x_semantic = text_to_semantic(
text, history_prompt=history_prompt, temp=text_temp, silent=silent,
)
audio_arr = semantic_to_waveform(
x_semantic, history_prompt=history_prompt, temp=waveform_temp, silent=silent,
)
return audio_arr

View File

@@ -137,9 +137,9 @@ def _parse_s3_filepath(s3_filepath):
def _download(from_s3_path, to_local_path):
os.makedirs(CACHE_DIR, exist_ok=True)
response = requests.get(from_s3_path, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(to_local_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
@@ -191,7 +191,7 @@ def clean_models(model_key=None):
def _load_model(ckpt_path, device, model_type="text"):
if "cuda" not in device:
logger.warning("No GPU being used. Careful, Inference might be extremely slow!")
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
if model_type == "text":
ConfigClass = GPTConfig
ModelClass = GPT
@@ -207,10 +207,10 @@ def _load_model(ckpt_path, device, model_type="text"):
os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
):
logger.warning(f"found outdated {model_type} model, removing...")
logger.warning(f"found outdated {model_type} model, removing.")
os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading...")
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack