Add repo_id and repo_type in snapshot_download (#1172)

* add repo_id and repo_type in snapshot_download

* fix positional args

* update
This commit is contained in:
Xingjun.Wang
2025-01-10 23:38:46 +08:00
committed by suluyana
parent cfd32abab2
commit 72ccdb1a72

View File

@@ -19,6 +19,7 @@ from modelscope.hub.utils.utils import (get_model_masked_directory,
model_id_to_group_owner_name)
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
DEFAULT_REPOSITORY_REVISION,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT)
from modelscope.utils.logger import get_logger
@@ -28,8 +29,8 @@ logger = get_logger()
def snapshot_download(
model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
model_id: str = None,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
@@ -40,6 +41,8 @@ def snapshot_download(
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8,
repo_id: str = None,
repo_type: Optional[str] = REPO_TYPE_MODEL,
) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
@@ -51,7 +54,10 @@ def snapshot_download(
user always has git and git-lfs installed, and properly configured.
Args:
model_id (str): A user or an organization name and a repo name separated by a `/`.
repo_id (str): A user or an organization name and a repo name separated by a `/`.
model_id (str): A user or an organization name and a model name separated by a `/`.
if `repo_id` is provided, `model_id` will be ignored.
repo_type (str, optional): The type of the repo, either 'model' or 'dataset'.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will
@@ -87,9 +93,22 @@ def snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
repo_id = repo_id or model_id
if not repo_id:
raise ValueError('Please provide a valid model_id or repo_id')
if repo_type not in REPO_TYPE_SUPPORT:
raise ValueError(
f'Invalid repo type: {repo_type}, only support: {REPO_TYPE_SUPPORT}'
)
if revision is None:
revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION
return _snapshot_download(
model_id,
repo_type=REPO_TYPE_MODEL,
repo_id,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,