mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
Add repo_id and repo_type in snapshot_download (#1172)
* add repo_id and repo_type in snapshot_download * fix positional args * update
This commit is contained in:
@@ -19,6 +19,7 @@ from modelscope.hub.utils.utils import (get_model_masked_directory,
|
||||
model_id_to_group_owner_name)
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
DEFAULT_REPOSITORY_REVISION,
|
||||
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
||||
REPO_TYPE_SUPPORT)
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -28,8 +29,8 @@ logger = get_logger()
|
||||
|
||||
|
||||
def snapshot_download(
|
||||
model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
model_id: str = None,
|
||||
revision: Optional[str] = None,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
@@ -40,6 +41,8 @@ def snapshot_download(
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
max_workers: int = 8,
|
||||
repo_id: str = None,
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
@@ -51,7 +54,10 @@ def snapshot_download(
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
model_id (str): A user or an organization name and a repo name separated by a `/`.
|
||||
repo_id (str): A user or an organization name and a repo name separated by a `/`.
|
||||
model_id (str): A user or an organization name and a model name separated by a `/`.
|
||||
if `repo_id` is provided, `model_id` will be ignored.
|
||||
repo_type (str, optional): The type of the repo, either 'model' or 'dataset'.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will
|
||||
@@ -87,9 +93,22 @@ def snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
|
||||
repo_id = repo_id or model_id
|
||||
if not repo_id:
|
||||
raise ValueError('Please provide a valid model_id or repo_id')
|
||||
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise ValueError(
|
||||
f'Invalid repo type: {repo_type}, only support: {REPO_TYPE_SUPPORT}'
|
||||
)
|
||||
|
||||
if revision is None:
|
||||
revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION
|
||||
|
||||
return _snapshot_download(
|
||||
model_id,
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
repo_id,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
|
||||
Reference in New Issue
Block a user