From 083aa80d097744e0402bd5d02e98aaff1bb9d864 Mon Sep 17 00:00:00 2001 From: liuyhwangyh Date: Wed, 3 Jul 2024 18:23:50 +0800 Subject: [PATCH 1/3] refactor ast index file generator (#897) * refactor ast index file generator * fix lint issue --------- Co-authored-by: mulin.lyh --- modelscope/utils/ast_utils.py | 85 +++++++++++++++++++---------------- modelscope/utils/plugins.py | 3 +- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/modelscope/utils/ast_utils.py b/modelscope/utils/ast_utils.py index 72627e2d..3554b0ea 100644 --- a/modelscope/utils/ast_utils.py +++ b/modelscope/utils/ast_utils.py @@ -7,12 +7,14 @@ import os import os.path as osp import time import traceback +from datetime import datetime from functools import reduce from pathlib import Path from typing import Union import json +from modelscope import version # do not delete from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers, Metrics, Models, Optimizers, Pipelines, @@ -658,6 +660,19 @@ def _update_index(index, files_mtime): index[REQUIREMENT_KEY].update(updated_index[REQUIREMENT_KEY]) +def __is_develop_model(): + # use the trick of release time check is in development + release_timestamp = int( + round( + datetime.strptime(version.__release_datetime__, + '%Y-%m-%d %H:%M:%S').timestamp())) + SECONDS_PER_YEAR = 24 * 365 * 60 * 60 + current_timestamp = int(round(datetime.now().timestamp())) + if release_timestamp > current_timestamp + SECONDS_PER_YEAR: + return True + return False + + def load_index( file_list=None, force_rebuild=False, @@ -699,51 +714,43 @@ def load_index( file_path = os.path.join(cache_dir, index_file) logger.info(f'Loading ast index from {file_path}') index = None - local_changed = False - if not force_rebuild and os.path.exists(file_path): - wrapped_index = _load_index(file_path) - md5, files_mtime = file_scanner.files_mtime_md5(file_list=file_list) - from modelscope.version import __version__ - if (wrapped_index[VERSION_KEY] == __version__): + + if force_rebuild: + logger.info('Force rebuilding ast index from scanning every file!') + index = file_scanner.get_files_scan_results(file_list) + return index + + # when developing, we need to generator as need. + if __is_develop_model(): + if os.path.exists(file_path): # already exist, check it's latest + wrapped_index = _load_index(file_path) + md5, files_mtime = file_scanner.files_mtime_md5( + file_list=file_list) index = wrapped_index - if (wrapped_index[MD5_KEY] != md5): - local_changed = True - full_index_flag = False - - if index is None: - full_index_flag = True - elif index and local_changed and FILES_MTIME_KEY not in index: - full_index_flag = True - elif index and local_changed and MODELSCOPE_PATH_KEY not in index: - full_index_flag = True - elif index and local_changed and index[ - MODELSCOPE_PATH_KEY] != MODELSCOPE_PATH.as_posix(): - full_index_flag = True - - if full_index_flag: - if force_rebuild: - logger.info('Force rebuilding ast index from scanning every file!') - index = file_scanner.get_files_scan_results(file_list) + from modelscope.version import __version__ + if (wrapped_index[VERSION_KEY] == __version__ + and wrapped_index[MD5_KEY] != md5) or \ + wrapped_index[VERSION_KEY] != __version__: + logger.info( + 'Updating the files for the changes of local files, ' + 'first time updating will take longer time! Please wait till updating done!' + ) + _update_index(index, files_mtime) + _save_index(index, file_path, file_list) else: logger.info( - f'No valid ast index found from {file_path}, generating ast index from prebuilt!' + f'No valid ast index found from {file_path}, generating ast index from scratch!' ) - index = load_from_prebuilt() - if index is None: - index = file_scanner.get_files_scan_results(file_list) - _save_index(index, file_path, file_list) - elif local_changed and not full_index_flag: + index = file_scanner.get_files_scan_results( + file_list) # generate new + _save_index(index, file_path, file_list) # save to generate path. logger.info( - 'Updating the files for the changes of local files, ' - 'first time updating will take longer time! Please wait till updating done!' - ) - _update_index(index, files_mtime) - _save_index(index, file_path, file_list) + f'Loading done! Current index file version is {index[VERSION_KEY]}, ' + f'with md5 {index[MD5_KEY]} and a total number of ' + f'{len(index[INDEX_KEY])} components indexed') + else: # just load the prebuild index file. + index = load_from_prebuilt() - logger.info( - f'Loading done! Current index file version is {index[VERSION_KEY]}, ' - f'with md5 {index[MD5_KEY]} and a total number of ' - f'{len(index[INDEX_KEY])} components indexed') return index diff --git a/modelscope/utils/plugins.py b/modelscope/utils/plugins.py index 8aff0fb1..1f191a8d 100644 --- a/modelscope/utils/plugins.py +++ b/modelscope/utils/plugins.py @@ -395,7 +395,8 @@ def import_module_from_model_dir(model_dir): ] create_module_from_files(relative_file_dirs, model_dir, module_name) for file in relative_file_dirs: - submodule = module_name + '.' + file.replace('.py', '').replace(os.sep, '.') + submodule = module_name + '.' + file.replace('.py', '').replace( + os.sep, '.') importlib.import_module(submodule) From 53acf79d042bbc029f8d380bba21a2234398b6e4 Mon Sep 17 00:00:00 2001 From: liuyhwangyh Date: Wed, 3 Jul 2024 20:03:46 +0800 Subject: [PATCH 2/3] fix logger issue (#899) Co-authored-by: mulin.lyh --- modelscope/utils/ast_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/modelscope/utils/ast_utils.py b/modelscope/utils/ast_utils.py index 3554b0ea..657ebb33 100644 --- a/modelscope/utils/ast_utils.py +++ b/modelscope/utils/ast_utils.py @@ -21,10 +21,8 @@ from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers, Preprocessors, TaskModels, Trainers) from modelscope.utils.constant import Fields, Tasks from modelscope.utils.file_utils import get_modelscope_cache_dir -from modelscope.utils.logger import get_logger from modelscope.utils.registry import default_group -logger = get_logger(log_level=logging.WARNING) p = Path(__file__) # get the path of package 'modelscope' @@ -58,6 +56,15 @@ TEMPLATE_PATH = 'TEMPLATE_PATH' TEMPLATE_FILE = 'ast_index_file.py' +def get_ast_logger(): + ast_logger = logging.getLogger('modelscope.ast') + ast_logger.setLevel(logging.INFO) + return ast_logger + + +logger = get_ast_logger() + + class AstScanning(object): def __init__(self) -> None: @@ -712,7 +719,6 @@ def load_index( cache_dir = os.getenv('MODELSCOPE_CACHE', indexer_file_dir) index_file = os.getenv('MODELSCOPE_INDEX_FILE', indexer_file) file_path = os.path.join(cache_dir, index_file) - logger.info(f'Loading ast index from {file_path}') index = None if force_rebuild: @@ -722,6 +728,7 @@ def load_index( # when developing, we need to generator as need. if __is_develop_model(): + logger.info(f'Loading ast index from {file_path}') if os.path.exists(file_path): # already exist, check it's latest wrapped_index = _load_index(file_path) md5, files_mtime = file_scanner.files_mtime_md5( From f7a32e48e301d7f918b4caf4d3d0c65df9030d50 Mon Sep 17 00:00:00 2001 From: liuyhwangyh Date: Fri, 12 Jul 2024 19:06:44 +0800 Subject: [PATCH 3/3] add dataset download (#906) * add dataset download * fix cr issue * fix cv matplotlib issue * refactor code * fix ut issue * remove debug code * remove unused import * fix import issue * sleep 65s before start docker avoid kill and run failed --------- Co-authored-by: mulin.lyh --- .dev_scripts/dockerci.sh | 1 + docker/Dockerfile.ubuntu | 32 ++-- modelscope/__init__.py | 8 +- modelscope/cli/download.py | 96 +++++++--- modelscope/hub/api.py | 16 +- modelscope/hub/errors.py | 4 +- modelscope/hub/file_download.py | 244 ++++++++++++++++++------ modelscope/hub/snapshot_download.py | 234 +++++++++++++++++------ modelscope/hub/utils/caching.py | 9 +- modelscope/utils/constant.py | 5 + modelscope/utils/file_utils.py | 26 ++- requirements/cv.txt | 4 +- requirements/datasets.txt | 12 ++ requirements/framework.txt | 15 +- setup.py | 2 +- tests/hub/test_download_dataset_file.py | 169 ++++++++++++++++ tests/hub/test_hub_retry.py | 10 +- 17 files changed, 694 insertions(+), 193 deletions(-) create mode 100644 requirements/datasets.txt create mode 100644 tests/hub/test_download_dataset_file.py diff --git a/.dev_scripts/dockerci.sh b/.dev_scripts/dockerci.sh index 0278a785..4f66073c 100644 --- a/.dev_scripts/dockerci.sh +++ b/.dev_scripts/dockerci.sh @@ -14,6 +14,7 @@ echo "PR modified files: $PR_CHANGED_FILES" PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#} echo "PR_CHANGED_FILES: $PR_CHANGED_FILES" idx=0 +sleep 65 for gpu in $gpus do exec {lock_fd}>"/tmp/gpu$gpu" || exit 1 diff --git a/docker/Dockerfile.ubuntu b/docker/Dockerfile.ubuntu index 437ee9e2..8e7543cd 100644 --- a/docker/Dockerfile.ubuntu +++ b/docker/Dockerfile.ubuntu @@ -34,21 +34,7 @@ RUN if [ "$USE_GPU" = "True" ] ; then \ echo 'cpu unsupport detectron2'; \ fi -# torchmetrics==0.11.4 for ofa -# tinycudann for cuda12.1.0 pytorch 2.1.2 -RUN if [ "$USE_GPU" = "True" ] ; then \ - pip install --no-cache-dir --upgrade pip && \ - pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \ - pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ - pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \ - pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \ - pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ - pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \ - pip install --no-cache-dir -U vllm; \ - else \ - echo 'cpu unsupport vllm auto-gptq'; \ - fi - +# install dependencies COPY requirements /var/modelscope RUN pip install --no-cache-dir --upgrade pip && \ pip install --no-cache-dir -r /var/modelscope/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ @@ -64,5 +50,21 @@ RUN pip install --no-cache-dir --upgrade pip && \ pip cache purge # 'scipy<1.13.0' for cannot import name 'kaiser' from 'scipy.signal' COPY examples /modelscope/examples +# torchmetrics==0.11.4 for ofa +# tinycudann for cuda12.1.0 pytorch 2.1.2 +RUN if [ "$USE_GPU" = "True" ] ; then \ + pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \ + pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ + pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \ + pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \ + pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ + pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \ + pip install --no-cache-dir -U triton vllm https://modelscope.oss-cn-beijing.aliyuncs.com/packages/lmdeploy-0.5.0-cp310-cp310-linux_x86_64.whl; \ + else \ + echo 'cpu unsupport vllm auto-gptq'; \ + fi + ENV SETUPTOOLS_USE_DISTUTILS=stdlib ENV VLLM_USE_MODELSCOPE=True +ENV LMDEPLOY_USE_MODELSCOPE=True diff --git a/modelscope/__init__.py b/modelscope/__init__.py index 8f2fbbee..34a9bfc2 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -10,7 +10,9 @@ if TYPE_CHECKING: from .hub.api import HubApi from .hub.check_model import check_local_model_is_latest, check_model_is_id from .hub.push_to_hub import push_to_hub, push_to_hub_async - from .hub.snapshot_download import snapshot_download + from .hub.snapshot_download import snapshot_download, dataset_snapshot_download + from .hub.file_download import model_file_download, dataset_file_download + from .metrics import ( AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric, ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric, @@ -59,7 +61,9 @@ else: 'TorchModelExporter', ], 'hub.api': ['HubApi'], - 'hub.snapshot_download': ['snapshot_download'], + 'hub.snapshot_download': + ['snapshot_download', 'dataset_snapshot_download'], + 'hub.file_download': ['model_file_download', 'dataset_file_download'], 'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'], 'hub.check_model': ['check_model_is_id', 'check_local_model_is_latest'], diff --git a/modelscope/cli/download.py b/modelscope/cli/download.py index 9adae9a2..fe2b37e9 100644 --- a/modelscope/cli/download.py +++ b/modelscope/cli/download.py @@ -3,8 +3,10 @@ from argparse import ArgumentParser from modelscope.cli.base import CLICommand -from modelscope.hub.file_download import model_file_download -from modelscope.hub.snapshot_download import snapshot_download +from modelscope.hub.file_download import (dataset_file_download, + model_file_download) +from modelscope.hub.snapshot_download import (dataset_snapshot_download, + snapshot_download) def subparser_func(args): @@ -24,11 +26,17 @@ class DownloadCMD(CLICommand): """ define args for download command. """ parser: ArgumentParser = parsers.add_parser(DownloadCMD.name) - parser.add_argument( + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( '--model', type=str, - required=True, - help='The model id to be downloaded.') + help='The id of the model to be downloaded. For download, ' + 'the id of either a model or dataset must be provided.') + group.add_argument( + '--dataset', + type=str, + help='The id of the dataset to be downloaded. For download, ' + 'the id of either a model or dataset must be provided.') parser.add_argument( '--revision', type=str, @@ -69,27 +77,57 @@ class DownloadCMD(CLICommand): parser.set_defaults(func=subparser_func) def execute(self): - if len(self.args.files) == 1: # download single file - model_file_download( - self.args.model, - self.args.files[0], - cache_dir=self.args.cache_dir, - local_dir=self.args.local_dir, - revision=self.args.revision) - elif len(self.args.files) > 1: # download specified multiple files. - snapshot_download( - self.args.model, - revision=self.args.revision, - cache_dir=self.args.cache_dir, - local_dir=self.args.local_dir, - allow_file_pattern=self.args.files, - ) - else: # download repo - snapshot_download( - self.args.model, - revision=self.args.revision, - cache_dir=self.args.cache_dir, - local_dir=self.args.local_dir, - allow_file_pattern=self.args.include, - ignore_file_pattern=self.args.exclude, - ) + if self.args.model: + if len(self.args.files) == 1: # download single file + model_file_download( + self.args.model, + self.args.files[0], + cache_dir=self.args.cache_dir, + local_dir=self.args.local_dir, + revision=self.args.revision) + elif len( + self.args.files) > 1: # download specified multiple files. + snapshot_download( + self.args.model, + revision=self.args.revision, + cache_dir=self.args.cache_dir, + local_dir=self.args.local_dir, + allow_file_pattern=self.args.files, + ) + else: # download repo + snapshot_download( + self.args.model, + revision=self.args.revision, + cache_dir=self.args.cache_dir, + local_dir=self.args.local_dir, + allow_file_pattern=self.args.include, + ignore_file_pattern=self.args.exclude, + ) + elif self.args.dataset: + if len(self.args.files) == 1: # download single file + dataset_file_download( + self.args.dataset, + self.args.files[0], + cache_dir=self.args.cache_dir, + local_dir=self.args.local_dir, + revision=self.args.revision) + elif len( + self.args.files) > 1: # download specified multiple files. + dataset_snapshot_download( + self.args.dataset, + revision=self.args.revision, + cache_dir=self.args.cache_dir, + local_dir=self.args.local_dir, + allow_file_pattern=self.args.files, + ) + else: # download repo + dataset_snapshot_download( + self.args.dataset, + revision=self.args.revision, + cache_dir=self.args.cache_dir, + local_dir=self.args.local_dir, + allow_file_pattern=self.args.include, + ignore_file_pattern=self.args.exclude, + ) + else: + pass # noop diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index da4ce5d4..26d82bee 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -76,6 +76,7 @@ class HubApi: connect=2, backoff_factor=1, status_forcelist=(500, 502, 503, 504), + respect_retry_after_header=False, ) adapter = HTTPAdapter(max_retries=retry) self.session.mount('http://', adapter) @@ -741,7 +742,8 @@ class HubApi: recursive = 'True' if recursive else 'False' datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree' - params = {'Revision': revision, 'Root': root_path, 'Recursive': recursive} + params = {'Revision': revision if revision else 'master', + 'Root': root_path if root_path else '/', 'Recursive': recursive} cookies = ModelScopeConfig.get_cookies() r = self.session.get(datahub_url, params=params, cookies=cookies) @@ -771,8 +773,10 @@ class HubApi: @staticmethod def dump_datatype_file(dataset_type: int, meta_cache_dir: str): """ - Dump the data_type as a local file, in order to get the dataset formation without calling the datahub. - More details, please refer to the class `modelscope.utils.constant.DatasetFormations`. + Dump the data_type as a local file, in order to get the dataset + formation without calling the datahub. + More details, please refer to the class + `modelscope.utils.constant.DatasetFormations`. """ dataset_type_file_path = os.path.join(meta_cache_dir, f'{str(dataset_type)}{DatasetFormations.formation_mark_ext.value}') @@ -874,13 +878,14 @@ class HubApi: dataset_name: str, namespace: str, revision: Optional[str] = DEFAULT_DATASET_REVISION, + view: Optional[bool] = False, extension_filter: Optional[bool] = True): if not file_name or not dataset_name or not namespace: raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!') # Note: make sure the FilePath is the last parameter in the url - params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name} + params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view} params: str = urlencode(params) file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}' @@ -1113,7 +1118,8 @@ class ModelScopeConfig: ModelScopeConfig.cookie_expired_warning = True logger.warning( 'Authentication has expired, ' - 'please re-login if you need to access private models or datasets.') + 'please re-login with modelscope login --token "YOUR_SDK_TOKEN" ' + 'if you need to access private models or datasets.') return None return cookies return None diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py index 8258399d..e2288787 100644 --- a/modelscope/hub/errors.py +++ b/modelscope/hub/errors.py @@ -153,9 +153,9 @@ def datahub_raise_on_error(url, rsp, http_response: requests.Response): if rsp.get('Code') == HTTPStatus.OK: return True else: - request_id = get_request_id(http_response) + request_id = rsp['RequestId'] raise RequestError( - f"Url = {url}, Request id={request_id} Message = {rsp.get('Message')},\ + f"Url = {url}, Request id={request_id} Code = {rsp['Code']} Message = {rsp['Message']},\ Please specify correct dataset_name and namespace.") diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 8b4f6bc5..29202404 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -21,10 +21,14 @@ from modelscope.hub.constants import ( API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES, API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS, MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME) -from modelscope.utils.constant import DEFAULT_MODEL_REVISION -from modelscope.utils.file_utils import get_model_cache_root +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, + DEFAULT_MODEL_REVISION, + REPO_TYPE_DATASET, REPO_TYPE_MODEL, + REPO_TYPE_SUPPORT) +from modelscope.utils.file_utils import (get_dataset_cache_root, + get_model_cache_root) from modelscope.utils.logger import get_logger -from .errors import FileDownloadError, NotExistError +from .errors import FileDownloadError, InvalidParameter, NotExistError from .utils.caching import ModelFileSystemCache from .utils.utils import (file_integrity_validation, get_endpoint, model_id_to_group_owner_name) @@ -78,8 +82,97 @@ def model_file_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ + return _repo_file_download( + model_id, + file_path, + repo_type=REPO_TYPE_MODEL, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + local_dir=local_dir) + + +def dataset_file_download( + dataset_id: str, + file_path: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION, + cache_dir: Union[str, Path, None] = None, + local_dir: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, +) -> str: + """Download raw files of a dataset. + Downloads all files at the specified revision. This + is useful when you want all files from a dataset, because you don't know which + ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. + + An alternative would be to just clone a dataset but this would require that the + user always has git and git-lfs installed, and properly configured. + + Args: + dataset_id (str): A user or an organization name and a dataset name separated by a `/`. + file_path (str): The relative path of the file to download. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. NOTE: currently only branch and tag name is supported + cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will + be save as cache_dir/dataset_id/THE_DATASET_FILES. + local_dir (str, optional): Specific local directory path to which the file will be downloaded. + user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string. + local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + cookies (CookieJar, optional): The cookie of the request, default None. + Raises: + ValueError: the value details. + + Returns: + str: Local folder path (string) of repo snapshot + + Note: + Raises the following errors: + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `use_auth_token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + """ + return _repo_file_download( + dataset_id, + file_path, + repo_type=REPO_TYPE_DATASET, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + local_dir=local_dir) + + +def _repo_file_download( + repo_id: str, + file_path: str, + *, + repo_type: str = None, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cache_dir: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, + local_dir: Optional[str] = None, +) -> Optional[str]: # pragma: no cover + + if not repo_type: + repo_type = REPO_TYPE_MODEL + if repo_type not in REPO_TYPE_SUPPORT: + raise InvalidParameter('Invalid repo type: %s, only support: %s' ( + repo_type, REPO_TYPE_SUPPORT)) + temporary_cache_dir, cache = create_temporary_directory_and_cache( - model_id, local_dir, cache_dir) + repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type) # if local_files_only is `True` and the file already exists in cached_path # return the cached path @@ -93,7 +186,7 @@ def model_file_download( else: raise ValueError( 'Cannot find the requested files in the cached path and outgoing' - ' traffic has been disabled. To enable model look-ups and downloads' + ' traffic has been disabled. To enable look-ups and downloads' " online, set 'local_files_only' to False.") _api = HubApi() @@ -102,75 +195,84 @@ def model_file_download( } if cookies is None: cookies = ModelScopeConfig.get_cookies() + repo_files = [] + if repo_type == REPO_TYPE_MODEL: + revision = _api.get_valid_revision( + repo_id, revision=revision, cookies=cookies) + file_to_download_meta = None + # we need to confirm the version is up-to-date + # we need to get the file list to check if the latest version is cached, if so return, otherwise download + repo_files = _api.get_model_files( + model_id=repo_id, + revision=revision, + recursive=True, + use_cookies=False if cookies is None else cookies) + elif repo_type == REPO_TYPE_DATASET: + group_or_owner, name = model_id_to_group_owner_name(repo_id) + if not revision: + revision = DEFAULT_DATASET_REVISION + files_list_tree = _api.list_repo_tree( + dataset_name=name, + namespace=group_or_owner, + revision=revision, + root_path='/', + recursive=True) + if not ('Code' in files_list_tree and files_list_tree['Code'] == 200): + print( + 'Get dataset: %s file list failed, request_id: %s, message: %s' + % (repo_id, files_list_tree['RequestId'], + files_list_tree['Message'])) + return None + repo_files = files_list_tree['Data']['Files'] - revision = _api.get_valid_revision( - model_id, revision=revision, cookies=cookies) - file_to_download_info = None - # we need to confirm the version is up-to-date - # we need to get the file list to check if the latest version is cached, if so return, otherwise download - model_files = _api.get_model_files( - model_id=model_id, - revision=revision, - recursive=True, - use_cookies=False if cookies is None else cookies) - - for model_file in model_files: - if model_file['Type'] == 'tree': + file_to_download_meta = None + for repo_file in repo_files: + if repo_file['Type'] == 'tree': continue - if model_file['Path'] == file_path: - if cache.exists(model_file): + if repo_file['Path'] == file_path: + if cache.exists(repo_file): logger.debug( - f'File {model_file["Name"]} already in cache, skip downloading!' + f'File {repo_file["Name"]} already in cache, skip downloading!' ) - return cache.get_file_by_info(model_file) + return cache.get_file_by_info(repo_file) else: - file_to_download_info = model_file + file_to_download_meta = repo_file break - if file_to_download_info is None: + if file_to_download_meta is None: raise NotExistError('The file path: %s not exist in: %s' % - (file_path, model_id)) + (file_path, repo_id)) # we need to download again - url_to_download = get_file_download_url(model_id, file_path, revision) - - if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_to_download_info[ - 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: - parallel_download( - url_to_download, - temporary_cache_dir, - file_path, - headers=headers, - cookies=None if cookies is None else cookies.get_dict(), - file_size=file_to_download_info['Size']) - else: - http_get_model_file( - url_to_download, - temporary_cache_dir, - file_path, - file_size=file_to_download_info['Size'], - headers=headers, - cookies=None if cookies is None else cookies.get_dict()) - - temp_file_path = os.path.join(temporary_cache_dir, file_path) - # for download with commit we can't get Sha256 - if file_to_download_info[FILE_HASH] is not None: - file_integrity_validation(temp_file_path, - file_to_download_info[FILE_HASH]) - return cache.put_file(file_to_download_info, - os.path.join(temporary_cache_dir, file_path)) + if repo_type == REPO_TYPE_MODEL: + url_to_download = get_file_download_url(repo_id, file_path, revision) + elif repo_type == REPO_TYPE_DATASET: + url_to_download = _api.get_dataset_file_url( + file_name=file_to_download_meta['Path'], + dataset_name=name, + namespace=group_or_owner, + revision=revision) + return download_file(url_to_download, file_to_download_meta, + temporary_cache_dir, cache, headers, cookies) -def create_temporary_directory_and_cache(model_id: str, local_dir: str, - cache_dir: str): +def create_temporary_directory_and_cache(model_id: str, + local_dir: str = None, + cache_dir: str = None, + repo_type: str = REPO_TYPE_MODEL): + if repo_type == REPO_TYPE_MODEL: + default_cache_root = get_model_cache_root() + elif repo_type == REPO_TYPE_DATASET: + default_cache_root = get_dataset_cache_root() + group_or_owner, name = model_id_to_group_owner_name(model_id) if local_dir is not None: temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME) cache = ModelFileSystemCache(local_dir) else: if cache_dir is None: - cache_dir = get_model_cache_root() + cache_dir = default_cache_root if isinstance(cache_dir, Path): cache_dir = str(cache_dir) temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME, @@ -269,6 +371,7 @@ def parallel_download( PART_SIZE = 160 * 1024 * 1024 # every part is 160M tasks = [] file_path = os.path.join(local_dir, file_name) + os.makedirs(os.path.dirname(file_path), exist_ok=True) for idx in range(int(file_size / PART_SIZE)): start = idx * PART_SIZE end = (idx + 1) * PART_SIZE - 1 @@ -323,6 +426,7 @@ def http_get_model_file( get_headers = {} if headers is None else copy.deepcopy(headers) get_headers['X-Request-ID'] = str(uuid.uuid4().hex) temp_file_path = os.path.join(local_dir, file_name) + os.makedirs(os.path.dirname(temp_file_path), exist_ok=True) logger.debug('downloading %s to %s', url, temp_file_path) # retry sleep 0.5s, 1s, 2s, 4s retry = Retry( @@ -349,7 +453,7 @@ def http_get_model_file( break get_headers['Range'] = 'bytes=%s-%s' % (partial_length, file_size - 1) - with open(temp_file_path, 'ab') as f: + with open(temp_file_path, 'ab+') as f: r = requests.get( url, stream=True, @@ -451,3 +555,31 @@ def http_get_file( logger.error(msg) raise FileDownloadError(msg) os.replace(temp_file.name, os.path.join(local_dir, file_name)) + + +def download_file(url, file_meta, temporary_cache_dir, cache, headers, + cookies): + if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[ + 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file. + parallel_download( + url, + temporary_cache_dir, + file_meta['Path'], + headers=headers, + cookies=None if cookies is None else cookies.get_dict(), + file_size=file_meta['Size']) + else: + http_get_model_file( + url, + temporary_cache_dir, + file_meta['Path'], + file_size=file_meta['Size'], + headers=headers, + cookies=cookies) + + # check file integrity + temp_file = os.path.join(temporary_cache_dir, file_meta['Path']) + if FILE_HASH in file_meta: + file_integrity_validation(temp_file, file_meta[FILE_HASH]) + # put file into to cache + return cache.put_file(file_meta, temp_file) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 053b4c3b..4e89a6b0 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -8,14 +8,15 @@ from pathlib import Path from typing import Dict, List, Optional, Union from modelscope.hub.api import HubApi, ModelScopeConfig -from modelscope.utils.constant import DEFAULT_MODEL_REVISION +from modelscope.hub.errors import InvalidParameter +from modelscope.hub.utils.utils import model_id_to_group_owner_name +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, + DEFAULT_MODEL_REVISION, + REPO_TYPE_DATASET, REPO_TYPE_MODEL, + REPO_TYPE_SUPPORT) from modelscope.utils.logger import get_logger -from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS, - MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB) from .file_download import (create_temporary_directory_and_cache, - get_file_download_url, http_get_model_file, - parallel_download) -from .utils.utils import file_integrity_validation + download_file, get_file_download_url) logger = get_logger() @@ -70,14 +71,110 @@ def snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ + return _snapshot_download( + model_id, + repo_type=REPO_TYPE_MODEL, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern, + local_dir=local_dir) + + +def dataset_snapshot_download( + dataset_id: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION, + cache_dir: Union[str, Path, None] = None, + local_dir: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, + ignore_file_pattern: Optional[Union[str, List[str]]] = None, + allow_file_pattern: Optional[Union[str, List[str]]] = None, +) -> str: + """Download raw files of a dataset. + Downloads all files at the specified revision. This + is useful when you want all files from a dataset, because you don't know which + ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. + + An alternative would be to just clone a dataset but this would require that the + user always has git and git-lfs installed, and properly configured. + + Args: + dataset_id (str): A user or an organization name and a dataset name separated by a `/`. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. NOTE: currently only branch and tag name is supported + cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will + be save as cache_dir/dataset_id/THE_DATASET_FILES. + local_dir (str, optional): Specific local directory path to which the file will be downloaded. + user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string. + local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + cookies (CookieJar, optional): The cookie of the request, default None. + ignore_file_pattern (`str` or `List`, *optional*, default to `None`): + Any file pattern to be ignored in downloading, like exact file names or file extensions. + Use regression is deprecated. + allow_file_pattern (`str` or `List`, *optional*, default to `None`): + Any file pattern to be downloading, like exact file names or file extensions. + Raises: + ValueError: the value details. + + Returns: + str: Local folder path (string) of repo snapshot + + Note: + Raises the following errors: + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `use_auth_token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + """ + return _snapshot_download( + dataset_id, + repo_type=REPO_TYPE_DATASET, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern, + local_dir=local_dir) + + +def _snapshot_download( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cache_dir: Union[str, Path, None] = None, + user_agent: Optional[Union[Dict, str]] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, + ignore_file_pattern: Optional[Union[str, List[str]]] = None, + allow_file_pattern: Optional[Union[str, List[str]]] = None, + local_dir: Optional[str] = None, +): + if not repo_type: + repo_type = REPO_TYPE_MODEL + if repo_type not in REPO_TYPE_SUPPORT: + raise InvalidParameter('Invalid repo type: %s, only support: %s' ( + repo_type, REPO_TYPE_SUPPORT)) + temporary_cache_dir, cache = create_temporary_directory_and_cache( - model_id, local_dir, cache_dir) + repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type) if local_files_only: if len(cache.cached_files) == 0: raise ValueError( 'Cannot find the requested files in the cached path and outgoing' - ' traffic has been disabled. To enable model look-ups and downloads' + ' traffic has been disabled. To enable look-ups and downloads' " online, set 'local_files_only' to False.") logger.warning('We can not confirm the cached file is for revision: %s' % revision) @@ -92,27 +189,48 @@ def snapshot_download( _api = HubApi() if cookies is None: cookies = ModelScopeConfig.get_cookies() - revision_detail = _api.get_valid_revision_detail( - model_id, revision=revision, cookies=cookies) - revision = revision_detail['Revision'] + repo_files = [] + if repo_type == REPO_TYPE_MODEL: + revision_detail = _api.get_valid_revision_detail( + repo_id, revision=revision, cookies=cookies) + revision = revision_detail['Revision'] - snapshot_header = headers if 'CI_TEST' in os.environ else { - **headers, - **{ - 'Snapshot': 'True' + snapshot_header = headers if 'CI_TEST' in os.environ else { + **headers, + **{ + 'Snapshot': 'True' + } } - } - if cache.cached_model_revision is not None: - snapshot_header[ - 'cached_model_revision'] = cache.cached_model_revision + if cache.cached_model_revision is not None: + snapshot_header[ + 'cached_model_revision'] = cache.cached_model_revision - model_files = _api.get_model_files( - model_id=model_id, - revision=revision, - recursive=True, - use_cookies=False if cookies is None else cookies, - headers=snapshot_header, - ) + repo_files = _api.get_model_files( + model_id=repo_id, + revision=revision, + recursive=True, + use_cookies=False if cookies is None else cookies, + headers=snapshot_header, + ) + elif repo_type == REPO_TYPE_DATASET: + group_or_owner, name = model_id_to_group_owner_name(repo_id) + if not revision: + revision = DEFAULT_DATASET_REVISION + revision_detail = revision + files_list_tree = _api.list_repo_tree( + dataset_name=name, + namespace=group_or_owner, + revision=revision, + root_path='/', + recursive=True) + if not ('Code' in files_list_tree + and files_list_tree['Code'] == 200): + print( + 'Get dataset: %s file list failed, request_id: %s, message: %s' + % (repo_id, files_list_tree['RequestId'], + files_list_tree['Message'])) + return None + repo_files = files_list_tree['Data']['Files'] if ignore_file_pattern is None: ignore_file_pattern = [] @@ -122,6 +240,12 @@ def snapshot_download( item if not item.endswith('/') else item + '*' for item in ignore_file_pattern ] + ignore_regex_pattern = [] + for file_pattern in ignore_file_pattern: + if file_pattern.startswith('*'): + ignore_regex_pattern.append('.' + file_pattern) + else: + ignore_regex_pattern.append(file_pattern) if allow_file_pattern is not None: if isinstance(allow_file_pattern, str): @@ -131,55 +255,39 @@ def snapshot_download( for item in allow_file_pattern ] - for model_file in model_files: - if model_file['Type'] == 'tree' or \ - any(fnmatch.fnmatch(model_file['Path'], pattern) for pattern in ignore_file_pattern) or \ - any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]): + for repo_file in repo_files: + if repo_file['Type'] == 'tree' or \ + any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \ + any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501 continue if allow_file_pattern is not None and allow_file_pattern: if not any( - fnmatch.fnmatch(model_file['Path'], pattern) + fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in allow_file_pattern): continue # check model_file is exist in cache, if existed, skip download, otherwise download - if cache.exists(model_file): - file_name = os.path.basename(model_file['Name']) + if cache.exists(repo_file): + file_name = os.path.basename(repo_file['Name']) logger.debug( f'File {file_name} already in cache, skip downloading!') continue + if repo_type == REPO_TYPE_MODEL: + # get download url + url = get_file_download_url( + model_id=repo_id, + file_path=repo_file['Path'], + revision=revision) + elif repo_type == REPO_TYPE_DATASET: + url = _api.get_dataset_file_url( + file_name=repo_file['Path'], + dataset_name=name, + namespace=group_or_owner, + revision=revision) - # get download url - url = get_file_download_url( - model_id=model_id, - file_path=model_file['Path'], - revision=revision) - - if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[ - 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: - parallel_download( - url, - temporary_cache_dir, - model_file['Name'], - headers=headers, - cookies=None if cookies is None else cookies.get_dict(), - file_size=model_file['Size']) - else: - http_get_model_file( - url, - temporary_cache_dir, - model_file['Name'], - file_size=model_file['Size'], - headers=headers, - cookies=cookies) - - # check file integrity - temp_file = os.path.join(temporary_cache_dir, model_file['Name']) - if FILE_HASH in model_file: - file_integrity_validation(temp_file, model_file[FILE_HASH]) - # put file into to cache - cache.put_file(model_file, temp_file) + download_file(url, repo_file, temporary_cache_dir, cache, headers, + cookies) cache.save_model_version(revision_info=revision_detail) return os.path.join(cache.get_root_location()) diff --git a/modelscope/hub/utils/caching.py b/modelscope/hub/utils/caching.py index cfa20f07..ecdec7cc 100644 --- a/modelscope/hub/utils/caching.py +++ b/modelscope/hub/utils/caching.py @@ -164,9 +164,12 @@ class ModelFileSystemCache(FileSystemCache): model_version_file_path = os.path.join( self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME) with open(model_version_file_path, 'w') as f: - version_info_str = 'Revision:%s,CreatedAt:%s' % ( - revision_info['Revision'], revision_info['CreatedAt']) - f.write(version_info_str) + if isinstance(revision_info, dict): + version_info_str = 'Revision:%s,CreatedAt:%s' % ( + revision_info['Revision'], revision_info['CreatedAt']) + f.write(version_info_str) + else: + f.write(revision_info) def get_model_id(self): return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID] diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 11be17fb..3570c0cb 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -5,6 +5,8 @@ import enum class Fields(object): """ Names for different application fields """ + hub = 'hub' + datasets = 'datasets' framework = 'framework' cv = 'cv' nlp = 'nlp' @@ -491,6 +493,9 @@ class Frameworks(object): kaldi = 'kaldi' +REPO_TYPE_MODEL = 'model' +REPO_TYPE_DATASET = 'dataset' +REPO_TYPE_SUPPORT = [REPO_TYPE_MODEL, REPO_TYPE_DATASET] DEFAULT_MODEL_REVISION = None MASTER_MODEL_BRANCH = 'master' DEFAULT_REPOSITORY_REVISION = 'master' diff --git a/modelscope/utils/file_utils.py b/modelscope/utils/file_utils.py index 56c32441..c00e8d26 100644 --- a/modelscope/utils/file_utils.py +++ b/modelscope/utils/file_utils.py @@ -53,11 +53,35 @@ def get_model_cache_root() -> str: """Get model cache root path. Returns: - str: the modelscope cache root. + str: the modelscope model cache root. """ return os.path.join(get_modelscope_cache_dir(), 'hub') +def get_dataset_cache_root() -> str: + """Get dataset raw file cache root path. + + Returns: + str: the modelscope dataset raw file cache root. + """ + return os.path.join(get_modelscope_cache_dir(), 'datasets') + + +def get_dataset_cache_dir(dataset_id: str) -> str: + """Get the dataset_id's path. + dataset_cache_root/dataset_id. + + Args: + dataset_id (str): The dataset id. + + Returns: + str: The dataset_id's cache root path. + """ + dataset_root = get_dataset_cache_root() + return dataset_root if dataset_id is None else os.path.join( + dataset_root, dataset_id + '/') + + def get_model_cache_dir(model_id: str) -> str: """cache dir precedence: function parameter > environment > ~/.cache/modelscope/hub/model_id diff --git a/requirements/cv.txt b/requirements/cv.txt index 5935fd91..d54e5dc5 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -25,6 +25,7 @@ imgaug>=0.4.0 kornia>=0.5.0 lmdb lpips +matplotlib>=3.8.0 ml_collections mmcls>=0.21.0 mmdet>=2.25.0,<=2.28.2 @@ -44,10 +45,11 @@ opencv-python paint_ldm pandas panopticapi +Pillow>=6.2.0 plyfile>=0.7.4 psutil pyclipper -PyMCubes +PyMCubes<=0.1.4 pytorch-lightning regex # <0.20.0 for compatible python3.7 python3.8 diff --git a/requirements/datasets.txt b/requirements/datasets.txt new file mode 100644 index 00000000..35924919 --- /dev/null +++ b/requirements/datasets.txt @@ -0,0 +1,12 @@ +addict +attrs +datasets>=2.16.0,<2.19.0 +einops +oss2 +python-dateutil>=2.1 +scipy +# latest version has some compatible issue. +setuptools==69.5.1 +simplejson>=3.3.0 +sortedcontainers>=1.5.9 +urllib3>=1.26 diff --git a/requirements/framework.txt b/requirements/framework.txt index d4428ae9..d6317bf2 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -2,21 +2,12 @@ addict attrs datasets>=2.16.0,<2.19.0 einops -filelock>=3.3.0 -huggingface_hub -numpy oss2 -pandas -Pillow>=6.2.0 -# pyarrow 9.0.0 introduced event_loop core dump -pyarrow>=6.0.0,!=9.0.0 python-dateutil>=2.1 -pyyaml -requests>=2.25 scipy -setuptools +# latest version has some compatible issue. +setuptools==69.5.1 simplejson>=3.3.0 sortedcontainers>=1.5.9 -tqdm>=4.64.0 +transformers urllib3>=1.26 -yapf diff --git a/setup.py b/setup.py index d722ee41..76a6a7bf 100644 --- a/setup.py +++ b/setup.py @@ -196,7 +196,7 @@ if __name__ == '__main__': # add framework dependencies to every field for field, requires in extra_requires.items(): if field not in [ - 'server', 'framework' + 'server', 'framework', 'hub', 'datasets' ]: # server need install model's field dependencies before. extra_requires[field] = framework_requires + extra_requires[field] extra_requires['all'] = all_requires diff --git a/tests/hub/test_download_dataset_file.py b/tests/hub/test_download_dataset_file.py new file mode 100644 index 00000000..0c4e9307 --- /dev/null +++ b/tests/hub/test_download_dataset_file.py @@ -0,0 +1,169 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import tempfile +import time +import unittest + +from modelscope.hub.file_download import dataset_file_download +from modelscope.hub.snapshot_download import dataset_snapshot_download + + +class DownloadDatasetTest(unittest.TestCase): + + def setUp(self): + pass + + def test_dataset_file_download(self): + dataset_id = 'citest/test_dataset_download' + file_path = 'open_qa.jsonl' + deep_file_path = '111/222/333/shijian.jpeg' + start_time = time.time() + + # test download to cache dir. + with tempfile.TemporaryDirectory() as temp_cache_dir: + # first download to cache. + cache_file_path = dataset_file_download( + dataset_id=dataset_id, + file_path=file_path, + cache_dir=temp_cache_dir) + file_modify_time = os.path.getmtime(cache_file_path) + print(cache_file_path) + assert cache_file_path == os.path.join(temp_cache_dir, dataset_id, + file_path) + assert file_modify_time > start_time + # download again, will get cached file. + cache_file_path = dataset_file_download( + dataset_id=dataset_id, + file_path=file_path, + cache_dir=temp_cache_dir) + file_modify_time2 = os.path.getmtime(cache_file_path) + assert file_modify_time == file_modify_time2 + + deep_cache_file_path = dataset_file_download( + dataset_id=dataset_id, + file_path=deep_file_path, + cache_dir=temp_cache_dir) + deep_file_cath_path = os.path.join(temp_cache_dir, dataset_id, + deep_file_path) + assert deep_cache_file_path == deep_file_cath_path + os.path.exists(deep_cache_file_path) + + # test download to local dir + with tempfile.TemporaryDirectory() as temp_local_dir: + # first download to cache. + cache_file_path = dataset_file_download( + dataset_id=dataset_id, + file_path=file_path, + local_dir=temp_local_dir) + assert cache_file_path == os.path.join(temp_local_dir, file_path) + file_modify_time = os.path.getmtime(cache_file_path) + assert file_modify_time > start_time + # download again, will get cached file. + cache_file_path = dataset_file_download( + dataset_id=dataset_id, + file_path=file_path, + local_dir=temp_local_dir) + file_modify_time2 = os.path.getmtime(cache_file_path) + assert file_modify_time == file_modify_time2 + + def test_dataset_snapshot_download(self): + dataset_id = 'citest/test_dataset_download' + file_path = 'open_qa.jsonl' + deep_file_path = '111/222/333/shijian.jpeg' + start_time = time.time() + + # test download to cache dir. + with tempfile.TemporaryDirectory() as temp_cache_dir: + # first download to cache. + dataset_cache_path = dataset_snapshot_download( + dataset_id=dataset_id, cache_dir=temp_cache_dir) + file_modify_time = os.path.getmtime( + os.path.join(dataset_cache_path, file_path)) + assert dataset_cache_path == os.path.join(temp_cache_dir, + dataset_id) + assert file_modify_time > start_time + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, deep_file_path)) + + # download again, will get cached file. + dataset_cache_path2 = dataset_snapshot_download( + dataset_id=dataset_id, cache_dir=temp_cache_dir) + file_modify_time2 = os.path.getmtime( + os.path.join(dataset_cache_path2, file_path)) + assert file_modify_time == file_modify_time2 + + # test download to local dir + with tempfile.TemporaryDirectory() as temp_local_dir: + # first download to cache. + dataset_cache_path = dataset_snapshot_download( + dataset_id=dataset_id, local_dir=temp_local_dir) + # root path is temp_local_dir, file download to local_dir + assert dataset_cache_path == temp_local_dir + file_modify_time = os.path.getmtime( + os.path.join(dataset_cache_path, file_path)) + assert file_modify_time > start_time + # download again, will get cached file. + dataset_cache_path2 = dataset_snapshot_download( + dataset_id=dataset_id, local_dir=temp_local_dir) + file_modify_time2 = os.path.getmtime( + os.path.join(dataset_cache_path2, file_path)) + assert file_modify_time == file_modify_time2 + + # test download with wild pattern, ignore_file_pattern + with tempfile.TemporaryDirectory() as temp_cache_dir: + # first download to cache. + dataset_cache_path = dataset_snapshot_download( + dataset_id=dataset_id, + cache_dir=temp_cache_dir, + ignore_file_pattern='*.jpeg') + assert dataset_cache_path == os.path.join(temp_cache_dir, + dataset_id) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, deep_file_path)) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg')) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, + '111/222/shijian.jpeg')) + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, file_path)) + + # test download with wild pattern, allow_file_pattern + with tempfile.TemporaryDirectory() as temp_cache_dir: + # first download to cache. + dataset_cache_path = dataset_snapshot_download( + dataset_id=dataset_id, + cache_dir=temp_cache_dir, + allow_file_pattern='*.jpeg') + assert dataset_cache_path == os.path.join(temp_cache_dir, + dataset_id) + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, deep_file_path)) + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg')) + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, + '111/222/shijian.jpeg')) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, file_path)) + + # test download with wild pattern, allow_file_pattern and ignore file pattern. + with tempfile.TemporaryDirectory() as temp_cache_dir: + # first download to cache. + dataset_cache_path = dataset_snapshot_download( + dataset_id=dataset_id, + cache_dir=temp_cache_dir, + ignore_file_pattern='*.jpeg', + allow_file_pattern='*.xxx') + assert dataset_cache_path == os.path.join(temp_cache_dir, + dataset_id) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, deep_file_path)) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg')) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, + '111/222/shijian.jpeg')) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, file_path)) diff --git a/tests/hub/test_hub_retry.py b/tests/hub/test_hub_retry.py index 87f209cf..beab2f1e 100644 --- a/tests/hub/test_hub_retry.py +++ b/tests/hub/test_hub_retry.py @@ -12,7 +12,7 @@ from modelscope.hub.api import HubApi from modelscope.hub.file_download import http_get_model_file -class HubOperationTest(unittest.TestCase): +class HubRetryTest(unittest.TestCase): def setUp(self): self.api = HubApi() @@ -56,6 +56,8 @@ class HubOperationTest(unittest.TestCase): rsp.msg = HTTPMessage() rsp.read = get_content rsp.chunked = False + rsp.length_remaining = 0 + rsp.headers = {} # retry 2 times and success. getconn_mock.return_value.getresponse.side_effect = [ Mock(status=500, msg=HTTPMessage()), @@ -88,16 +90,18 @@ class HubOperationTest(unittest.TestCase): success_rsp = HTTPResponse(getconn_mock) success_rsp.status = 200 success_rsp.msg = HTTPMessage() - success_rsp.msg.add_header('Content-Length', '2957783') success_rsp.read = get_content success_rsp.chunked = True + success_rsp.length_remaining = 0 + success_rsp.headers = {'Content-Length': '2957783'} failed_rsp = HTTPResponse(getconn_mock) failed_rsp.status = 502 failed_rsp.msg = HTTPMessage() - failed_rsp.msg.add_header('Content-Length', '2957783') failed_rsp.read = get_content failed_rsp.chunked = True + success_rsp.length_remaining = 2957783 + success_rsp.headers = {'Content-Length': '2957783'} # retry 5 times and success. getconn_mock.return_value.getresponse.side_effect = [