mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 03:38:01 +01:00
add fake classifier
This commit is contained in:
@@ -3,7 +3,6 @@ import hashlib
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import requests
|
import requests
|
||||||
import sys
|
|
||||||
|
|
||||||
from encodec import EncodecModel
|
from encodec import EncodecModel
|
||||||
import funcy
|
import funcy
|
||||||
@@ -63,27 +62,43 @@ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
|||||||
|
|
||||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
REMOTE_BASE_URL = "http://s3.amazonaws.com/suno-public/bark/models/v0/"
|
REMOTE_BASE_URL = "http://s3.amazonaws.com/suno-public/bark/models/v0/"
|
||||||
REMOTE_MODEL_PATHS = {
|
REMOTE_MODEL_PATHS = {
|
||||||
"text": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
|
"text": {
|
||||||
"coarse": os.environ.get(
|
"path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
|
||||||
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
|
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||||
),
|
},
|
||||||
"fine": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
|
"coarse": {
|
||||||
|
"path": os.environ.get(
|
||||||
|
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
|
||||||
|
),
|
||||||
|
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||||
|
},
|
||||||
|
"fine": {
|
||||||
|
"path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
|
||||||
|
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _compute_md5(s):
|
def _string_md5(s):
|
||||||
m = hashlib.md5()
|
m = hashlib.md5()
|
||||||
m.update(s.encode("utf-8"))
|
m.update(s.encode("utf-8"))
|
||||||
return m.hexdigest()
|
return m.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _md5(fname):
|
||||||
|
hash_md5 = hashlib.md5()
|
||||||
|
with open(fname, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(4096), b""):
|
||||||
|
hash_md5.update(chunk)
|
||||||
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _get_ckpt_path(model_type):
|
def _get_ckpt_path(model_type):
|
||||||
model_name = _compute_md5(REMOTE_MODEL_PATHS[model_type])
|
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
|
||||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||||
|
|
||||||
|
|
||||||
@@ -97,6 +112,7 @@ def _parse_s3_filepath(s3_filepath):
|
|||||||
|
|
||||||
|
|
||||||
def _download(from_s3_path, to_local_path):
|
def _download(from_s3_path, to_local_path):
|
||||||
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||||
response = requests.get(from_s3_path, stream=True)
|
response = requests.get(from_s3_path, stream=True)
|
||||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||||
block_size = 1024 # 1 Kibibyte
|
block_size = 1024 # 1 Kibibyte
|
||||||
@@ -164,9 +180,15 @@ def _load_model(ckpt_path, device, model_type="text"):
|
|||||||
ModelClass = FineGPT
|
ModelClass = FineGPT
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
if (
|
||||||
|
os.path.exists(ckpt_path) and
|
||||||
|
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
|
||||||
|
):
|
||||||
|
print(f"found outdated {model_type} model, removing...")
|
||||||
|
os.remove(ckpt_path)
|
||||||
if not os.path.exists(ckpt_path):
|
if not os.path.exists(ckpt_path):
|
||||||
print(f"{model_type} model not found, downloading...")
|
print(f"{model_type} model not found, downloading...")
|
||||||
_download(REMOTE_MODEL_PATHS[model_type], ckpt_path)
|
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
# this is a hack
|
# this is a hack
|
||||||
model_args = checkpoint["model_args"]
|
model_args = checkpoint["model_args"]
|
||||||
|
|||||||
112
bark/notebooks/fake_classifier.ipynb
Normal file
112
bark/notebooks/fake_classifier.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -36,4 +36,5 @@ Straightforward improvements will allow models to run faster than realtime, rend
|
|||||||
|
|
||||||
While we hope that this release will enable users to express their creativity and build applications that are a force
|
While we hope that this release will enable users to express their creativity and build applications that are a force
|
||||||
for good, we acknowledge that any text to audio model has the potential for dual use. While it is not straightforward
|
for good, we acknowledge that any text to audio model has the potential for dual use. While it is not straightforward
|
||||||
to voice clone known people with Bark, they can still be used for nefarious purposes.
|
to voice clone known people with Bark, they can still be used for nefarious purposes. To further reduce the chances of unintended use of Bark,
|
||||||
|
we also release a simple classifier to detect Bark-generated audio with high accuracy (see notebooks section of the main repository).
|
||||||
|
|||||||
Reference in New Issue
Block a user