refactor code

This commit is contained in:
mulin.lyh
2024-07-12 14:29:23 +08:00
parent 759c564659
commit df81f3de95
6 changed files with 329 additions and 351 deletions

View File

@@ -10,8 +10,8 @@ if TYPE_CHECKING:
from .hub.api import HubApi
from .hub.check_model import check_local_model_is_latest, check_model_is_id
from .hub.push_to_hub import push_to_hub, push_to_hub_async
from .hub.snapshot_download import snapshot_download
from .hub.dataset_download import dataset_snapshot_download, dataset_file_download
from .hub.snapshot_download import snapshot_download, dataset_snapshot_download
from .hub.file_download import model_file_download, dataset_file_download
from .metrics import (
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
@@ -61,9 +61,9 @@ else:
'TorchModelExporter',
],
'hub.api': ['HubApi'],
'hub.snapshot_download': ['snapshot_download'],
'hub.dataset_download':
['dataset_snapshot_download', 'dataset_file_download'],
'hub.snapshot_download':
['snapshot_download', 'dataset_snapshot_download'],
'hub.file_download': ['model_file_download', 'dataset_file_download'],
'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'],
'hub.check_model':
['check_model_is_id', 'check_local_model_is_latest'],

View File

@@ -1,278 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import fnmatch
import os
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import NotExistError
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
from modelscope.utils.file_utils import get_dataset_cache_root
from modelscope.utils.logger import get_logger
from .file_download import create_temporary_directory_and_cache, download_file
from .utils.utils import model_id_to_group_owner_name
logger = get_logger()
def dataset_file_download(
dataset_id: str,
file_path: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
cache_dir: Union[str, Path, None] = None,
local_dir: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
is useful when you want all files from a dataset, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a dataset but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
file_path (str): The relative path of the file to download.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
be save as cache_dir/dataset_id/THE_DATASET_FILES.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
Raises:
ValueError: the value details.
Returns:
str: Local folder path (string) of repo snapshot
Note:
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
temporary_cache_dir, cache = create_temporary_directory_and_cache(
dataset_id,
local_dir=local_dir,
cache_dir=cache_dir,
default_cache_root=get_dataset_cache_root())
if local_files_only:
cached_file_path = cache.get_file_by_path(file_path)
if cached_file_path is not None:
logger.warning(
"File exists in local cache, but we're not sure it's up to date"
)
return cached_file_path
else:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable dataset look-ups and downloads'
" online, set 'local_files_only' to False.")
else:
# make headers
headers = {
'user-agent':
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
}
_api = HubApi()
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
group_or_owner, name = model_id_to_group_owner_name(dataset_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (dataset_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
file_meta_to_download = None
# find the file to download.
for file_meta in files_list_tree['Data']['Files']:
if file_meta['Type'] == 'tree':
continue
if fnmatch.fnmatch(file_meta['Path'], file_path):
# check file is exist in cache, if existed, skip download, otherwise download
if cache.exists(file_meta):
file_name = os.path.basename(file_meta['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!'
)
return cache.get_file_by_info(file_meta)
else:
file_meta_to_download = file_meta
break
if file_meta_to_download is None:
raise NotExistError('The file path: %s not exist in: %s' %
(file_path, dataset_id))
# start download file.
# get download url
url = _api.get_dataset_file_url(
file_name=file_meta_to_download['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
return download_file(url, file_meta_to_download, temporary_cache_dir,
cache, headers, cookies)
def dataset_snapshot_download(
dataset_id: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
cache_dir: Union[str, Path, None] = None,
local_dir: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
is useful when you want all files from a dataset, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a dataset but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
be save as cache_dir/dataset_id/THE_DATASET_FILES.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be downloading, like exact file names or file extensions.
Raises:
ValueError: the value details.
Returns:
str: Local folder path (string) of repo snapshot
Note:
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
temporary_cache_dir, cache = create_temporary_directory_and_cache(
dataset_id,
local_dir=local_dir,
cache_dir=cache_dir,
default_cache_root=get_dataset_cache_root())
if local_files_only:
if len(cache.cached_files) == 0:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable dataset look-ups and downloads'
" online, set 'local_files_only' to False.")
logger.warning('We can not confirm the cached file is for revision: %s'
% revision)
return cache.get_root_location(
) # we can not confirm the cached file is for snapshot 'revision'
else:
# make headers
headers = {
'user-agent':
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
}
_api = HubApi()
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
group_or_owner, name = model_id_to_group_owner_name(dataset_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (dataset_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
if ignore_file_pattern is None:
ignore_file_pattern = []
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
ignore_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in ignore_file_pattern
]
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
allow_file_pattern = [allow_file_pattern]
allow_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in allow_file_pattern
]
for file_meta in files_list_tree['Data']['Files']:
if file_meta['Type'] == 'tree' or \
any(fnmatch.fnmatch(file_meta['Path'], pattern) for pattern in ignore_file_pattern):
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(file_meta['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check file is exist in cache, if existed, skip download, otherwise download
if cache.exists(file_meta):
file_name = os.path.basename(file_meta['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
continue
# get download url
url = _api.get_dataset_file_url(
file_name=file_meta['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
download_file(url, file_meta, temporary_cache_dir, cache, headers,
cookies)
cache.save_model_version(revision_info=revision)
return os.path.join(cache.get_root_location())

View File

@@ -21,10 +21,14 @@ from modelscope.hub.constants import (
API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES,
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME)
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT)
from modelscope.utils.file_utils import (get_dataset_cache_root,
get_model_cache_root)
from modelscope.utils.logger import get_logger
from .errors import FileDownloadError, NotExistError
from .errors import FileDownloadError, InvalidParameter, NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_endpoint,
model_id_to_group_owner_name)
@@ -78,8 +82,97 @@ def model_file_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
return _repo_file_download(
model_id,
file_path,
repo_type=REPO_TYPE_MODEL,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
local_dir=local_dir)
def dataset_file_download(
dataset_id: str,
file_path: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
cache_dir: Union[str, Path, None] = None,
local_dir: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
is useful when you want all files from a dataset, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a dataset but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
file_path (str): The relative path of the file to download.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
be save as cache_dir/dataset_id/THE_DATASET_FILES.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
Raises:
ValueError: the value details.
Returns:
str: Local folder path (string) of repo snapshot
Note:
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
return _repo_file_download(
dataset_id,
file_path,
repo_type=REPO_TYPE_DATASET,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
local_dir=local_dir)
def _repo_file_download(
repo_id: str,
file_path: str,
*,
repo_type: str = None,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
cache_dir: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
local_dir: Optional[str] = None,
) -> Optional[str]: # pragma: no cover
if not repo_type:
repo_type = REPO_TYPE_MODEL
if repo_type not in REPO_TYPE_SUPPORT:
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
repo_type, REPO_TYPE_SUPPORT))
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir, get_model_cache_root())
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
# if local_files_only is `True` and the file already exists in cached_path
# return the cached path
@@ -93,7 +186,7 @@ def model_file_download(
else:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable model look-ups and downloads'
' traffic has been disabled. To enable look-ups and downloads'
" online, set 'local_files_only' to False.")
_api = HubApi()
@@ -102,45 +195,77 @@ def model_file_download(
}
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
repo_files = []
if repo_type == REPO_TYPE_MODEL:
revision = _api.get_valid_revision(
repo_id, revision=revision, cookies=cookies)
file_to_download_meta = None
# we need to confirm the version is up-to-date
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
repo_files = _api.get_model_files(
model_id=repo_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies)
elif repo_type == REPO_TYPE_DATASET:
group_or_owner, name = model_id_to_group_owner_name(repo_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (repo_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
repo_files = files_list_tree['Data']['Files']
revision = _api.get_valid_revision(
model_id, revision=revision, cookies=cookies)
file_to_download_info = None
# we need to confirm the version is up-to-date
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
model_files = _api.get_model_files(
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies)
for model_file in model_files:
if model_file['Type'] == 'tree':
file_to_download_meta = None
for repo_file in repo_files:
if repo_file['Type'] == 'tree':
continue
if model_file['Path'] == file_path:
if cache.exists(model_file):
if repo_file['Path'] == file_path:
if cache.exists(repo_file):
logger.debug(
f'File {model_file["Name"]} already in cache, skip downloading!'
f'File {repo_file["Name"]} already in cache, skip downloading!'
)
return cache.get_file_by_info(model_file)
return cache.get_file_by_info(repo_file)
else:
file_to_download_info = model_file
file_to_download_meta = repo_file
break
if file_to_download_info is None:
if file_to_download_meta is None:
raise NotExistError('The file path: %s not exist in: %s' %
(file_path, model_id))
(file_path, repo_id))
# we need to download again
url_to_download = get_file_download_url(model_id, file_path, revision)
return download_file(url_to_download, file_to_download_info,
if repo_type == REPO_TYPE_MODEL:
url_to_download = get_file_download_url(repo_id, file_path, revision)
elif repo_type == REPO_TYPE_DATASET:
url_to_download = _api.get_dataset_file_url(
file_name=file_to_download_meta['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
return download_file(url_to_download, file_to_download_meta,
temporary_cache_dir, cache, headers, cookies)
def create_temporary_directory_and_cache(model_id: str, local_dir: str,
cache_dir: str,
default_cache_root: str):
def create_temporary_directory_and_cache(model_id: str,
local_dir: str = None,
cache_dir: str = None,
repo_type: str = REPO_TYPE_MODEL):
if repo_type == REPO_TYPE_MODEL:
default_cache_root = get_model_cache_root()
elif repo_type == REPO_TYPE_DATASET:
default_cache_root = get_dataset_cache_root()
group_or_owner, name = model_id_to_group_owner_name(model_id)
if local_dir is not None:
temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)

View File

@@ -8,8 +8,14 @@ from pathlib import Path
from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.hub.errors import InvalidParameter
from modelscope.hub.utils.utils import model_id_to_group_owner_name
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT)
from modelscope.utils.file_utils import (get_dataset_cache_root,
get_model_cache_root)
from modelscope.utils.logger import get_logger
from .file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url)
@@ -67,14 +73,109 @@ def snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
return _snapshot_download(
model_id,
repo_type=REPO_TYPE_MODEL,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
def dataset_snapshot_download(
dataset_id: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
cache_dir: Union[str, Path, None] = None,
local_dir: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
is useful when you want all files from a dataset, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a dataset but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
be save as cache_dir/dataset_id/THE_DATASET_FILES.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be downloading, like exact file names or file extensions.
Raises:
ValueError: the value details.
Returns:
str: Local folder path (string) of repo snapshot
Note:
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
return _snapshot_download(
dataset_id,
repo_type=REPO_TYPE_DATASET,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
def _snapshot_download(
repo_id: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
local_dir: Optional[str] = None,
):
if not repo_type:
repo_type = REPO_TYPE_MODEL
if repo_type not in REPO_TYPE_SUPPORT:
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
repo_type, REPO_TYPE_SUPPORT))
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir, get_model_cache_root())
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
if local_files_only:
if len(cache.cached_files) == 0:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable model look-ups and downloads'
' traffic has been disabled. To enable look-ups and downloads'
" online, set 'local_files_only' to False.")
logger.warning('We can not confirm the cached file is for revision: %s'
% revision)
@@ -89,27 +190,48 @@ def snapshot_download(
_api = HubApi()
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
revision_detail = _api.get_valid_revision_detail(
model_id, revision=revision, cookies=cookies)
revision = revision_detail['Revision']
repo_files = []
if repo_type == REPO_TYPE_MODEL:
revision_detail = _api.get_valid_revision_detail(
repo_id, revision=revision, cookies=cookies)
revision = revision_detail['Revision']
snapshot_header = headers if 'CI_TEST' in os.environ else {
**headers,
**{
'Snapshot': 'True'
snapshot_header = headers if 'CI_TEST' in os.environ else {
**headers,
**{
'Snapshot': 'True'
}
}
}
if cache.cached_model_revision is not None:
snapshot_header[
'cached_model_revision'] = cache.cached_model_revision
if cache.cached_model_revision is not None:
snapshot_header[
'cached_model_revision'] = cache.cached_model_revision
model_files = _api.get_model_files(
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies,
headers=snapshot_header,
)
repo_files = _api.get_model_files(
model_id=repo_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies,
headers=snapshot_header,
)
elif repo_type == REPO_TYPE_DATASET:
group_or_owner, name = model_id_to_group_owner_name(repo_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
revision_detail = revision
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree
and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (repo_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
repo_files = files_list_tree['Data']['Files']
if ignore_file_pattern is None:
ignore_file_pattern = []
@@ -128,32 +250,38 @@ def snapshot_download(
for item in allow_file_pattern
]
for model_file in model_files:
if model_file['Type'] == 'tree' or \
any(fnmatch.fnmatch(model_file['Path'], pattern) for pattern in ignore_file_pattern) or \
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
for repo_file in repo_files:
if repo_file['Type'] == 'tree' or \
any(fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern) or \
any([re.search(re.escape(pattern), repo_file['Name']) is not None for pattern in ignore_file_pattern]): # noqa E501
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(model_file['Path'], pattern)
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
continue
if repo_type == REPO_TYPE_MODEL:
# get download url
url = get_file_download_url(
model_id=repo_id,
file_path=repo_file['Path'],
revision=revision)
elif repo_type == REPO_TYPE_DATASET:
url = _api.get_dataset_file_url(
file_name=repo_file['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
# get download url
url = get_file_download_url(
model_id=model_id,
file_path=model_file['Path'],
revision=revision)
download_file(url, model_file, temporary_cache_dir, cache, headers,
download_file(url, repo_file, temporary_cache_dir, cache, headers,
cookies)
cache.save_model_version(revision_info=revision_detail)

View File

@@ -493,6 +493,9 @@ class Frameworks(object):
kaldi = 'kaldi'
REPO_TYPE_MODEL = 'model'
REPO_TYPE_DATASET = 'dataset'
REPO_TYPE_SUPPORT = [REPO_TYPE_MODEL, REPO_TYPE_DATASET]
DEFAULT_MODEL_REVISION = None
MASTER_MODEL_BRANCH = 'master'
DEFAULT_REPOSITORY_REVISION = 'master'

View File

@@ -5,8 +5,8 @@ import tempfile
import time
import unittest
from modelscope.hub.dataset_download import (dataset_file_download,
dataset_snapshot_download)
from modelscope.hub.file_download import dataset_file_download
from modelscope.hub.snapshot_download import dataset_snapshot_download
class DownloadDatasetTest(unittest.TestCase):