From df81f3de959f890e2b72741780affd63d75d96e7 Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Fri, 12 Jul 2024 14:29:23 +0800 Subject: [PATCH] refactor code --- modelscope/__init__.py | 10 +- modelscope/hub/dataset_download.py | 278 ------------------------ modelscope/hub/file_download.py | 185 +++++++++++++--- modelscope/hub/snapshot_download.py | 200 ++++++++++++++--- modelscope/utils/constant.py | 3 + tests/hub/test_download_dataset_file.py | 4 +- 6 files changed, 329 insertions(+), 351 deletions(-) delete mode 100644 modelscope/hub/dataset_download.py diff --git a/modelscope/__init__.py b/modelscope/__init__.py index b5e3cac8..34a9bfc2 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -10,8 +10,8 @@ if TYPE_CHECKING: from .hub.api import HubApi from .hub.check_model import check_local_model_is_latest, check_model_is_id from .hub.push_to_hub import push_to_hub, push_to_hub_async - from .hub.snapshot_download import snapshot_download - from .hub.dataset_download import dataset_snapshot_download, dataset_file_download + from .hub.snapshot_download import snapshot_download, dataset_snapshot_download + from .hub.file_download import model_file_download, dataset_file_download from .metrics import ( AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric, @@ -61,9 +61,9 @@ else: 'TorchModelExporter', ], 'hub.api': ['HubApi'], - 'hub.snapshot_download': ['snapshot_download'], - 'hub.dataset_download': - ['dataset_snapshot_download', 'dataset_file_download'], + 'hub.snapshot_download': + ['snapshot_download', 'dataset_snapshot_download'], + 'hub.file_download': ['model_file_download', 'dataset_file_download'], 'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'], 'hub.check_model': ['check_model_is_id', 'check_local_model_is_latest'], diff --git a/modelscope/hub/dataset_download.py b/modelscope/hub/dataset_download.py deleted file mode 100644 index 413b900b..00000000 --- a/modelscope/hub/dataset_download.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import fnmatch -import os -from http.cookiejar import CookieJar -from pathlib import Path -from typing import Dict, List, Optional, Union - -from modelscope.hub.api import HubApi, ModelScopeConfig -from modelscope.hub.errors import NotExistError -from modelscope.utils.constant import DEFAULT_DATASET_REVISION -from modelscope.utils.file_utils import get_dataset_cache_root -from modelscope.utils.logger import get_logger -from .file_download import create_temporary_directory_and_cache, download_file -from .utils.utils import model_id_to_group_owner_name - -logger = get_logger() - - -def dataset_file_download( - dataset_id: str, - file_path: str, - revision: Optional[str] = DEFAULT_DATASET_REVISION, - cache_dir: Union[str, Path, None] = None, - local_dir: Optional[str] = None, - user_agent: Optional[Union[Dict, str]] = None, - local_files_only: Optional[bool] = False, - cookies: Optional[CookieJar] = None, -) -> str: - """Download raw files of a dataset. - Downloads all files at the specified revision. This - is useful when you want all files from a dataset, because you don't know which - ones you will need a priori. All files are nested inside a folder in order - to keep their actual filename relative to that folder. - - An alternative would be to just clone a dataset but this would require that the - user always has git and git-lfs installed, and properly configured. - - Args: - dataset_id (str): A user or an organization name and a dataset name separated by a `/`. - file_path (str): The relative path of the file to download. - revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a - commit hash. NOTE: currently only branch and tag name is supported - cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will - be save as cache_dir/dataset_id/THE_DATASET_FILES. - local_dir (str, optional): Specific local directory path to which the file will be downloaded. - user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string. - local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the - local cached file if it exists. - cookies (CookieJar, optional): The cookie of the request, default None. - Raises: - ValueError: the value details. - - Returns: - str: Local folder path (string) of repo snapshot - - Note: - Raises the following errors: - - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - if `use_auth_token=True` and the token cannot be found. - - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if - ETag cannot be determined. - - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) - if some parameter value is invalid - """ - temporary_cache_dir, cache = create_temporary_directory_and_cache( - dataset_id, - local_dir=local_dir, - cache_dir=cache_dir, - default_cache_root=get_dataset_cache_root()) - - if local_files_only: - cached_file_path = cache.get_file_by_path(file_path) - if cached_file_path is not None: - logger.warning( - "File exists in local cache, but we're not sure it's up to date" - ) - return cached_file_path - else: - raise ValueError( - 'Cannot find the requested files in the cached path and outgoing' - ' traffic has been disabled. To enable dataset look-ups and downloads' - " online, set 'local_files_only' to False.") - else: - # make headers - headers = { - 'user-agent': - ModelScopeConfig.get_user_agent(user_agent=user_agent, ) - } - _api = HubApi() - if cookies is None: - cookies = ModelScopeConfig.get_cookies() - group_or_owner, name = model_id_to_group_owner_name(dataset_id) - if not revision: - revision = DEFAULT_DATASET_REVISION - files_list_tree = _api.list_repo_tree( - dataset_name=name, - namespace=group_or_owner, - revision=revision, - root_path='/', - recursive=True) - if not ('Code' in files_list_tree and files_list_tree['Code'] == 200): - print( - 'Get dataset: %s file list failed, request_id: %s, message: %s' - % (dataset_id, files_list_tree['RequestId'], - files_list_tree['Message'])) - return None - file_meta_to_download = None - # find the file to download. - for file_meta in files_list_tree['Data']['Files']: - if file_meta['Type'] == 'tree': - continue - - if fnmatch.fnmatch(file_meta['Path'], file_path): - # check file is exist in cache, if existed, skip download, otherwise download - if cache.exists(file_meta): - file_name = os.path.basename(file_meta['Name']) - logger.debug( - f'File {file_name} already in cache, skip downloading!' - ) - return cache.get_file_by_info(file_meta) - else: - file_meta_to_download = file_meta - break - if file_meta_to_download is None: - raise NotExistError('The file path: %s not exist in: %s' % - (file_path, dataset_id)) - - # start download file. - # get download url - url = _api.get_dataset_file_url( - file_name=file_meta_to_download['Path'], - dataset_name=name, - namespace=group_or_owner, - revision=revision) - - return download_file(url, file_meta_to_download, temporary_cache_dir, - cache, headers, cookies) - - -def dataset_snapshot_download( - dataset_id: str, - revision: Optional[str] = DEFAULT_DATASET_REVISION, - cache_dir: Union[str, Path, None] = None, - local_dir: Optional[str] = None, - user_agent: Optional[Union[Dict, str]] = None, - local_files_only: Optional[bool] = False, - cookies: Optional[CookieJar] = None, - ignore_file_pattern: Optional[Union[str, List[str]]] = None, - allow_file_pattern: Optional[Union[str, List[str]]] = None, -) -> str: - """Download raw files of a dataset. - Downloads all files at the specified revision. This - is useful when you want all files from a dataset, because you don't know which - ones you will need a priori. All files are nested inside a folder in order - to keep their actual filename relative to that folder. - - An alternative would be to just clone a dataset but this would require that the - user always has git and git-lfs installed, and properly configured. - - Args: - dataset_id (str): A user or an organization name and a dataset name separated by a `/`. - revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a - commit hash. NOTE: currently only branch and tag name is supported - cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will - be save as cache_dir/dataset_id/THE_DATASET_FILES. - local_dir (str, optional): Specific local directory path to which the file will be downloaded. - user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string. - local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the - local cached file if it exists. - cookies (CookieJar, optional): The cookie of the request, default None. - ignore_file_pattern (`str` or `List`, *optional*, default to `None`): - Any file pattern to be ignored in downloading, like exact file names or file extensions. - allow_file_pattern (`str` or `List`, *optional*, default to `None`): - Any file pattern to be downloading, like exact file names or file extensions. - Raises: - ValueError: the value details. - - Returns: - str: Local folder path (string) of repo snapshot - - Note: - Raises the following errors: - - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - if `use_auth_token=True` and the token cannot be found. - - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if - ETag cannot be determined. - - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) - if some parameter value is invalid - """ - temporary_cache_dir, cache = create_temporary_directory_and_cache( - dataset_id, - local_dir=local_dir, - cache_dir=cache_dir, - default_cache_root=get_dataset_cache_root()) - - if local_files_only: - if len(cache.cached_files) == 0: - raise ValueError( - 'Cannot find the requested files in the cached path and outgoing' - ' traffic has been disabled. To enable dataset look-ups and downloads' - " online, set 'local_files_only' to False.") - logger.warning('We can not confirm the cached file is for revision: %s' - % revision) - return cache.get_root_location( - ) # we can not confirm the cached file is for snapshot 'revision' - else: - # make headers - headers = { - 'user-agent': - ModelScopeConfig.get_user_agent(user_agent=user_agent, ) - } - _api = HubApi() - if cookies is None: - cookies = ModelScopeConfig.get_cookies() - group_or_owner, name = model_id_to_group_owner_name(dataset_id) - if not revision: - revision = DEFAULT_DATASET_REVISION - files_list_tree = _api.list_repo_tree( - dataset_name=name, - namespace=group_or_owner, - revision=revision, - root_path='/', - recursive=True) - if not ('Code' in files_list_tree and files_list_tree['Code'] == 200): - print( - 'Get dataset: %s file list failed, request_id: %s, message: %s' - % (dataset_id, files_list_tree['RequestId'], - files_list_tree['Message'])) - return None - - if ignore_file_pattern is None: - ignore_file_pattern = [] - if isinstance(ignore_file_pattern, str): - ignore_file_pattern = [ignore_file_pattern] - ignore_file_pattern = [ - item if not item.endswith('/') else item + '*' - for item in ignore_file_pattern - ] - - if allow_file_pattern is not None: - if isinstance(allow_file_pattern, str): - allow_file_pattern = [allow_file_pattern] - allow_file_pattern = [ - item if not item.endswith('/') else item + '*' - for item in allow_file_pattern - ] - - for file_meta in files_list_tree['Data']['Files']: - if file_meta['Type'] == 'tree' or \ - any(fnmatch.fnmatch(file_meta['Path'], pattern) for pattern in ignore_file_pattern): - continue - - if allow_file_pattern is not None and allow_file_pattern: - if not any( - fnmatch.fnmatch(file_meta['Path'], pattern) - for pattern in allow_file_pattern): - continue - - # check file is exist in cache, if existed, skip download, otherwise download - if cache.exists(file_meta): - file_name = os.path.basename(file_meta['Name']) - logger.debug( - f'File {file_name} already in cache, skip downloading!') - continue - - # get download url - url = _api.get_dataset_file_url( - file_name=file_meta['Path'], - dataset_name=name, - namespace=group_or_owner, - revision=revision) - - download_file(url, file_meta, temporary_cache_dir, cache, headers, - cookies) - - cache.save_model_version(revision_info=revision) - return os.path.join(cache.get_root_location()) diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index dce4a7aa..29202404 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -21,10 +21,14 @@ from modelscope.hub.constants import ( API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES, API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS, MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME) -from modelscope.utils.constant import DEFAULT_MODEL_REVISION -from modelscope.utils.file_utils import get_model_cache_root +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, + DEFAULT_MODEL_REVISION, + REPO_TYPE_DATASET, REPO_TYPE_MODEL, + REPO_TYPE_SUPPORT) +from modelscope.utils.file_utils import (get_dataset_cache_root, + get_model_cache_root) from modelscope.utils.logger import get_logger -from .errors import FileDownloadError, NotExistError +from .errors import FileDownloadError, InvalidParameter, NotExistError from .utils.caching import ModelFileSystemCache from .utils.utils import (file_integrity_validation, get_endpoint, model_id_to_group_owner_name) @@ -78,8 +82,97 @@ def model_file_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ + return _repo_file_download( + model_id, + file_path, + repo_type=REPO_TYPE_MODEL, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + local_dir=local_dir) + + +def dataset_file_download( + dataset_id: str, + file_path: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION, + cache_dir: Union[str, Path, None] = None, + local_dir: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, +) -> str: + """Download raw files of a dataset. + Downloads all files at the specified revision. This + is useful when you want all files from a dataset, because you don't know which + ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. + + An alternative would be to just clone a dataset but this would require that the + user always has git and git-lfs installed, and properly configured. + + Args: + dataset_id (str): A user or an organization name and a dataset name separated by a `/`. + file_path (str): The relative path of the file to download. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. NOTE: currently only branch and tag name is supported + cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will + be save as cache_dir/dataset_id/THE_DATASET_FILES. + local_dir (str, optional): Specific local directory path to which the file will be downloaded. + user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string. + local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + cookies (CookieJar, optional): The cookie of the request, default None. + Raises: + ValueError: the value details. + + Returns: + str: Local folder path (string) of repo snapshot + + Note: + Raises the following errors: + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `use_auth_token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + """ + return _repo_file_download( + dataset_id, + file_path, + repo_type=REPO_TYPE_DATASET, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + local_dir=local_dir) + + +def _repo_file_download( + repo_id: str, + file_path: str, + *, + repo_type: str = None, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cache_dir: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, + local_dir: Optional[str] = None, +) -> Optional[str]: # pragma: no cover + + if not repo_type: + repo_type = REPO_TYPE_MODEL + if repo_type not in REPO_TYPE_SUPPORT: + raise InvalidParameter('Invalid repo type: %s, only support: %s' ( + repo_type, REPO_TYPE_SUPPORT)) + temporary_cache_dir, cache = create_temporary_directory_and_cache( - model_id, local_dir, cache_dir, get_model_cache_root()) + repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type) # if local_files_only is `True` and the file already exists in cached_path # return the cached path @@ -93,7 +186,7 @@ def model_file_download( else: raise ValueError( 'Cannot find the requested files in the cached path and outgoing' - ' traffic has been disabled. To enable model look-ups and downloads' + ' traffic has been disabled. To enable look-ups and downloads' " online, set 'local_files_only' to False.") _api = HubApi() @@ -102,45 +195,77 @@ def model_file_download( } if cookies is None: cookies = ModelScopeConfig.get_cookies() + repo_files = [] + if repo_type == REPO_TYPE_MODEL: + revision = _api.get_valid_revision( + repo_id, revision=revision, cookies=cookies) + file_to_download_meta = None + # we need to confirm the version is up-to-date + # we need to get the file list to check if the latest version is cached, if so return, otherwise download + repo_files = _api.get_model_files( + model_id=repo_id, + revision=revision, + recursive=True, + use_cookies=False if cookies is None else cookies) + elif repo_type == REPO_TYPE_DATASET: + group_or_owner, name = model_id_to_group_owner_name(repo_id) + if not revision: + revision = DEFAULT_DATASET_REVISION + files_list_tree = _api.list_repo_tree( + dataset_name=name, + namespace=group_or_owner, + revision=revision, + root_path='/', + recursive=True) + if not ('Code' in files_list_tree and files_list_tree['Code'] == 200): + print( + 'Get dataset: %s file list failed, request_id: %s, message: %s' + % (repo_id, files_list_tree['RequestId'], + files_list_tree['Message'])) + return None + repo_files = files_list_tree['Data']['Files'] - revision = _api.get_valid_revision( - model_id, revision=revision, cookies=cookies) - file_to_download_info = None - # we need to confirm the version is up-to-date - # we need to get the file list to check if the latest version is cached, if so return, otherwise download - model_files = _api.get_model_files( - model_id=model_id, - revision=revision, - recursive=True, - use_cookies=False if cookies is None else cookies) - - for model_file in model_files: - if model_file['Type'] == 'tree': + file_to_download_meta = None + for repo_file in repo_files: + if repo_file['Type'] == 'tree': continue - if model_file['Path'] == file_path: - if cache.exists(model_file): + if repo_file['Path'] == file_path: + if cache.exists(repo_file): logger.debug( - f'File {model_file["Name"]} already in cache, skip downloading!' + f'File {repo_file["Name"]} already in cache, skip downloading!' ) - return cache.get_file_by_info(model_file) + return cache.get_file_by_info(repo_file) else: - file_to_download_info = model_file + file_to_download_meta = repo_file break - if file_to_download_info is None: + if file_to_download_meta is None: raise NotExistError('The file path: %s not exist in: %s' % - (file_path, model_id)) + (file_path, repo_id)) # we need to download again - url_to_download = get_file_download_url(model_id, file_path, revision) - return download_file(url_to_download, file_to_download_info, + if repo_type == REPO_TYPE_MODEL: + url_to_download = get_file_download_url(repo_id, file_path, revision) + elif repo_type == REPO_TYPE_DATASET: + url_to_download = _api.get_dataset_file_url( + file_name=file_to_download_meta['Path'], + dataset_name=name, + namespace=group_or_owner, + revision=revision) + return download_file(url_to_download, file_to_download_meta, temporary_cache_dir, cache, headers, cookies) -def create_temporary_directory_and_cache(model_id: str, local_dir: str, - cache_dir: str, - default_cache_root: str): +def create_temporary_directory_and_cache(model_id: str, + local_dir: str = None, + cache_dir: str = None, + repo_type: str = REPO_TYPE_MODEL): + if repo_type == REPO_TYPE_MODEL: + default_cache_root = get_model_cache_root() + elif repo_type == REPO_TYPE_DATASET: + default_cache_root = get_dataset_cache_root() + group_or_owner, name = model_id_to_group_owner_name(model_id) if local_dir is not None: temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 1ede1521..e4bf53b8 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -8,8 +8,14 @@ from pathlib import Path from typing import Dict, List, Optional, Union from modelscope.hub.api import HubApi, ModelScopeConfig -from modelscope.utils.constant import DEFAULT_MODEL_REVISION -from modelscope.utils.file_utils import get_model_cache_root +from modelscope.hub.errors import InvalidParameter +from modelscope.hub.utils.utils import model_id_to_group_owner_name +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, + DEFAULT_MODEL_REVISION, + REPO_TYPE_DATASET, REPO_TYPE_MODEL, + REPO_TYPE_SUPPORT) +from modelscope.utils.file_utils import (get_dataset_cache_root, + get_model_cache_root) from modelscope.utils.logger import get_logger from .file_download import (create_temporary_directory_and_cache, download_file, get_file_download_url) @@ -67,14 +73,109 @@ def snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ + return _snapshot_download( + model_id, + repo_type=REPO_TYPE_MODEL, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern, + local_dir=local_dir) + + +def dataset_snapshot_download( + dataset_id: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION, + cache_dir: Union[str, Path, None] = None, + local_dir: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, + ignore_file_pattern: Optional[Union[str, List[str]]] = None, + allow_file_pattern: Optional[Union[str, List[str]]] = None, +) -> str: + """Download raw files of a dataset. + Downloads all files at the specified revision. This + is useful when you want all files from a dataset, because you don't know which + ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. + + An alternative would be to just clone a dataset but this would require that the + user always has git and git-lfs installed, and properly configured. + + Args: + dataset_id (str): A user or an organization name and a dataset name separated by a `/`. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. NOTE: currently only branch and tag name is supported + cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will + be save as cache_dir/dataset_id/THE_DATASET_FILES. + local_dir (str, optional): Specific local directory path to which the file will be downloaded. + user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string. + local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + cookies (CookieJar, optional): The cookie of the request, default None. + ignore_file_pattern (`str` or `List`, *optional*, default to `None`): + Any file pattern to be ignored in downloading, like exact file names or file extensions. + allow_file_pattern (`str` or `List`, *optional*, default to `None`): + Any file pattern to be downloading, like exact file names or file extensions. + Raises: + ValueError: the value details. + + Returns: + str: Local folder path (string) of repo snapshot + + Note: + Raises the following errors: + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `use_auth_token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + """ + return _snapshot_download( + dataset_id, + repo_type=REPO_TYPE_DATASET, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern, + local_dir=local_dir) + + +def _snapshot_download( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cache_dir: Union[str, Path, None] = None, + user_agent: Optional[Union[Dict, str]] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, + ignore_file_pattern: Optional[Union[str, List[str]]] = None, + allow_file_pattern: Optional[Union[str, List[str]]] = None, + local_dir: Optional[str] = None, +): + if not repo_type: + repo_type = REPO_TYPE_MODEL + if repo_type not in REPO_TYPE_SUPPORT: + raise InvalidParameter('Invalid repo type: %s, only support: %s' ( + repo_type, REPO_TYPE_SUPPORT)) + temporary_cache_dir, cache = create_temporary_directory_and_cache( - model_id, local_dir, cache_dir, get_model_cache_root()) + repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type) if local_files_only: if len(cache.cached_files) == 0: raise ValueError( 'Cannot find the requested files in the cached path and outgoing' - ' traffic has been disabled. To enable model look-ups and downloads' + ' traffic has been disabled. To enable look-ups and downloads' " online, set 'local_files_only' to False.") logger.warning('We can not confirm the cached file is for revision: %s' % revision) @@ -89,27 +190,48 @@ def snapshot_download( _api = HubApi() if cookies is None: cookies = ModelScopeConfig.get_cookies() - revision_detail = _api.get_valid_revision_detail( - model_id, revision=revision, cookies=cookies) - revision = revision_detail['Revision'] + repo_files = [] + if repo_type == REPO_TYPE_MODEL: + revision_detail = _api.get_valid_revision_detail( + repo_id, revision=revision, cookies=cookies) + revision = revision_detail['Revision'] - snapshot_header = headers if 'CI_TEST' in os.environ else { - **headers, - **{ - 'Snapshot': 'True' + snapshot_header = headers if 'CI_TEST' in os.environ else { + **headers, + **{ + 'Snapshot': 'True' + } } - } - if cache.cached_model_revision is not None: - snapshot_header[ - 'cached_model_revision'] = cache.cached_model_revision + if cache.cached_model_revision is not None: + snapshot_header[ + 'cached_model_revision'] = cache.cached_model_revision - model_files = _api.get_model_files( - model_id=model_id, - revision=revision, - recursive=True, - use_cookies=False if cookies is None else cookies, - headers=snapshot_header, - ) + repo_files = _api.get_model_files( + model_id=repo_id, + revision=revision, + recursive=True, + use_cookies=False if cookies is None else cookies, + headers=snapshot_header, + ) + elif repo_type == REPO_TYPE_DATASET: + group_or_owner, name = model_id_to_group_owner_name(repo_id) + if not revision: + revision = DEFAULT_DATASET_REVISION + revision_detail = revision + files_list_tree = _api.list_repo_tree( + dataset_name=name, + namespace=group_or_owner, + revision=revision, + root_path='/', + recursive=True) + if not ('Code' in files_list_tree + and files_list_tree['Code'] == 200): + print( + 'Get dataset: %s file list failed, request_id: %s, message: %s' + % (repo_id, files_list_tree['RequestId'], + files_list_tree['Message'])) + return None + repo_files = files_list_tree['Data']['Files'] if ignore_file_pattern is None: ignore_file_pattern = [] @@ -128,32 +250,38 @@ def snapshot_download( for item in allow_file_pattern ] - for model_file in model_files: - if model_file['Type'] == 'tree' or \ - any(fnmatch.fnmatch(model_file['Path'], pattern) for pattern in ignore_file_pattern) or \ - any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]): + for repo_file in repo_files: + if repo_file['Type'] == 'tree' or \ + any(fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern) or \ + any([re.search(re.escape(pattern), repo_file['Name']) is not None for pattern in ignore_file_pattern]): # noqa E501 continue if allow_file_pattern is not None and allow_file_pattern: if not any( - fnmatch.fnmatch(model_file['Path'], pattern) + fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in allow_file_pattern): continue # check model_file is exist in cache, if existed, skip download, otherwise download - if cache.exists(model_file): - file_name = os.path.basename(model_file['Name']) + if cache.exists(repo_file): + file_name = os.path.basename(repo_file['Name']) logger.debug( f'File {file_name} already in cache, skip downloading!') continue + if repo_type == REPO_TYPE_MODEL: + # get download url + url = get_file_download_url( + model_id=repo_id, + file_path=repo_file['Path'], + revision=revision) + elif repo_type == REPO_TYPE_DATASET: + url = _api.get_dataset_file_url( + file_name=repo_file['Path'], + dataset_name=name, + namespace=group_or_owner, + revision=revision) - # get download url - url = get_file_download_url( - model_id=model_id, - file_path=model_file['Path'], - revision=revision) - - download_file(url, model_file, temporary_cache_dir, cache, headers, + download_file(url, repo_file, temporary_cache_dir, cache, headers, cookies) cache.save_model_version(revision_info=revision_detail) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 1b10d771..3570c0cb 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -493,6 +493,9 @@ class Frameworks(object): kaldi = 'kaldi' +REPO_TYPE_MODEL = 'model' +REPO_TYPE_DATASET = 'dataset' +REPO_TYPE_SUPPORT = [REPO_TYPE_MODEL, REPO_TYPE_DATASET] DEFAULT_MODEL_REVISION = None MASTER_MODEL_BRANCH = 'master' DEFAULT_REPOSITORY_REVISION = 'master' diff --git a/tests/hub/test_download_dataset_file.py b/tests/hub/test_download_dataset_file.py index 49d8c238..0c4e9307 100644 --- a/tests/hub/test_download_dataset_file.py +++ b/tests/hub/test_download_dataset_file.py @@ -5,8 +5,8 @@ import tempfile import time import unittest -from modelscope.hub.dataset_download import (dataset_file_download, - dataset_snapshot_download) +from modelscope.hub.file_download import dataset_file_download +from modelscope.hub.snapshot_download import dataset_snapshot_download class DownloadDatasetTest(unittest.TestCase):