From e9ad2d5886388c6188eced86c96331b6a6983984 Mon Sep 17 00:00:00 2001 From: vaibhavs10 Date: Thu, 27 Apr 2023 16:12:54 +0200 Subject: [PATCH 1/8] initial commit --- bark/generation.py | 80 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 21 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index 4ac165c..3b12757 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -14,6 +14,7 @@ import torch import torch.nn.functional as F import tqdm from transformers import BertTokenizer +from huggingface_hub import hf_hub_download from .model import GPTConfig, GPT from .model_fine import FineGPT, FineGPTConfig @@ -89,31 +90,64 @@ USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) -REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" +# REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" + +# REMOTE_MODEL_PATHS = { +# "text_small": { +# "path": os.path.join(REMOTE_BASE_URL, "text.pt"), +# "checksum": "b3e42bcbab23b688355cd44128c4cdd3", +# }, +# "coarse_small": { +# "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), +# "checksum": "5fe964825e3b0321f9d5f3857b89194d", +# }, +# "fine_small": { +# "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), +# "checksum": "5428d1befe05be2ba32195496e58dc90", +# }, +# "text": { +# "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), +# "checksum": "54afa89d65e318d4f5f80e8e8799026a", +# }, +# "coarse": { +# "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), +# "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", +# }, +# "fine": { +# "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), +# "checksum": "59d184ed44e3650774a2f0503a48a97b", +# }, +# } REMOTE_MODEL_PATHS = { "text_small": { - "path": os.path.join(REMOTE_BASE_URL, "text.pt"), + "repo_id": "reach-vb/bark-small", + "file_name": "text.pt", "checksum": "b3e42bcbab23b688355cd44128c4cdd3", }, "coarse_small": { - "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), + "repo_id": "reach-vb/bark-small", + "file_name": "coarse.pt", "checksum": "5fe964825e3b0321f9d5f3857b89194d", }, "fine_small": { - "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), + "repo_id": "reach-vb/bark-small", + "file_name": "fine.pt", "checksum": "5428d1befe05be2ba32195496e58dc90", }, "text": { - "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), + "repo_id": "reach-vb/bark", + "file_name": "text_2.pt", "checksum": "54afa89d65e318d4f5f80e8e8799026a", }, "coarse": { - "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), + "repo_id": "reach-vb/bark", + "file_name": "coarse_2.pt", "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", }, "fine": { - "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), + "repo_id": "reach-vb/bark-small", + "file_name": "fine_2.pt", "checksum": "59d184ed44e3650774a2f0503a48a97b", }, } @@ -165,21 +199,25 @@ def _parse_s3_filepath(s3_filepath): return bucket_name, rel_s3_filepath -def _download(from_s3_path, to_local_path): - os.makedirs(CACHE_DIR, exist_ok=True) - response = requests.get(from_s3_path, stream=True) - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 - progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - with open(to_local_path, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - raise ValueError("ERROR, something went wrong") +# def _download(from_s3_path, to_local_path): +# os.makedirs(CACHE_DIR, exist_ok=True) +# response = requests.get(from_s3_path, stream=True) +# total_size_in_bytes = int(response.headers.get("content-length", 0)) +# block_size = 1024 +# progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) +# with open(to_local_path, "wb") as file: +# for data in response.iter_content(block_size): +# progress_bar.update(len(data)) +# file.write(data) +# progress_bar.close() +# if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: +# raise ValueError("ERROR, something went wrong") +def _download(from_hf_path, file_name, to_local_path): + os.makedirs(CACHE_DIR, exist_ok=True) + hf_hub_download(repo_id=from_hf_path, filename=file_name, cache_dir=to_local_path) + class InferenceContext: def __init__(self, benchmark=False): # we can't expect inputs to be the same length, so disable benchmarking by default @@ -243,7 +281,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): os.remove(ckpt_path) if not os.path.exists(ckpt_path): logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") - _download(model_info["path"], ckpt_path) + _download(model_info["repo_id"], model_info["file_name"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack model_args = checkpoint["model_args"] From ac3a7568a7e34e2f90e98b9ee6a31425fc9fe66f Mon Sep 17 00:00:00 2001 From: vaibhavs10 Date: Thu, 27 Apr 2023 16:19:58 +0200 Subject: [PATCH 2/8] up --- bark/generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index 3b12757..a3c8a0b 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -90,7 +90,7 @@ USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) -# REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" +REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" # REMOTE_MODEL_PATHS = { # "text_small": { @@ -176,7 +176,7 @@ def _md5(fname): def _get_ckpt_path(model_type, use_small=False): model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type - model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"]) + model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"]) return os.path.join(CACHE_DIR, f"{model_name}.pt") From c26a82a4153fe05486aa593498682348f7d6ed42 Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Thu, 27 Apr 2023 17:35:28 +0200 Subject: [PATCH 3/8] up --- bark/generation.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index a3c8a0b..82994e2 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -146,7 +146,7 @@ REMOTE_MODEL_PATHS = { "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", }, "fine": { - "repo_id": "reach-vb/bark-small", + "repo_id": "reach-vb/bark", "file_name": "fine_2.pt", "checksum": "59d184ed44e3650774a2f0503a48a97b", }, @@ -199,24 +199,12 @@ def _parse_s3_filepath(s3_filepath): return bucket_name, rel_s3_filepath -# def _download(from_s3_path, to_local_path): -# os.makedirs(CACHE_DIR, exist_ok=True) -# response = requests.get(from_s3_path, stream=True) -# total_size_in_bytes = int(response.headers.get("content-length", 0)) -# block_size = 1024 -# progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) -# with open(to_local_path, "wb") as file: -# for data in response.iter_content(block_size): -# progress_bar.update(len(data)) -# file.write(data) -# progress_bar.close() -# if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: -# raise ValueError("ERROR, something went wrong") - - def _download(from_hf_path, file_name, to_local_path): os.makedirs(CACHE_DIR, exist_ok=True) - hf_hub_download(repo_id=from_hf_path, filename=file_name, cache_dir=to_local_path) + destination_file_name = to_local_path.split("/")[-1] + file_dir = CACHE_DIR + hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=file_dir) + os.replace(f"{CACHE_DIR}/{file_name}", to_local_path) class InferenceContext: def __init__(self, benchmark=False): From 035d08e157d57d04bd1d5891c6464b151bc5acab Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Thu, 27 Apr 2023 17:45:59 +0200 Subject: [PATCH 4/8] up --- bark/generation.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index 82994e2..ab23479 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -90,34 +90,6 @@ USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) -REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" - -# REMOTE_MODEL_PATHS = { -# "text_small": { -# "path": os.path.join(REMOTE_BASE_URL, "text.pt"), -# "checksum": "b3e42bcbab23b688355cd44128c4cdd3", -# }, -# "coarse_small": { -# "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), -# "checksum": "5fe964825e3b0321f9d5f3857b89194d", -# }, -# "fine_small": { -# "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), -# "checksum": "5428d1befe05be2ba32195496e58dc90", -# }, -# "text": { -# "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), -# "checksum": "54afa89d65e318d4f5f80e8e8799026a", -# }, -# "coarse": { -# "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), -# "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", -# }, -# "fine": { -# "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), -# "checksum": "59d184ed44e3650774a2f0503a48a97b", -# }, -# } REMOTE_MODEL_PATHS = { "text_small": { From c61ee92ee926f9248c8c2fe040a09dc117f96b84 Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Thu, 27 Apr 2023 17:51:43 +0200 Subject: [PATCH 5/8] up --- bark/generation.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index ab23479..842b890 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -162,15 +162,6 @@ def _grab_best_device(use_gpu=True): return device -S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" - - -def _parse_s3_filepath(s3_filepath): - bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1) - rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath) - return bucket_name, rel_s3_filepath - - def _download(from_hf_path, file_name, to_local_path): os.makedirs(CACHE_DIR, exist_ok=True) destination_file_name = to_local_path.split("/")[-1] From b24dd26d4b18e355c5a425bec794b30216d1d86e Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Fri, 28 Apr 2023 16:26:09 +0200 Subject: [PATCH 6/8] add suggestions from code review --- bark/generation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index 842b890..f6980cc 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -165,9 +165,8 @@ def _grab_best_device(use_gpu=True): def _download(from_hf_path, file_name, to_local_path): os.makedirs(CACHE_DIR, exist_ok=True) destination_file_name = to_local_path.split("/")[-1] - file_dir = CACHE_DIR - hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=file_dir) - os.replace(f"{CACHE_DIR}/{file_name}", to_local_path) + hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR) + os.replace(os.path.join(CACHE_DIR, file_name), to_local_path) class InferenceContext: def __init__(self, benchmark=False): From 27ff4f9db86b5cd22e93872704b7bee5b56e353b Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Fri, 28 Apr 2023 17:02:41 +0200 Subject: [PATCH 7/8] new model repo --- bark/generation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index f6980cc..1e2c8db 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -93,17 +93,17 @@ OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) REMOTE_MODEL_PATHS = { "text_small": { - "repo_id": "reach-vb/bark-small", + "repo_id": "reach-vb/bark", "file_name": "text.pt", "checksum": "b3e42bcbab23b688355cd44128c4cdd3", }, "coarse_small": { - "repo_id": "reach-vb/bark-small", + "repo_id": "reach-vb/bark", "file_name": "coarse.pt", "checksum": "5fe964825e3b0321f9d5f3857b89194d", }, "fine_small": { - "repo_id": "reach-vb/bark-small", + "repo_id": "reach-vb/bark", "file_name": "fine.pt", "checksum": "5428d1befe05be2ba32195496e58dc90", }, From e0f2d117f51eeb6426d02f1ff57e9a2f4ab5f3fa Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Fri, 28 Apr 2023 17:54:50 +0200 Subject: [PATCH 8/8] updating model repo organisation reach-vb -> suno --- bark/generation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index 1e2c8db..64b3c47 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -93,32 +93,32 @@ OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) REMOTE_MODEL_PATHS = { "text_small": { - "repo_id": "reach-vb/bark", + "repo_id": "suno/bark", "file_name": "text.pt", "checksum": "b3e42bcbab23b688355cd44128c4cdd3", }, "coarse_small": { - "repo_id": "reach-vb/bark", + "repo_id": "suno/bark", "file_name": "coarse.pt", "checksum": "5fe964825e3b0321f9d5f3857b89194d", }, "fine_small": { - "repo_id": "reach-vb/bark", + "repo_id": "suno/bark", "file_name": "fine.pt", "checksum": "5428d1befe05be2ba32195496e58dc90", }, "text": { - "repo_id": "reach-vb/bark", + "repo_id": "suno/bark", "file_name": "text_2.pt", "checksum": "54afa89d65e318d4f5f80e8e8799026a", }, "coarse": { - "repo_id": "reach-vb/bark", + "repo_id": "suno/bark", "file_name": "coarse_2.pt", "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", }, "fine": { - "repo_id": "reach-vb/bark", + "repo_id": "suno/bark", "file_name": "fine_2.pt", "checksum": "59d184ed44e3650774a2f0503a48a97b", },