add fake classifier

This commit is contained in:
Georg Kucsko
2023-04-12 14:31:00 -04:00
parent 0981150e6a
commit 2c038176b3
3 changed files with 146 additions and 11 deletions

View File

@@ -3,7 +3,6 @@ import hashlib
import os
import re
import requests
import sys
from encodec import EncodecModel
import funcy
@@ -63,27 +62,43 @@ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
os.makedirs(CACHE_DIR, exist_ok=True)
REMOTE_BASE_URL = "http://s3.amazonaws.com/suno-public/bark/models/v0/"
REMOTE_MODEL_PATHS = {
"text": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
"coarse": os.environ.get(
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
),
"fine": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
"text": {
"path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.environ.get(
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}
def _compute_md5(s):
def _string_md5(s):
m = hashlib.md5()
m.update(s.encode("utf-8"))
return m.hexdigest()
def _md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def _get_ckpt_path(model_type):
model_name = _compute_md5(REMOTE_MODEL_PATHS[model_type])
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
return os.path.join(CACHE_DIR, f"{model_name}.pt")
@@ -97,6 +112,7 @@ def _parse_s3_filepath(s3_filepath):
def _download(from_s3_path, to_local_path):
os.makedirs(CACHE_DIR, exist_ok=True)
response = requests.get(from_s3_path, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
@@ -164,9 +180,15 @@ def _load_model(ckpt_path, device, model_type="text"):
ModelClass = FineGPT
else:
raise NotImplementedError()
if (
os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
):
print(f"found outdated {model_type} model, removing...")
os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
print(f"{model_type} model not found, downloading...")
_download(REMOTE_MODEL_PATHS[model_type], ckpt_path)
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
model_args = checkpoint["model_args"]

File diff suppressed because one or more lines are too long

View File

@@ -36,4 +36,5 @@ Straightforward improvements will allow models to run faster than realtime, rend
While we hope that this release will enable users to express their creativity and build applications that are a force
for good, we acknowledge that any text to audio model has the potential for dual use. While it is not straightforward
to voice clone known people with Bark, they can still be used for nefarious purposes.
to voice clone known people with Bark, they can still be used for nefarious purposes. To further reduce the chances of unintended use of Bark,
we also release a simple classifier to detect Bark-generated audio with high accuracy (see notebooks section of the main repository).