diff --git a/bark/generation.py b/bark/generation.py index a3c8a0b..82994e2 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -146,7 +146,7 @@ REMOTE_MODEL_PATHS = { "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", }, "fine": { - "repo_id": "reach-vb/bark-small", + "repo_id": "reach-vb/bark", "file_name": "fine_2.pt", "checksum": "59d184ed44e3650774a2f0503a48a97b", }, @@ -199,24 +199,12 @@ def _parse_s3_filepath(s3_filepath): return bucket_name, rel_s3_filepath -# def _download(from_s3_path, to_local_path): -# os.makedirs(CACHE_DIR, exist_ok=True) -# response = requests.get(from_s3_path, stream=True) -# total_size_in_bytes = int(response.headers.get("content-length", 0)) -# block_size = 1024 -# progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) -# with open(to_local_path, "wb") as file: -# for data in response.iter_content(block_size): -# progress_bar.update(len(data)) -# file.write(data) -# progress_bar.close() -# if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: -# raise ValueError("ERROR, something went wrong") - - def _download(from_hf_path, file_name, to_local_path): os.makedirs(CACHE_DIR, exist_ok=True) - hf_hub_download(repo_id=from_hf_path, filename=file_name, cache_dir=to_local_path) + destination_file_name = to_local_path.split("/")[-1] + file_dir = CACHE_DIR + hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=file_dir) + os.replace(f"{CACHE_DIR}/{file_name}", to_local_path) class InferenceContext: def __init__(self, benchmark=False):