mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2026-04-03 01:36:17 +02:00
small updates
This commit is contained in:
16
bark/api.py
16
bark/api.py
@@ -9,6 +9,7 @@ def text_to_semantic(
|
||||
text: str,
|
||||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
):
|
||||
"""Generate semantic array from text.
|
||||
|
||||
@@ -16,6 +17,7 @@ def text_to_semantic(
|
||||
text: text to be turned into audio
|
||||
history_prompt: history choice for audio cloning
|
||||
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
|
||||
Returns:
|
||||
numpy semantic array to be fed into `semantic_to_waveform`
|
||||
@@ -24,6 +26,7 @@ def text_to_semantic(
|
||||
text,
|
||||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
silent=silent,
|
||||
)
|
||||
return x_semantic
|
||||
|
||||
@@ -32,6 +35,7 @@ def semantic_to_waveform(
|
||||
semantic_tokens: np.ndarray,
|
||||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
):
|
||||
"""Generate audio array from semantic input.
|
||||
|
||||
@@ -39,6 +43,7 @@ def semantic_to_waveform(
|
||||
semantic_tokens: semantic token output from `text_to_semantic`
|
||||
history_prompt: history choice for audio cloning
|
||||
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
|
||||
Returns:
|
||||
numpy audio array at sample frequency 24khz
|
||||
@@ -47,6 +52,7 @@ def semantic_to_waveform(
|
||||
semantic_tokens,
|
||||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
silent=silent,
|
||||
)
|
||||
x_fine_gen = generate_fine(
|
||||
x_coarse_gen,
|
||||
@@ -62,6 +68,7 @@ def generate_audio(
|
||||
history_prompt: Optional[str] = None,
|
||||
text_temp: float = 0.7,
|
||||
waveform_temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
):
|
||||
"""Generate audio array from input text.
|
||||
|
||||
@@ -70,10 +77,15 @@ def generate_audio(
|
||||
history_prompt: history choice for audio cloning
|
||||
text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
|
||||
Returns:
|
||||
numpy audio array at sample frequency 24khz
|
||||
"""
|
||||
x_semantic = text_to_semantic(text, history_prompt=history_prompt, temp=text_temp)
|
||||
audio_arr = semantic_to_waveform(x_semantic, history_prompt=history_prompt, temp=waveform_temp)
|
||||
x_semantic = text_to_semantic(
|
||||
text, history_prompt=history_prompt, temp=text_temp, silent=silent,
|
||||
)
|
||||
audio_arr = semantic_to_waveform(
|
||||
x_semantic, history_prompt=history_prompt, temp=waveform_temp, silent=silent,
|
||||
)
|
||||
return audio_arr
|
||||
|
||||
@@ -137,9 +137,9 @@ def _parse_s3_filepath(s3_filepath):
|
||||
def _download(from_s3_path, to_local_path):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
response = requests.get(from_s3_path, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 1024
|
||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
with open(to_local_path, "wb") as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
@@ -191,7 +191,7 @@ def clean_models(model_key=None):
|
||||
|
||||
def _load_model(ckpt_path, device, model_type="text"):
|
||||
if "cuda" not in device:
|
||||
logger.warning("No GPU being used. Careful, Inference might be extremely slow!")
|
||||
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
|
||||
if model_type == "text":
|
||||
ConfigClass = GPTConfig
|
||||
ModelClass = GPT
|
||||
@@ -207,10 +207,10 @@ def _load_model(ckpt_path, device, model_type="text"):
|
||||
os.path.exists(ckpt_path) and
|
||||
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
|
||||
):
|
||||
logger.warning(f"found outdated {model_type} model, removing...")
|
||||
logger.warning(f"found outdated {model_type} model, removing.")
|
||||
os.remove(ckpt_path)
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading...")
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
|
||||
Reference in New Issue
Block a user