diff --git a/modelscope/cli/cli.py b/modelscope/cli/cli.py index 5c60da8d..aa7dad02 100644 --- a/modelscope/cli/cli.py +++ b/modelscope/cli/cli.py @@ -9,6 +9,7 @@ from modelscope.cli.modelcard import ModelCardCMD from modelscope.cli.pipeline import PipelineCMD from modelscope.cli.plugins import PluginsCMD from modelscope.cli.server import ServerCMD +from modelscope.hub.api import HubApi from modelscope.utils.logger import get_logger logger = get_logger(log_level=logging.WARNING) @@ -17,6 +18,8 @@ logger = get_logger(log_level=logging.WARNING) def run_cmd(): parser = argparse.ArgumentParser( 'ModelScope Command Line tool', usage='modelscope []') + parser.add_argument( + '--token', default=None, help='Specify modelscope token.') subparsers = parser.add_subparsers(help='modelscope commands helpers') DownloadCMD.define_args(subparsers) @@ -31,7 +34,9 @@ def run_cmd(): if not hasattr(args, 'func'): parser.print_help() exit(1) - + if args.token is not None: + api = HubApi() + api.login(args.token) cmd = args.func(args) cmd.execute() diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 26d82bee..59a3b3ba 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -735,7 +735,9 @@ class HubApi: namespace: str, revision: str, root_path: str, - recursive: bool = True): + recursive: bool = True, + page_number: int = 1, + page_size: int = 100): dataset_hub_id, dataset_type = self.get_dataset_id_and_type( dataset_name=dataset_name, namespace=namespace) @@ -743,7 +745,8 @@ class HubApi: recursive = 'True' if recursive else 'False' datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree' params = {'Revision': revision if revision else 'master', - 'Root': root_path if root_path else '/', 'Recursive': recursive} + 'Root': root_path if root_path else '/', 'Recursive': recursive, + 'PageNumber': page_number, 'PageSize': page_size} cookies = ModelScopeConfig.get_cookies() r = self.session.get(datahub_url, params=params, cookies=cookies) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 4e89a6b0..9f5316df 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -3,12 +3,14 @@ import fnmatch import os import re +import uuid from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, List, Optional, Union from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.errors import InvalidParameter +from modelscope.hub.utils.caching import ModelFileSystemCache from modelscope.hub.utils.utils import model_id_to_group_owner_name from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, @@ -31,6 +33,8 @@ def snapshot_download( ignore_file_pattern: Optional[Union[str, List[str]]] = None, allow_file_pattern: Optional[Union[str, List[str]]] = None, local_dir: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, ) -> str: """Download all files of a repo. Downloads a whole snapshot of a repo's files at the specified revision. This @@ -56,6 +60,10 @@ def snapshot_download( allow_file_pattern (`str` or `List`, *optional*, default to `None`): Any file pattern to be downloading, like exact file names or file extensions. local_dir (str, optional): Specific local directory path to which the file will be downloaded. + allow_patterns (`str` or `List`, *optional*, default to `None`): + If provided, only files matching at least one pattern are downloaded, priority over allow_file_pattern. + ignore_patterns (`str` or `List`, *optional*, default to `None`): + If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern. Raises: ValueError: the value details. @@ -71,6 +79,10 @@ def snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ + if allow_patterns: + allow_file_pattern = allow_patterns + if ignore_patterns: + ignore_file_pattern = ignore_patterns return _snapshot_download( model_id, repo_type=REPO_TYPE_MODEL, @@ -81,7 +93,9 @@ def snapshot_download( cookies=cookies, ignore_file_pattern=ignore_file_pattern, allow_file_pattern=allow_file_pattern, - local_dir=local_dir) + local_dir=local_dir, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns) def dataset_snapshot_download( @@ -94,6 +108,8 @@ def dataset_snapshot_download( cookies: Optional[CookieJar] = None, ignore_file_pattern: Optional[Union[str, List[str]]] = None, allow_file_pattern: Optional[Union[str, List[str]]] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, ) -> str: """Download raw files of a dataset. Downloads all files at the specified revision. This @@ -120,6 +136,10 @@ def dataset_snapshot_download( Use regression is deprecated. allow_file_pattern (`str` or `List`, *optional*, default to `None`): Any file pattern to be downloading, like exact file names or file extensions. + allow_patterns (`str` or `List`, *optional*, default to `None`): + If provided, only files matching at least one pattern are downloaded, priority over allow_file_pattern. + ignore_patterns (`str` or `List`, *optional*, default to `None`): + If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern. Raises: ValueError: the value details. @@ -135,6 +155,10 @@ def dataset_snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ + if allow_patterns: + allow_file_pattern = allow_patterns + if ignore_patterns: + ignore_file_pattern = ignore_patterns return _snapshot_download( dataset_id, repo_type=REPO_TYPE_DATASET, @@ -164,8 +188,8 @@ def _snapshot_download( if not repo_type: repo_type = REPO_TYPE_MODEL if repo_type not in REPO_TYPE_SUPPORT: - raise InvalidParameter('Invalid repo type: %s, only support: %s' ( - repo_type, REPO_TYPE_SUPPORT)) + raise InvalidParameter('Invalid repo type: %s, only support: %s' % + (repo_type, REPO_TYPE_SUPPORT)) temporary_cache_dir, cache = create_temporary_directory_and_cache( repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type) @@ -184,8 +208,10 @@ def _snapshot_download( # make headers headers = { 'user-agent': - ModelScopeConfig.get_user_agent(user_agent=user_agent, ) + ModelScopeConfig.get_user_agent(user_agent=user_agent, ), } + if 'CI_TEST' not in os.environ: + headers['snapshot_identifier'] = str(uuid.uuid4()) _api = HubApi() if cookies is None: cookies = ModelScopeConfig.get_cookies() @@ -212,82 +238,138 @@ def _snapshot_download( use_cookies=False if cookies is None else cookies, headers=snapshot_header, ) + _download_file_lists( + repo_files, + cache, + temporary_cache_dir, + repo_id, + _api, + None, + None, + headers, + revision_detail=revision_detail, + repo_type=repo_type, + revision=revision, + cookies=cookies, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern) elif repo_type == REPO_TYPE_DATASET: group_or_owner, name = model_id_to_group_owner_name(repo_id) if not revision: revision = DEFAULT_DATASET_REVISION + _api.dataset_download_statistics(name, group_or_owner) revision_detail = revision - files_list_tree = _api.list_repo_tree( - dataset_name=name, - namespace=group_or_owner, - revision=revision, - root_path='/', - recursive=True) - if not ('Code' in files_list_tree - and files_list_tree['Code'] == 200): - print( - 'Get dataset: %s file list failed, request_id: %s, message: %s' - % (repo_id, files_list_tree['RequestId'], - files_list_tree['Message'])) - return None - repo_files = files_list_tree['Data']['Files'] - - if ignore_file_pattern is None: - ignore_file_pattern = [] - if isinstance(ignore_file_pattern, str): - ignore_file_pattern = [ignore_file_pattern] - ignore_file_pattern = [ - item if not item.endswith('/') else item + '*' - for item in ignore_file_pattern - ] - ignore_regex_pattern = [] - for file_pattern in ignore_file_pattern: - if file_pattern.startswith('*'): - ignore_regex_pattern.append('.' + file_pattern) - else: - ignore_regex_pattern.append(file_pattern) - - if allow_file_pattern is not None: - if isinstance(allow_file_pattern, str): - allow_file_pattern = [allow_file_pattern] - allow_file_pattern = [ - item if not item.endswith('/') else item + '*' - for item in allow_file_pattern - ] - - for repo_file in repo_files: - if repo_file['Type'] == 'tree' or \ - any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \ - any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501 - continue - - if allow_file_pattern is not None and allow_file_pattern: - if not any( - fnmatch.fnmatch(repo_file['Path'], pattern) - for pattern in allow_file_pattern): - continue - - # check model_file is exist in cache, if existed, skip download, otherwise download - if cache.exists(repo_file): - file_name = os.path.basename(repo_file['Name']) - logger.debug( - f'File {file_name} already in cache, skip downloading!') - continue - if repo_type == REPO_TYPE_MODEL: - # get download url - url = get_file_download_url( - model_id=repo_id, - file_path=repo_file['Path'], - revision=revision) - elif repo_type == REPO_TYPE_DATASET: - url = _api.get_dataset_file_url( - file_name=repo_file['Path'], + page_number = 1 + page_size = 100 + while True: + files_list_tree = _api.list_repo_tree( dataset_name=name, namespace=group_or_owner, - revision=revision) + revision=revision, + root_path='/', + recursive=True, + page_number=page_number, + page_size=page_size) + if not ('Code' in files_list_tree + and files_list_tree['Code'] == 200): + print( + 'Get dataset: %s file list failed, request_id: %s, message: %s' + % (repo_id, files_list_tree['RequestId'], + files_list_tree['Message'])) + return None + repo_files = files_list_tree['Data']['Files'] + _download_file_lists( + repo_files, + cache, + temporary_cache_dir, + repo_id, + _api, + name, + group_or_owner, + headers, + revision_detail=revision_detail, + repo_type=repo_type, + revision=revision, + cookies=cookies, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern) + if len(repo_files) < page_size: + break + page_number += 1 - download_file(url, repo_file, temporary_cache_dir, cache, headers, - cookies) - cache.save_model_version(revision_info=revision_detail) - return os.path.join(cache.get_root_location()) +def _download_file_lists( + repo_files: List[str], + cache: ModelFileSystemCache, + temporary_cache_dir: str, + repo_id: str, + api: HubApi, + name: str, + group_or_owner: str, + headers, + revision_detail: str, + repo_type: Optional[str] = None, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cookies: Optional[CookieJar] = None, + ignore_file_pattern: Optional[Union[str, List[str]]] = None, + allow_file_pattern: Optional[Union[str, List[str]]] = None, +): + if ignore_file_pattern is None: + ignore_file_pattern = [] + if isinstance(ignore_file_pattern, str): + ignore_file_pattern = [ignore_file_pattern] + ignore_file_pattern = [ + item if not item.endswith('/') else item + '*' + for item in ignore_file_pattern + ] + ignore_regex_pattern = [] + for file_pattern in ignore_file_pattern: + if file_pattern.startswith('*'): + ignore_regex_pattern.append('.' + file_pattern) + else: + ignore_regex_pattern.append(file_pattern) + + if allow_file_pattern is not None: + if isinstance(allow_file_pattern, str): + allow_file_pattern = [allow_file_pattern] + allow_file_pattern = [ + item if not item.endswith('/') else item + '*' + for item in allow_file_pattern + ] + + for repo_file in repo_files: + if repo_file['Type'] == 'tree' or \ + any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \ + any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501 + continue + + if allow_file_pattern is not None and allow_file_pattern: + if not any( + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in allow_file_pattern): + continue + + # check model_file is exist in cache, if existed, skip download, otherwise download + if cache.exists(repo_file): + file_name = os.path.basename(repo_file['Name']) + logger.debug( + f'File {file_name} already in cache, skip downloading!') + continue + if repo_type == REPO_TYPE_MODEL: + # get download url + url = get_file_download_url( + model_id=repo_id, + file_path=repo_file['Path'], + revision=revision) + elif repo_type == REPO_TYPE_DATASET: + url = api.get_dataset_file_url( + file_name=repo_file['Path'], + dataset_name=name, + namespace=group_or_owner, + revision=revision) + + download_file(url, repo_file, temporary_cache_dir, cache, headers, + cookies) + + cache.save_model_version(revision_info=revision_detail) + return os.path.join(cache.get_root_location()) diff --git a/modelscope/utils/automodel_utils.py b/modelscope/utils/automodel_utils.py index 5cefdafe..df540ba1 100644 --- a/modelscope/utils/automodel_utils.py +++ b/modelscope/utils/automodel_utils.py @@ -6,6 +6,7 @@ from typing import Any, Optional from modelscope.metainfo import Tasks from modelscope.utils.ast_utils import INDEX_KEY from modelscope.utils.import_utils import (LazyImportModule, + is_torch_available, is_transformers_available) @@ -36,7 +37,7 @@ def post_init(self, *args, **kwargs): def fix_transformers_upgrade(): - if is_transformers_available(): + if is_transformers_available() and is_torch_available(): # from 4.35.0, transformers changes its arguments of _set_gradient_checkpointing import transformers from transformers import PreTrainedModel diff --git a/requirements/framework.txt b/requirements/framework.txt index c8a4c277..946023a5 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -10,4 +10,5 @@ scipy setuptools==69.5.1 simplejson>=3.3.0 sortedcontainers>=1.5.9 +transformers urllib3>=1.26