mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
fix dataset page bug, framework add transformers dependency
This commit is contained in:
@@ -9,6 +9,7 @@ from modelscope.cli.modelcard import ModelCardCMD
|
||||
from modelscope.cli.pipeline import PipelineCMD
|
||||
from modelscope.cli.plugins import PluginsCMD
|
||||
from modelscope.cli.server import ServerCMD
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(log_level=logging.WARNING)
|
||||
@@ -17,6 +18,8 @@ logger = get_logger(log_level=logging.WARNING)
|
||||
def run_cmd():
|
||||
parser = argparse.ArgumentParser(
|
||||
'ModelScope Command Line tool', usage='modelscope <command> [<args>]')
|
||||
parser.add_argument(
|
||||
'--token', default=None, help='Specify modelscope token.')
|
||||
subparsers = parser.add_subparsers(help='modelscope commands helpers')
|
||||
|
||||
DownloadCMD.define_args(subparsers)
|
||||
@@ -31,7 +34,9 @@ def run_cmd():
|
||||
if not hasattr(args, 'func'):
|
||||
parser.print_help()
|
||||
exit(1)
|
||||
|
||||
if args.token is not None:
|
||||
api = HubApi()
|
||||
api.login(args.token)
|
||||
cmd = args.func(args)
|
||||
cmd.execute()
|
||||
|
||||
|
||||
@@ -735,7 +735,9 @@ class HubApi:
|
||||
namespace: str,
|
||||
revision: str,
|
||||
root_path: str,
|
||||
recursive: bool = True):
|
||||
recursive: bool = True,
|
||||
page_number: int = 1,
|
||||
page_size: int = 100):
|
||||
|
||||
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
|
||||
dataset_name=dataset_name, namespace=namespace)
|
||||
@@ -743,7 +745,8 @@ class HubApi:
|
||||
recursive = 'True' if recursive else 'False'
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
|
||||
params = {'Revision': revision if revision else 'master',
|
||||
'Root': root_path if root_path else '/', 'Recursive': recursive}
|
||||
'Root': root_path if root_path else '/', 'Recursive': recursive,
|
||||
'PageNumber': page_number, 'PageSize': page_size}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
r = self.session.get(datahub_url, params=params, cookies=cookies)
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
import fnmatch
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.errors import InvalidParameter
|
||||
from modelscope.hub.utils.caching import ModelFileSystemCache
|
||||
from modelscope.hub.utils.utils import model_id_to_group_owner_name
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
@@ -31,6 +33,8 @@ def snapshot_download(
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
@@ -56,6 +60,10 @@ def snapshot_download(
|
||||
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be downloading, like exact file names or file extensions.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
allow_patterns (`str` or `List`, *optional*, default to `None`):
|
||||
If provided, only files matching at least one pattern are downloaded, priority over allow_file_pattern.
|
||||
ignore_patterns (`str` or `List`, *optional*, default to `None`):
|
||||
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
@@ -71,6 +79,10 @@ def snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if allow_patterns:
|
||||
allow_file_pattern = allow_patterns
|
||||
if ignore_patterns:
|
||||
ignore_file_pattern = ignore_patterns
|
||||
return _snapshot_download(
|
||||
model_id,
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
@@ -81,7 +93,9 @@ def snapshot_download(
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
local_dir=local_dir)
|
||||
local_dir=local_dir,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
@@ -94,6 +108,8 @@ def dataset_snapshot_download(
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
@@ -120,6 +136,10 @@ def dataset_snapshot_download(
|
||||
Use regression is deprecated.
|
||||
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be downloading, like exact file names or file extensions.
|
||||
allow_patterns (`str` or `List`, *optional*, default to `None`):
|
||||
If provided, only files matching at least one pattern are downloaded, priority over allow_file_pattern.
|
||||
ignore_patterns (`str` or `List`, *optional*, default to `None`):
|
||||
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
@@ -135,6 +155,10 @@ def dataset_snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if allow_patterns:
|
||||
allow_file_pattern = allow_patterns
|
||||
if ignore_patterns:
|
||||
ignore_file_pattern = ignore_patterns
|
||||
return _snapshot_download(
|
||||
dataset_id,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
@@ -164,8 +188,8 @@ def _snapshot_download(
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
|
||||
repo_type, REPO_TYPE_SUPPORT))
|
||||
raise InvalidParameter('Invalid repo type: %s, only support: %s' %
|
||||
(repo_type, REPO_TYPE_SUPPORT))
|
||||
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
|
||||
@@ -184,8 +208,10 @@ def _snapshot_download(
|
||||
# make headers
|
||||
headers = {
|
||||
'user-agent':
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, ),
|
||||
}
|
||||
if 'CI_TEST' not in os.environ:
|
||||
headers['snapshot_identifier'] = str(uuid.uuid4())
|
||||
_api = HubApi()
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
@@ -212,82 +238,138 @@ def _snapshot_download(
|
||||
use_cookies=False if cookies is None else cookies,
|
||||
headers=snapshot_header,
|
||||
)
|
||||
_download_file_lists(
|
||||
repo_files,
|
||||
cache,
|
||||
temporary_cache_dir,
|
||||
repo_id,
|
||||
_api,
|
||||
None,
|
||||
None,
|
||||
headers,
|
||||
revision_detail=revision_detail,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
_api.dataset_download_statistics(name, group_or_owner)
|
||||
revision_detail = revision
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree
|
||||
and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
|
||||
if ignore_file_pattern is None:
|
||||
ignore_file_pattern = []
|
||||
if isinstance(ignore_file_pattern, str):
|
||||
ignore_file_pattern = [ignore_file_pattern]
|
||||
ignore_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in ignore_file_pattern
|
||||
]
|
||||
ignore_regex_pattern = []
|
||||
for file_pattern in ignore_file_pattern:
|
||||
if file_pattern.startswith('*'):
|
||||
ignore_regex_pattern.append('.' + file_pattern)
|
||||
else:
|
||||
ignore_regex_pattern.append(file_pattern)
|
||||
|
||||
if allow_file_pattern is not None:
|
||||
if isinstance(allow_file_pattern, str):
|
||||
allow_file_pattern = [allow_file_pattern]
|
||||
allow_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in allow_file_pattern
|
||||
]
|
||||
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree' or \
|
||||
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
|
||||
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=repo_id,
|
||||
file_path=repo_file['Path'],
|
||||
revision=revision)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
url = _api.get_dataset_file_url(
|
||||
file_name=repo_file['Path'],
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
while True:
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size)
|
||||
if not ('Code' in files_list_tree
|
||||
and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
_download_file_lists(
|
||||
repo_files,
|
||||
cache,
|
||||
temporary_cache_dir,
|
||||
repo_id,
|
||||
_api,
|
||||
name,
|
||||
group_or_owner,
|
||||
headers,
|
||||
revision_detail=revision_detail,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern)
|
||||
if len(repo_files) < page_size:
|
||||
break
|
||||
page_number += 1
|
||||
|
||||
download_file(url, repo_file, temporary_cache_dir, cache, headers,
|
||||
cookies)
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
return os.path.join(cache.get_root_location())
|
||||
def _download_file_lists(
|
||||
repo_files: List[str],
|
||||
cache: ModelFileSystemCache,
|
||||
temporary_cache_dir: str,
|
||||
repo_id: str,
|
||||
api: HubApi,
|
||||
name: str,
|
||||
group_or_owner: str,
|
||||
headers,
|
||||
revision_detail: str,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
if ignore_file_pattern is None:
|
||||
ignore_file_pattern = []
|
||||
if isinstance(ignore_file_pattern, str):
|
||||
ignore_file_pattern = [ignore_file_pattern]
|
||||
ignore_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in ignore_file_pattern
|
||||
]
|
||||
ignore_regex_pattern = []
|
||||
for file_pattern in ignore_file_pattern:
|
||||
if file_pattern.startswith('*'):
|
||||
ignore_regex_pattern.append('.' + file_pattern)
|
||||
else:
|
||||
ignore_regex_pattern.append(file_pattern)
|
||||
|
||||
if allow_file_pattern is not None:
|
||||
if isinstance(allow_file_pattern, str):
|
||||
allow_file_pattern = [allow_file_pattern]
|
||||
allow_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in allow_file_pattern
|
||||
]
|
||||
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree' or \
|
||||
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
|
||||
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=repo_id,
|
||||
file_path=repo_file['Path'],
|
||||
revision=revision)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
url = api.get_dataset_file_url(
|
||||
file_name=repo_file['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
|
||||
download_file(url, repo_file, temporary_cache_dir, cache, headers,
|
||||
cookies)
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
return os.path.join(cache.get_root_location())
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Optional
|
||||
from modelscope.metainfo import Tasks
|
||||
from modelscope.utils.ast_utils import INDEX_KEY
|
||||
from modelscope.utils.import_utils import (LazyImportModule,
|
||||
is_torch_available,
|
||||
is_transformers_available)
|
||||
|
||||
|
||||
@@ -36,7 +37,7 @@ def post_init(self, *args, **kwargs):
|
||||
|
||||
|
||||
def fix_transformers_upgrade():
|
||||
if is_transformers_available():
|
||||
if is_transformers_available() and is_torch_available():
|
||||
# from 4.35.0, transformers changes its arguments of _set_gradient_checkpointing
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
@@ -10,4 +10,5 @@ scipy
|
||||
setuptools==69.5.1
|
||||
simplejson>=3.3.0
|
||||
sortedcontainers>=1.5.9
|
||||
transformers
|
||||
urllib3>=1.26
|
||||
|
||||
Reference in New Issue
Block a user