From 72ccdb1a72e18ec1ad68314474fbe5f07fc2104c Mon Sep 17 00:00:00 2001 From: "Xingjun.Wang" Date: Fri, 10 Jan 2025 23:38:46 +0800 Subject: [PATCH] Add repo_id and repo_type in snapshot_download (#1172) * add repo_id and repo_type in snapshot_download * fix positional args * update --- modelscope/hub/snapshot_download.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 2f7f4790..31d1f091 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -19,6 +19,7 @@ from modelscope.hub.utils.utils import (get_model_masked_directory, model_id_to_group_owner_name) from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, + DEFAULT_REPOSITORY_REVISION, REPO_TYPE_DATASET, REPO_TYPE_MODEL, REPO_TYPE_SUPPORT) from modelscope.utils.logger import get_logger @@ -28,8 +29,8 @@ logger = get_logger() def snapshot_download( - model_id: str, - revision: Optional[str] = DEFAULT_MODEL_REVISION, + model_id: str = None, + revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, user_agent: Optional[Union[Dict, str]] = None, local_files_only: Optional[bool] = False, @@ -40,6 +41,8 @@ def snapshot_download( allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_workers: int = 8, + repo_id: str = None, + repo_type: Optional[str] = REPO_TYPE_MODEL, ) -> str: """Download all files of a repo. Downloads a whole snapshot of a repo's files at the specified revision. This @@ -51,7 +54,10 @@ def snapshot_download( user always has git and git-lfs installed, and properly configured. Args: - model_id (str): A user or an organization name and a repo name separated by a `/`. + repo_id (str): A user or an organization name and a repo name separated by a `/`. + model_id (str): A user or an organization name and a model name separated by a `/`. + if `repo_id` is provided, `model_id` will be ignored. + repo_type (str, optional): The type of the repo, either 'model' or 'dataset'. revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a commit hash. NOTE: currently only branch and tag name is supported cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will @@ -87,9 +93,22 @@ def snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ + + repo_id = repo_id or model_id + if not repo_id: + raise ValueError('Please provide a valid model_id or repo_id') + + if repo_type not in REPO_TYPE_SUPPORT: + raise ValueError( + f'Invalid repo type: {repo_type}, only support: {REPO_TYPE_SUPPORT}' + ) + + if revision is None: + revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION + return _snapshot_download( - model_id, - repo_type=REPO_TYPE_MODEL, + repo_id, + repo_type=repo_type, revision=revision, cache_dir=cache_dir, user_agent=user_agent,