mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 11:48:09 +01:00
initial commit
This commit is contained in:
@@ -14,6 +14,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import tqdm
|
import tqdm
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from .model import GPTConfig, GPT
|
from .model import GPTConfig, GPT
|
||||||
from .model_fine import FineGPT, FineGPTConfig
|
from .model_fine import FineGPT, FineGPTConfig
|
||||||
@@ -89,31 +90,64 @@ USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
|||||||
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
|
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
|
||||||
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
|
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
|
||||||
|
|
||||||
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
# REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
||||||
|
|
||||||
|
# REMOTE_MODEL_PATHS = {
|
||||||
|
# "text_small": {
|
||||||
|
# "path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
||||||
|
# "checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
||||||
|
# },
|
||||||
|
# "coarse_small": {
|
||||||
|
# "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
||||||
|
# "checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
||||||
|
# },
|
||||||
|
# "fine_small": {
|
||||||
|
# "path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
||||||
|
# "checksum": "5428d1befe05be2ba32195496e58dc90",
|
||||||
|
# },
|
||||||
|
# "text": {
|
||||||
|
# "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
||||||
|
# "checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||||
|
# },
|
||||||
|
# "coarse": {
|
||||||
|
# "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
||||||
|
# "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||||
|
# },
|
||||||
|
# "fine": {
|
||||||
|
# "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
||||||
|
# "checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
|
||||||
REMOTE_MODEL_PATHS = {
|
REMOTE_MODEL_PATHS = {
|
||||||
"text_small": {
|
"text_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
"repo_id": "reach-vb/bark-small",
|
||||||
|
"file_name": "text.pt",
|
||||||
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
||||||
},
|
},
|
||||||
"coarse_small": {
|
"coarse_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
"repo_id": "reach-vb/bark-small",
|
||||||
|
"file_name": "coarse.pt",
|
||||||
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
||||||
},
|
},
|
||||||
"fine_small": {
|
"fine_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
"repo_id": "reach-vb/bark-small",
|
||||||
|
"file_name": "fine.pt",
|
||||||
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
||||||
},
|
},
|
||||||
"text": {
|
"text": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
"repo_id": "reach-vb/bark",
|
||||||
|
"file_name": "text_2.pt",
|
||||||
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||||
},
|
},
|
||||||
"coarse": {
|
"coarse": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
"repo_id": "reach-vb/bark",
|
||||||
|
"file_name": "coarse_2.pt",
|
||||||
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||||
},
|
},
|
||||||
"fine": {
|
"fine": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
"repo_id": "reach-vb/bark-small",
|
||||||
|
"file_name": "fine_2.pt",
|
||||||
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -165,21 +199,25 @@ def _parse_s3_filepath(s3_filepath):
|
|||||||
return bucket_name, rel_s3_filepath
|
return bucket_name, rel_s3_filepath
|
||||||
|
|
||||||
|
|
||||||
def _download(from_s3_path, to_local_path):
|
# def _download(from_s3_path, to_local_path):
|
||||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
# os.makedirs(CACHE_DIR, exist_ok=True)
|
||||||
response = requests.get(from_s3_path, stream=True)
|
# response = requests.get(from_s3_path, stream=True)
|
||||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
# total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||||
block_size = 1024
|
# block_size = 1024
|
||||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
# progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
with open(to_local_path, "wb") as file:
|
# with open(to_local_path, "wb") as file:
|
||||||
for data in response.iter_content(block_size):
|
# for data in response.iter_content(block_size):
|
||||||
progress_bar.update(len(data))
|
# progress_bar.update(len(data))
|
||||||
file.write(data)
|
# file.write(data)
|
||||||
progress_bar.close()
|
# progress_bar.close()
|
||||||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
# if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
||||||
raise ValueError("ERROR, something went wrong")
|
# raise ValueError("ERROR, something went wrong")
|
||||||
|
|
||||||
|
|
||||||
|
def _download(from_hf_path, file_name, to_local_path):
|
||||||
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||||
|
hf_hub_download(repo_id=from_hf_path, filename=file_name, cache_dir=to_local_path)
|
||||||
|
|
||||||
class InferenceContext:
|
class InferenceContext:
|
||||||
def __init__(self, benchmark=False):
|
def __init__(self, benchmark=False):
|
||||||
# we can't expect inputs to be the same length, so disable benchmarking by default
|
# we can't expect inputs to be the same length, so disable benchmarking by default
|
||||||
@@ -243,7 +281,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
|||||||
os.remove(ckpt_path)
|
os.remove(ckpt_path)
|
||||||
if not os.path.exists(ckpt_path):
|
if not os.path.exists(ckpt_path):
|
||||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||||
_download(model_info["path"], ckpt_path)
|
_download(model_info["repo_id"], model_info["file_name"], ckpt_path)
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
# this is a hack
|
# this is a hack
|
||||||
model_args = checkpoint["model_args"]
|
model_args = checkpoint["model_args"]
|
||||||
|
|||||||
Reference in New Issue
Block a user