diff --git a/bark/generation.py b/bark/generation.py index c97dead..fa54388 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -21,6 +21,7 @@ if ( torch.cuda.is_available() and hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") and + hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() ): autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) @@ -80,23 +81,39 @@ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0") +USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) + REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" -REMOTE_MODEL_PATHS = { - "text": { - "path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")), - "checksum": "54afa89d65e318d4f5f80e8e8799026a", - }, - "coarse": { - "path": os.environ.get( - "SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt") - ), - "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", - }, - "fine": { - "path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")), - "checksum": "59d184ed44e3650774a2f0503a48a97b", - }, -} +if USE_SMALL_MODELS: + REMOTE_MODEL_PATHS = { + "text": { + "path": os.path.join(REMOTE_BASE_URL, "text.pt"), + "checksum": "b3e42bcbab23b688355cd44128c4cdd3", + }, + "coarse": { + "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), + "checksum": "5fe964825e3b0321f9d5f3857b89194d", + }, + "fine": { + "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), + "checksum": "5428d1befe05be2ba32195496e58dc90", + }, + } +else: + REMOTE_MODEL_PATHS = { + "text": { + "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), + "checksum": "54afa89d65e318d4f5f80e8e8799026a", + }, + "coarse": { + "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), + "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", + }, + "fine": { + "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), + "checksum": "59d184ed44e3650774a2f0503a48a97b", + }, + } if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):