diff --git a/bark/generation.py b/bark/generation.py index 4ac165c..3b12757 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -14,6 +14,7 @@ import torch import torch.nn.functional as F import tqdm from transformers import BertTokenizer +from huggingface_hub import hf_hub_download from .model import GPTConfig, GPT from .model_fine import FineGPT, FineGPTConfig @@ -89,31 +90,64 @@ USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) -REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" +# REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" + +# REMOTE_MODEL_PATHS = { +# "text_small": { +# "path": os.path.join(REMOTE_BASE_URL, "text.pt"), +# "checksum": "b3e42bcbab23b688355cd44128c4cdd3", +# }, +# "coarse_small": { +# "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), +# "checksum": "5fe964825e3b0321f9d5f3857b89194d", +# }, +# "fine_small": { +# "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), +# "checksum": "5428d1befe05be2ba32195496e58dc90", +# }, +# "text": { +# "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), +# "checksum": "54afa89d65e318d4f5f80e8e8799026a", +# }, +# "coarse": { +# "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), +# "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", +# }, +# "fine": { +# "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), +# "checksum": "59d184ed44e3650774a2f0503a48a97b", +# }, +# } REMOTE_MODEL_PATHS = { "text_small": { - "path": os.path.join(REMOTE_BASE_URL, "text.pt"), + "repo_id": "reach-vb/bark-small", + "file_name": "text.pt", "checksum": "b3e42bcbab23b688355cd44128c4cdd3", }, "coarse_small": { - "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), + "repo_id": "reach-vb/bark-small", + "file_name": "coarse.pt", "checksum": "5fe964825e3b0321f9d5f3857b89194d", }, "fine_small": { - "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), + "repo_id": "reach-vb/bark-small", + "file_name": "fine.pt", "checksum": "5428d1befe05be2ba32195496e58dc90", }, "text": { - "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), + "repo_id": "reach-vb/bark", + "file_name": "text_2.pt", "checksum": "54afa89d65e318d4f5f80e8e8799026a", }, "coarse": { - "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), + "repo_id": "reach-vb/bark", + "file_name": "coarse_2.pt", "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", }, "fine": { - "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), + "repo_id": "reach-vb/bark-small", + "file_name": "fine_2.pt", "checksum": "59d184ed44e3650774a2f0503a48a97b", }, } @@ -165,21 +199,25 @@ def _parse_s3_filepath(s3_filepath): return bucket_name, rel_s3_filepath -def _download(from_s3_path, to_local_path): - os.makedirs(CACHE_DIR, exist_ok=True) - response = requests.get(from_s3_path, stream=True) - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 - progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - with open(to_local_path, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - raise ValueError("ERROR, something went wrong") +# def _download(from_s3_path, to_local_path): +# os.makedirs(CACHE_DIR, exist_ok=True) +# response = requests.get(from_s3_path, stream=True) +# total_size_in_bytes = int(response.headers.get("content-length", 0)) +# block_size = 1024 +# progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) +# with open(to_local_path, "wb") as file: +# for data in response.iter_content(block_size): +# progress_bar.update(len(data)) +# file.write(data) +# progress_bar.close() +# if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: +# raise ValueError("ERROR, something went wrong") +def _download(from_hf_path, file_name, to_local_path): + os.makedirs(CACHE_DIR, exist_ok=True) + hf_hub_download(repo_id=from_hf_path, filename=file_name, cache_dir=to_local_path) + class InferenceContext: def __init__(self, benchmark=False): # we can't expect inputs to be the same length, so disable benchmarking by default @@ -243,7 +281,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): os.remove(ckpt_path) if not os.path.exists(ckpt_path): logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") - _download(model_info["path"], ckpt_path) + _download(model_info["repo_id"], model_info["file_name"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack model_args = checkpoint["model_args"]