mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
refactor code
This commit is contained in:
@@ -10,8 +10,8 @@ if TYPE_CHECKING:
|
||||
from .hub.api import HubApi
|
||||
from .hub.check_model import check_local_model_is_latest, check_model_is_id
|
||||
from .hub.push_to_hub import push_to_hub, push_to_hub_async
|
||||
from .hub.snapshot_download import snapshot_download
|
||||
from .hub.dataset_download import dataset_snapshot_download, dataset_file_download
|
||||
from .hub.snapshot_download import snapshot_download, dataset_snapshot_download
|
||||
from .hub.file_download import model_file_download, dataset_file_download
|
||||
|
||||
from .metrics import (
|
||||
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
|
||||
@@ -61,9 +61,9 @@ else:
|
||||
'TorchModelExporter',
|
||||
],
|
||||
'hub.api': ['HubApi'],
|
||||
'hub.snapshot_download': ['snapshot_download'],
|
||||
'hub.dataset_download':
|
||||
['dataset_snapshot_download', 'dataset_file_download'],
|
||||
'hub.snapshot_download':
|
||||
['snapshot_download', 'dataset_snapshot_download'],
|
||||
'hub.file_download': ['model_file_download', 'dataset_file_download'],
|
||||
'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'],
|
||||
'hub.check_model':
|
||||
['check_model_is_id', 'check_local_model_is_latest'],
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import fnmatch
|
||||
import os
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.errors import NotExistError
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
|
||||
from modelscope.utils.file_utils import get_dataset_cache_root
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .file_download import create_temporary_directory_and_cache, download_file
|
||||
from .utils.utils import model_id_to_group_owner_name
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def dataset_file_download(
|
||||
dataset_id: str,
|
||||
file_path: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
is useful when you want all files from a dataset, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a dataset but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
||||
file_path (str): The relative path of the file to download.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
|
||||
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
dataset_id,
|
||||
local_dir=local_dir,
|
||||
cache_dir=cache_dir,
|
||||
default_cache_root=get_dataset_cache_root())
|
||||
|
||||
if local_files_only:
|
||||
cached_file_path = cache.get_file_by_path(file_path)
|
||||
if cached_file_path is not None:
|
||||
logger.warning(
|
||||
"File exists in local cache, but we're not sure it's up to date"
|
||||
)
|
||||
return cached_file_path
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable dataset look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
else:
|
||||
# make headers
|
||||
headers = {
|
||||
'user-agent':
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
}
|
||||
_api = HubApi()
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
group_or_owner, name = model_id_to_group_owner_name(dataset_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (dataset_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
file_meta_to_download = None
|
||||
# find the file to download.
|
||||
for file_meta in files_list_tree['Data']['Files']:
|
||||
if file_meta['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if fnmatch.fnmatch(file_meta['Path'], file_path):
|
||||
# check file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(file_meta):
|
||||
file_name = os.path.basename(file_meta['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(file_meta)
|
||||
else:
|
||||
file_meta_to_download = file_meta
|
||||
break
|
||||
if file_meta_to_download is None:
|
||||
raise NotExistError('The file path: %s not exist in: %s' %
|
||||
(file_path, dataset_id))
|
||||
|
||||
# start download file.
|
||||
# get download url
|
||||
url = _api.get_dataset_file_url(
|
||||
file_name=file_meta_to_download['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
|
||||
return download_file(url, file_meta_to_download, temporary_cache_dir,
|
||||
cache, headers, cookies)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
dataset_id: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
is useful when you want all files from a dataset, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a dataset but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
|
||||
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be ignored in downloading, like exact file names or file extensions.
|
||||
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be downloading, like exact file names or file extensions.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
dataset_id,
|
||||
local_dir=local_dir,
|
||||
cache_dir=cache_dir,
|
||||
default_cache_root=get_dataset_cache_root())
|
||||
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable dataset look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
logger.warning('We can not confirm the cached file is for revision: %s'
|
||||
% revision)
|
||||
return cache.get_root_location(
|
||||
) # we can not confirm the cached file is for snapshot 'revision'
|
||||
else:
|
||||
# make headers
|
||||
headers = {
|
||||
'user-agent':
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
}
|
||||
_api = HubApi()
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
group_or_owner, name = model_id_to_group_owner_name(dataset_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (dataset_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
|
||||
if ignore_file_pattern is None:
|
||||
ignore_file_pattern = []
|
||||
if isinstance(ignore_file_pattern, str):
|
||||
ignore_file_pattern = [ignore_file_pattern]
|
||||
ignore_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in ignore_file_pattern
|
||||
]
|
||||
|
||||
if allow_file_pattern is not None:
|
||||
if isinstance(allow_file_pattern, str):
|
||||
allow_file_pattern = [allow_file_pattern]
|
||||
allow_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in allow_file_pattern
|
||||
]
|
||||
|
||||
for file_meta in files_list_tree['Data']['Files']:
|
||||
if file_meta['Type'] == 'tree' or \
|
||||
any(fnmatch.fnmatch(file_meta['Path'], pattern) for pattern in ignore_file_pattern):
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(file_meta['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
|
||||
# check file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(file_meta):
|
||||
file_name = os.path.basename(file_meta['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
|
||||
# get download url
|
||||
url = _api.get_dataset_file_url(
|
||||
file_name=file_meta['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
|
||||
download_file(url, file_meta, temporary_cache_dir, cache, headers,
|
||||
cookies)
|
||||
|
||||
cache.save_model_version(revision_info=revision)
|
||||
return os.path.join(cache.get_root_location())
|
||||
@@ -21,10 +21,14 @@ from modelscope.hub.constants import (
|
||||
API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME)
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.file_utils import get_model_cache_root
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
||||
REPO_TYPE_SUPPORT)
|
||||
from modelscope.utils.file_utils import (get_dataset_cache_root,
|
||||
get_model_cache_root)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .errors import FileDownloadError, NotExistError
|
||||
from .errors import FileDownloadError, InvalidParameter, NotExistError
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||
model_id_to_group_owner_name)
|
||||
@@ -78,8 +82,97 @@ def model_file_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
return _repo_file_download(
|
||||
model_id,
|
||||
file_path,
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
local_dir=local_dir)
|
||||
|
||||
|
||||
def dataset_file_download(
|
||||
dataset_id: str,
|
||||
file_path: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
is useful when you want all files from a dataset, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a dataset but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
||||
file_path (str): The relative path of the file to download.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
|
||||
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
return _repo_file_download(
|
||||
dataset_id,
|
||||
file_path,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
local_dir=local_dir)
|
||||
|
||||
|
||||
def _repo_file_download(
|
||||
repo_id: str,
|
||||
file_path: str,
|
||||
*,
|
||||
repo_type: str = None,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cache_dir: Optional[str] = None,
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
|
||||
repo_type, REPO_TYPE_SUPPORT))
|
||||
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
model_id, local_dir, cache_dir, get_model_cache_root())
|
||||
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
|
||||
|
||||
# if local_files_only is `True` and the file already exists in cached_path
|
||||
# return the cached path
|
||||
@@ -93,7 +186,7 @@ def model_file_download(
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
' traffic has been disabled. To enable look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
|
||||
_api = HubApi()
|
||||
@@ -102,45 +195,77 @@ def model_file_download(
|
||||
}
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
repo_files = []
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
revision = _api.get_valid_revision(
|
||||
repo_id, revision=revision, cookies=cookies)
|
||||
file_to_download_meta = None
|
||||
# we need to confirm the version is up-to-date
|
||||
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
|
||||
repo_files = _api.get_model_files(
|
||||
model_id=repo_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
|
||||
revision = _api.get_valid_revision(
|
||||
model_id, revision=revision, cookies=cookies)
|
||||
file_to_download_info = None
|
||||
# we need to confirm the version is up-to-date
|
||||
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
file_to_download_meta = None
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if model_file['Path'] == file_path:
|
||||
if cache.exists(model_file):
|
||||
if repo_file['Path'] == file_path:
|
||||
if cache.exists(repo_file):
|
||||
logger.debug(
|
||||
f'File {model_file["Name"]} already in cache, skip downloading!'
|
||||
f'File {repo_file["Name"]} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(model_file)
|
||||
return cache.get_file_by_info(repo_file)
|
||||
else:
|
||||
file_to_download_info = model_file
|
||||
file_to_download_meta = repo_file
|
||||
break
|
||||
|
||||
if file_to_download_info is None:
|
||||
if file_to_download_meta is None:
|
||||
raise NotExistError('The file path: %s not exist in: %s' %
|
||||
(file_path, model_id))
|
||||
(file_path, repo_id))
|
||||
|
||||
# we need to download again
|
||||
url_to_download = get_file_download_url(model_id, file_path, revision)
|
||||
return download_file(url_to_download, file_to_download_info,
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
url_to_download = get_file_download_url(repo_id, file_path, revision)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
url_to_download = _api.get_dataset_file_url(
|
||||
file_name=file_to_download_meta['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
return download_file(url_to_download, file_to_download_meta,
|
||||
temporary_cache_dir, cache, headers, cookies)
|
||||
|
||||
|
||||
def create_temporary_directory_and_cache(model_id: str, local_dir: str,
|
||||
cache_dir: str,
|
||||
default_cache_root: str):
|
||||
def create_temporary_directory_and_cache(model_id: str,
|
||||
local_dir: str = None,
|
||||
cache_dir: str = None,
|
||||
repo_type: str = REPO_TYPE_MODEL):
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
default_cache_root = get_model_cache_root()
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
default_cache_root = get_dataset_cache_root()
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
if local_dir is not None:
|
||||
temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
|
||||
|
||||
@@ -8,8 +8,14 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.file_utils import get_model_cache_root
|
||||
from modelscope.hub.errors import InvalidParameter
|
||||
from modelscope.hub.utils.utils import model_id_to_group_owner_name
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
||||
REPO_TYPE_SUPPORT)
|
||||
from modelscope.utils.file_utils import (get_dataset_cache_root,
|
||||
get_model_cache_root)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .file_download import (create_temporary_directory_and_cache,
|
||||
download_file, get_file_download_url)
|
||||
@@ -67,14 +73,109 @@ def snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
return _snapshot_download(
|
||||
model_id,
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
local_dir=local_dir)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
dataset_id: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
is useful when you want all files from a dataset, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a dataset but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
|
||||
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be ignored in downloading, like exact file names or file extensions.
|
||||
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be downloading, like exact file names or file extensions.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
return _snapshot_download(
|
||||
dataset_id,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
local_dir=local_dir)
|
||||
|
||||
|
||||
def _snapshot_download(
|
||||
repo_id: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
):
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
|
||||
repo_type, REPO_TYPE_SUPPORT))
|
||||
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
model_id, local_dir, cache_dir, get_model_cache_root())
|
||||
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
|
||||
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
' traffic has been disabled. To enable look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
logger.warning('We can not confirm the cached file is for revision: %s'
|
||||
% revision)
|
||||
@@ -89,27 +190,48 @@ def snapshot_download(
|
||||
_api = HubApi()
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
revision_detail = _api.get_valid_revision_detail(
|
||||
model_id, revision=revision, cookies=cookies)
|
||||
revision = revision_detail['Revision']
|
||||
repo_files = []
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
revision_detail = _api.get_valid_revision_detail(
|
||||
repo_id, revision=revision, cookies=cookies)
|
||||
revision = revision_detail['Revision']
|
||||
|
||||
snapshot_header = headers if 'CI_TEST' in os.environ else {
|
||||
**headers,
|
||||
**{
|
||||
'Snapshot': 'True'
|
||||
snapshot_header = headers if 'CI_TEST' in os.environ else {
|
||||
**headers,
|
||||
**{
|
||||
'Snapshot': 'True'
|
||||
}
|
||||
}
|
||||
}
|
||||
if cache.cached_model_revision is not None:
|
||||
snapshot_header[
|
||||
'cached_model_revision'] = cache.cached_model_revision
|
||||
if cache.cached_model_revision is not None:
|
||||
snapshot_header[
|
||||
'cached_model_revision'] = cache.cached_model_revision
|
||||
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies,
|
||||
headers=snapshot_header,
|
||||
)
|
||||
repo_files = _api.get_model_files(
|
||||
model_id=repo_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies,
|
||||
headers=snapshot_header,
|
||||
)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
revision_detail = revision
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree
|
||||
and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
|
||||
if ignore_file_pattern is None:
|
||||
ignore_file_pattern = []
|
||||
@@ -128,32 +250,38 @@ def snapshot_download(
|
||||
for item in allow_file_pattern
|
||||
]
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree' or \
|
||||
any(fnmatch.fnmatch(model_file['Path'], pattern) for pattern in ignore_file_pattern) or \
|
||||
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree' or \
|
||||
any(fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern) or \
|
||||
any([re.search(re.escape(pattern), repo_file['Name']) is not None for pattern in ignore_file_pattern]): # noqa E501
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(model_file['Path'], pattern)
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
file_name = os.path.basename(model_file['Name'])
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=repo_id,
|
||||
file_path=repo_file['Path'],
|
||||
revision=revision)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
url = _api.get_dataset_file_url(
|
||||
file_name=repo_file['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
|
||||
download_file(url, model_file, temporary_cache_dir, cache, headers,
|
||||
download_file(url, repo_file, temporary_cache_dir, cache, headers,
|
||||
cookies)
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
|
||||
@@ -493,6 +493,9 @@ class Frameworks(object):
|
||||
kaldi = 'kaldi'
|
||||
|
||||
|
||||
REPO_TYPE_MODEL = 'model'
|
||||
REPO_TYPE_DATASET = 'dataset'
|
||||
REPO_TYPE_SUPPORT = [REPO_TYPE_MODEL, REPO_TYPE_DATASET]
|
||||
DEFAULT_MODEL_REVISION = None
|
||||
MASTER_MODEL_BRANCH = 'master'
|
||||
DEFAULT_REPOSITORY_REVISION = 'master'
|
||||
|
||||
@@ -5,8 +5,8 @@ import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.dataset_download import (dataset_file_download,
|
||||
dataset_snapshot_download)
|
||||
from modelscope.hub.file_download import dataset_file_download
|
||||
from modelscope.hub.snapshot_download import dataset_snapshot_download
|
||||
|
||||
|
||||
class DownloadDatasetTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user