Merge pull request #188 from Vaibhavs10/add_hf_hub_support

Add support to download models from Hugging Face Hub
This commit is contained in:
Georg Kucsko
2023-04-28 12:44:00 -04:00
committed by GitHub

View File

@@ -14,6 +14,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from transformers import BertTokenizer from transformers import BertTokenizer
from huggingface_hub import hf_hub_download
from .model import GPTConfig, GPT from .model import GPTConfig, GPT
from .model_fine import FineGPT, FineGPTConfig from .model_fine import FineGPT, FineGPTConfig
@@ -89,31 +90,36 @@ USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
REMOTE_MODEL_PATHS = { REMOTE_MODEL_PATHS = {
"text_small": { "text_small": {
"path": os.path.join(REMOTE_BASE_URL, "text.pt"), "repo_id": "suno/bark",
"file_name": "text.pt",
"checksum": "b3e42bcbab23b688355cd44128c4cdd3", "checksum": "b3e42bcbab23b688355cd44128c4cdd3",
}, },
"coarse_small": { "coarse_small": {
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), "repo_id": "suno/bark",
"file_name": "coarse.pt",
"checksum": "5fe964825e3b0321f9d5f3857b89194d", "checksum": "5fe964825e3b0321f9d5f3857b89194d",
}, },
"fine_small": { "fine_small": {
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"), "repo_id": "suno/bark",
"file_name": "fine.pt",
"checksum": "5428d1befe05be2ba32195496e58dc90", "checksum": "5428d1befe05be2ba32195496e58dc90",
}, },
"text": { "text": {
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), "repo_id": "suno/bark",
"file_name": "text_2.pt",
"checksum": "54afa89d65e318d4f5f80e8e8799026a", "checksum": "54afa89d65e318d4f5f80e8e8799026a",
}, },
"coarse": { "coarse": {
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), "repo_id": "suno/bark",
"file_name": "coarse_2.pt",
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
}, },
"fine": { "fine": {
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), "repo_id": "suno/bark",
"file_name": "fine_2.pt",
"checksum": "59d184ed44e3650774a2f0503a48a97b", "checksum": "59d184ed44e3650774a2f0503a48a97b",
}, },
} }
@@ -142,7 +148,7 @@ def _md5(fname):
def _get_ckpt_path(model_type, use_small=False): def _get_ckpt_path(model_type, use_small=False):
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"]) model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"])
return os.path.join(CACHE_DIR, f"{model_name}.pt") return os.path.join(CACHE_DIR, f"{model_name}.pt")
@@ -156,29 +162,11 @@ def _grab_best_device(use_gpu=True):
return device return device
S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" def _download(from_hf_path, file_name, to_local_path):
def _parse_s3_filepath(s3_filepath):
bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1)
rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath)
return bucket_name, rel_s3_filepath
def _download(from_s3_path, to_local_path):
os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True)
response = requests.get(from_s3_path, stream=True) destination_file_name = to_local_path.split("/")[-1]
total_size_in_bytes = int(response.headers.get("content-length", 0)) hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
block_size = 1024 os.replace(os.path.join(CACHE_DIR, file_name), to_local_path)
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(to_local_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise ValueError("ERROR, something went wrong")
class InferenceContext: class InferenceContext:
def __init__(self, benchmark=False): def __init__(self, benchmark=False):
@@ -243,7 +231,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
os.remove(ckpt_path) os.remove(ckpt_path)
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["path"], ckpt_path) _download(model_info["repo_id"], model_info["file_name"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device) checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack # this is a hack
model_args = checkpoint["model_args"] model_args = checkpoint["model_args"]