fix dataset page bug, framework add transformers dependency

This commit is contained in:
mulin.lyh
2024-07-25 19:35:52 +08:00
parent 4e2555c5a3
commit d5d3d49206
5 changed files with 170 additions and 78 deletions

View File

@@ -9,6 +9,7 @@ from modelscope.cli.modelcard import ModelCardCMD
from modelscope.cli.pipeline import PipelineCMD
from modelscope.cli.plugins import PluginsCMD
from modelscope.cli.server import ServerCMD
from modelscope.hub.api import HubApi
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.WARNING)
@@ -17,6 +18,8 @@ logger = get_logger(log_level=logging.WARNING)
def run_cmd():
parser = argparse.ArgumentParser(
'ModelScope Command Line tool', usage='modelscope <command> [<args>]')
parser.add_argument(
'--token', default=None, help='Specify modelscope token.')
subparsers = parser.add_subparsers(help='modelscope commands helpers')
DownloadCMD.define_args(subparsers)
@@ -31,7 +34,9 @@ def run_cmd():
if not hasattr(args, 'func'):
parser.print_help()
exit(1)
if args.token is not None:
api = HubApi()
api.login(args.token)
cmd = args.func(args)
cmd.execute()

View File

@@ -735,7 +735,9 @@ class HubApi:
namespace: str,
revision: str,
root_path: str,
recursive: bool = True):
recursive: bool = True,
page_number: int = 1,
page_size: int = 100):
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
dataset_name=dataset_name, namespace=namespace)
@@ -743,7 +745,8 @@ class HubApi:
recursive = 'True' if recursive else 'False'
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
params = {'Revision': revision if revision else 'master',
'Root': root_path if root_path else '/', 'Recursive': recursive}
'Root': root_path if root_path else '/', 'Recursive': recursive,
'PageNumber': page_number, 'PageSize': page_size}
cookies = ModelScopeConfig.get_cookies()
r = self.session.get(datahub_url, params=params, cookies=cookies)

View File

@@ -3,12 +3,14 @@
import fnmatch
import os
import re
import uuid
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import InvalidParameter
from modelscope.hub.utils.caching import ModelFileSystemCache
from modelscope.hub.utils.utils import model_id_to_group_owner_name
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
@@ -31,6 +33,8 @@ def snapshot_download(
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
local_dir: Optional[str] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
@@ -56,6 +60,10 @@ def snapshot_download(
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be downloading, like exact file names or file extensions.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
allow_patterns (`str` or `List`, *optional*, default to `None`):
If provided, only files matching at least one pattern are downloaded, priority over allow_file_pattern.
ignore_patterns (`str` or `List`, *optional*, default to `None`):
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
Raises:
ValueError: the value details.
@@ -71,6 +79,10 @@ def snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
if allow_patterns:
allow_file_pattern = allow_patterns
if ignore_patterns:
ignore_file_pattern = ignore_patterns
return _snapshot_download(
model_id,
repo_type=REPO_TYPE_MODEL,
@@ -81,7 +93,9 @@ def snapshot_download(
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
local_dir=local_dir,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns)
def dataset_snapshot_download(
@@ -94,6 +108,8 @@ def dataset_snapshot_download(
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
@@ -120,6 +136,10 @@ def dataset_snapshot_download(
Use regression is deprecated.
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be downloading, like exact file names or file extensions.
allow_patterns (`str` or `List`, *optional*, default to `None`):
If provided, only files matching at least one pattern are downloaded, priority over allow_file_pattern.
ignore_patterns (`str` or `List`, *optional*, default to `None`):
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
Raises:
ValueError: the value details.
@@ -135,6 +155,10 @@ def dataset_snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
if allow_patterns:
allow_file_pattern = allow_patterns
if ignore_patterns:
ignore_file_pattern = ignore_patterns
return _snapshot_download(
dataset_id,
repo_type=REPO_TYPE_DATASET,
@@ -164,8 +188,8 @@ def _snapshot_download(
if not repo_type:
repo_type = REPO_TYPE_MODEL
if repo_type not in REPO_TYPE_SUPPORT:
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
repo_type, REPO_TYPE_SUPPORT))
raise InvalidParameter('Invalid repo type: %s, only support: %s' %
(repo_type, REPO_TYPE_SUPPORT))
temporary_cache_dir, cache = create_temporary_directory_and_cache(
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
@@ -184,8 +208,10 @@ def _snapshot_download(
# make headers
headers = {
'user-agent':
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
ModelScopeConfig.get_user_agent(user_agent=user_agent, ),
}
if 'CI_TEST' not in os.environ:
headers['snapshot_identifier'] = str(uuid.uuid4())
_api = HubApi()
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
@@ -212,82 +238,138 @@ def _snapshot_download(
use_cookies=False if cookies is None else cookies,
headers=snapshot_header,
)
_download_file_lists(
repo_files,
cache,
temporary_cache_dir,
repo_id,
_api,
None,
None,
headers,
revision_detail=revision_detail,
repo_type=repo_type,
revision=revision,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern)
elif repo_type == REPO_TYPE_DATASET:
group_or_owner, name = model_id_to_group_owner_name(repo_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
_api.dataset_download_statistics(name, group_or_owner)
revision_detail = revision
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree
and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (repo_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
repo_files = files_list_tree['Data']['Files']
if ignore_file_pattern is None:
ignore_file_pattern = []
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
ignore_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in ignore_file_pattern
]
ignore_regex_pattern = []
for file_pattern in ignore_file_pattern:
if file_pattern.startswith('*'):
ignore_regex_pattern.append('.' + file_pattern)
else:
ignore_regex_pattern.append(file_pattern)
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
allow_file_pattern = [allow_file_pattern]
allow_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in allow_file_pattern
]
for repo_file in repo_files:
if repo_file['Type'] == 'tree' or \
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
continue
if repo_type == REPO_TYPE_MODEL:
# get download url
url = get_file_download_url(
model_id=repo_id,
file_path=repo_file['Path'],
revision=revision)
elif repo_type == REPO_TYPE_DATASET:
url = _api.get_dataset_file_url(
file_name=repo_file['Path'],
page_number = 1
page_size = 100
while True:
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision)
revision=revision,
root_path='/',
recursive=True,
page_number=page_number,
page_size=page_size)
if not ('Code' in files_list_tree
and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (repo_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
repo_files = files_list_tree['Data']['Files']
_download_file_lists(
repo_files,
cache,
temporary_cache_dir,
repo_id,
_api,
name,
group_or_owner,
headers,
revision_detail=revision_detail,
repo_type=repo_type,
revision=revision,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern)
if len(repo_files) < page_size:
break
page_number += 1
download_file(url, repo_file, temporary_cache_dir, cache, headers,
cookies)
cache.save_model_version(revision_info=revision_detail)
return os.path.join(cache.get_root_location())
def _download_file_lists(
repo_files: List[str],
cache: ModelFileSystemCache,
temporary_cache_dir: str,
repo_id: str,
api: HubApi,
name: str,
group_or_owner: str,
headers,
revision_detail: str,
repo_type: Optional[str] = None,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
):
if ignore_file_pattern is None:
ignore_file_pattern = []
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
ignore_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in ignore_file_pattern
]
ignore_regex_pattern = []
for file_pattern in ignore_file_pattern:
if file_pattern.startswith('*'):
ignore_regex_pattern.append('.' + file_pattern)
else:
ignore_regex_pattern.append(file_pattern)
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
allow_file_pattern = [allow_file_pattern]
allow_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in allow_file_pattern
]
for repo_file in repo_files:
if repo_file['Type'] == 'tree' or \
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
continue
if repo_type == REPO_TYPE_MODEL:
# get download url
url = get_file_download_url(
model_id=repo_id,
file_path=repo_file['Path'],
revision=revision)
elif repo_type == REPO_TYPE_DATASET:
url = api.get_dataset_file_url(
file_name=repo_file['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
download_file(url, repo_file, temporary_cache_dir, cache, headers,
cookies)
cache.save_model_version(revision_info=revision_detail)
return os.path.join(cache.get_root_location())

View File

@@ -6,6 +6,7 @@ from typing import Any, Optional
from modelscope.metainfo import Tasks
from modelscope.utils.ast_utils import INDEX_KEY
from modelscope.utils.import_utils import (LazyImportModule,
is_torch_available,
is_transformers_available)
@@ -36,7 +37,7 @@ def post_init(self, *args, **kwargs):
def fix_transformers_upgrade():
if is_transformers_available():
if is_transformers_available() and is_torch_available():
# from 4.35.0, transformers changes its arguments of _set_gradient_checkpointing
import transformers
from transformers import PreTrainedModel

View File

@@ -10,4 +10,5 @@ scipy
setuptools==69.5.1
simplejson>=3.3.0
sortedcontainers>=1.5.9
transformers
urllib3>=1.26