Merge branch 'master' into release/1.16

This commit is contained in:
mulin.lyh
2024-07-12 20:51:24 +08:00
19 changed files with 752 additions and 237 deletions

View File

@@ -14,6 +14,7 @@ echo "PR modified files: $PR_CHANGED_FILES"
PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
idx=0
sleep 65
for gpu in $gpus
do
exec {lock_fd}>"/tmp/gpu$gpu" || exit 1

View File

@@ -34,22 +34,7 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
echo 'cpu unsupport detectron2'; \
fi
# torchmetrics==0.11.4 for ofa
# tinycudann for cuda12.1.0 pytorch 2.1.2
# fix pip compatible issue.
# pip install --no-cache-dir --upgrade pip && \
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \
pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \
pip install --no-cache-dir -U vllm; \
else \
echo 'cpu unsupport vllm auto-gptq'; \
fi
# install dependencies
COPY requirements /var/modelscope
RUN pip install --no-cache-dir -r /var/modelscope/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip install --no-cache-dir -r /var/modelscope/audio.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
@@ -64,5 +49,21 @@ RUN pip install --no-cache-dir -r /var/modelscope/framework.txt -f https://model
pip cache purge
# 'scipy<1.13.0' for cannot import name 'kaiser' from 'scipy.signal'
COPY examples /modelscope/examples
# torchmetrics==0.11.4 for ofa
# tinycudann for cuda12.1.0 pytorch 2.1.2
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \
pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \
pip install --no-cache-dir -U triton vllm https://modelscope.oss-cn-beijing.aliyuncs.com/packages/lmdeploy-0.5.0-cp310-cp310-linux_x86_64.whl; \
else \
echo 'cpu unsupport vllm auto-gptq'; \
fi
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
ENV VLLM_USE_MODELSCOPE=True
ENV LMDEPLOY_USE_MODELSCOPE=True

View File

@@ -10,7 +10,9 @@ if TYPE_CHECKING:
from .hub.api import HubApi
from .hub.check_model import check_local_model_is_latest, check_model_is_id
from .hub.push_to_hub import push_to_hub, push_to_hub_async
from .hub.snapshot_download import snapshot_download
from .hub.snapshot_download import snapshot_download, dataset_snapshot_download
from .hub.file_download import model_file_download, dataset_file_download
from .metrics import (
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric,
@@ -59,7 +61,9 @@ else:
'TorchModelExporter',
],
'hub.api': ['HubApi'],
'hub.snapshot_download': ['snapshot_download'],
'hub.snapshot_download':
['snapshot_download', 'dataset_snapshot_download'],
'hub.file_download': ['model_file_download', 'dataset_file_download'],
'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'],
'hub.check_model':
['check_model_is_id', 'check_local_model_is_latest'],

View File

@@ -3,8 +3,10 @@
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.hub.file_download import (dataset_file_download,
model_file_download)
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
snapshot_download)
def subparser_func(args):
@@ -24,11 +26,17 @@ class DownloadCMD(CLICommand):
""" define args for download command.
"""
parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
parser.add_argument(
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'--model',
type=str,
required=True,
help='The model id to be downloaded.')
help='The id of the model to be downloaded. For download, '
'the id of either a model or dataset must be provided.')
group.add_argument(
'--dataset',
type=str,
help='The id of the dataset to be downloaded. For download, '
'the id of either a model or dataset must be provided.')
parser.add_argument(
'--revision',
type=str,
@@ -69,27 +77,57 @@ class DownloadCMD(CLICommand):
parser.set_defaults(func=subparser_func)
def execute(self):
if len(self.args.files) == 1: # download single file
model_file_download(
self.args.model,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
elif len(self.args.files) > 1: # download specified multiple files.
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
)
else: # download repo
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)
if self.args.model:
if len(self.args.files) == 1: # download single file
model_file_download(
self.args.model,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
elif len(
self.args.files) > 1: # download specified multiple files.
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
)
else: # download repo
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)
elif self.args.dataset:
if len(self.args.files) == 1: # download single file
dataset_file_download(
self.args.dataset,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
elif len(
self.args.files) > 1: # download specified multiple files.
dataset_snapshot_download(
self.args.dataset,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
)
else: # download repo
dataset_snapshot_download(
self.args.dataset,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)
else:
pass # noop

View File

@@ -76,6 +76,7 @@ class HubApi:
connect=2,
backoff_factor=1,
status_forcelist=(500, 502, 503, 504),
respect_retry_after_header=False,
)
adapter = HTTPAdapter(max_retries=retry)
self.session.mount('http://', adapter)
@@ -741,7 +742,8 @@ class HubApi:
recursive = 'True' if recursive else 'False'
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
params = {'Revision': revision, 'Root': root_path, 'Recursive': recursive}
params = {'Revision': revision if revision else 'master',
'Root': root_path if root_path else '/', 'Recursive': recursive}
cookies = ModelScopeConfig.get_cookies()
r = self.session.get(datahub_url, params=params, cookies=cookies)
@@ -771,8 +773,10 @@ class HubApi:
@staticmethod
def dump_datatype_file(dataset_type: int, meta_cache_dir: str):
"""
Dump the data_type as a local file, in order to get the dataset formation without calling the datahub.
More details, please refer to the class `modelscope.utils.constant.DatasetFormations`.
Dump the data_type as a local file, in order to get the dataset
formation without calling the datahub.
More details, please refer to the class
`modelscope.utils.constant.DatasetFormations`.
"""
dataset_type_file_path = os.path.join(meta_cache_dir,
f'{str(dataset_type)}{DatasetFormations.formation_mark_ext.value}')
@@ -874,13 +878,14 @@ class HubApi:
dataset_name: str,
namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
view: Optional[bool] = False,
extension_filter: Optional[bool] = True):
if not file_name or not dataset_name or not namespace:
raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
# Note: make sure the FilePath is the last parameter in the url
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name}
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view}
params: str = urlencode(params)
file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
@@ -1113,7 +1118,8 @@ class ModelScopeConfig:
ModelScopeConfig.cookie_expired_warning = True
logger.warning(
'Authentication has expired, '
'please re-login if you need to access private models or datasets.')
'please re-login with modelscope login --token "YOUR_SDK_TOKEN" '
'if you need to access private models or datasets.')
return None
return cookies
return None

View File

@@ -153,9 +153,9 @@ def datahub_raise_on_error(url, rsp, http_response: requests.Response):
if rsp.get('Code') == HTTPStatus.OK:
return True
else:
request_id = get_request_id(http_response)
request_id = rsp['RequestId']
raise RequestError(
f"Url = {url}, Request id={request_id} Message = {rsp.get('Message')},\
f"Url = {url}, Request id={request_id} Code = {rsp['Code']} Message = {rsp['Message']},\
Please specify correct dataset_name and namespace.")

View File

@@ -21,10 +21,14 @@ from modelscope.hub.constants import (
API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES,
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME)
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT)
from modelscope.utils.file_utils import (get_dataset_cache_root,
get_model_cache_root)
from modelscope.utils.logger import get_logger
from .errors import FileDownloadError, NotExistError
from .errors import FileDownloadError, InvalidParameter, NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_endpoint,
model_id_to_group_owner_name)
@@ -78,8 +82,97 @@ def model_file_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
return _repo_file_download(
model_id,
file_path,
repo_type=REPO_TYPE_MODEL,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
local_dir=local_dir)
def dataset_file_download(
dataset_id: str,
file_path: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
cache_dir: Union[str, Path, None] = None,
local_dir: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
is useful when you want all files from a dataset, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a dataset but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
file_path (str): The relative path of the file to download.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
be save as cache_dir/dataset_id/THE_DATASET_FILES.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
Raises:
ValueError: the value details.
Returns:
str: Local folder path (string) of repo snapshot
Note:
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
return _repo_file_download(
dataset_id,
file_path,
repo_type=REPO_TYPE_DATASET,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
local_dir=local_dir)
def _repo_file_download(
repo_id: str,
file_path: str,
*,
repo_type: str = None,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
cache_dir: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
local_dir: Optional[str] = None,
) -> Optional[str]: # pragma: no cover
if not repo_type:
repo_type = REPO_TYPE_MODEL
if repo_type not in REPO_TYPE_SUPPORT:
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
repo_type, REPO_TYPE_SUPPORT))
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir)
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
# if local_files_only is `True` and the file already exists in cached_path
# return the cached path
@@ -93,7 +186,7 @@ def model_file_download(
else:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable model look-ups and downloads'
' traffic has been disabled. To enable look-ups and downloads'
" online, set 'local_files_only' to False.")
_api = HubApi()
@@ -102,75 +195,84 @@ def model_file_download(
}
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
repo_files = []
if repo_type == REPO_TYPE_MODEL:
revision = _api.get_valid_revision(
repo_id, revision=revision, cookies=cookies)
file_to_download_meta = None
# we need to confirm the version is up-to-date
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
repo_files = _api.get_model_files(
model_id=repo_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies)
elif repo_type == REPO_TYPE_DATASET:
group_or_owner, name = model_id_to_group_owner_name(repo_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (repo_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
repo_files = files_list_tree['Data']['Files']
revision = _api.get_valid_revision(
model_id, revision=revision, cookies=cookies)
file_to_download_info = None
# we need to confirm the version is up-to-date
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
model_files = _api.get_model_files(
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies)
for model_file in model_files:
if model_file['Type'] == 'tree':
file_to_download_meta = None
for repo_file in repo_files:
if repo_file['Type'] == 'tree':
continue
if model_file['Path'] == file_path:
if cache.exists(model_file):
if repo_file['Path'] == file_path:
if cache.exists(repo_file):
logger.debug(
f'File {model_file["Name"]} already in cache, skip downloading!'
f'File {repo_file["Name"]} already in cache, skip downloading!'
)
return cache.get_file_by_info(model_file)
return cache.get_file_by_info(repo_file)
else:
file_to_download_info = model_file
file_to_download_meta = repo_file
break
if file_to_download_info is None:
if file_to_download_meta is None:
raise NotExistError('The file path: %s not exist in: %s' %
(file_path, model_id))
(file_path, repo_id))
# we need to download again
url_to_download = get_file_download_url(model_id, file_path, revision)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_to_download_info[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
parallel_download(
url_to_download,
temporary_cache_dir,
file_path,
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_to_download_info['Size'])
else:
http_get_model_file(
url_to_download,
temporary_cache_dir,
file_path,
file_size=file_to_download_info['Size'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
temp_file_path = os.path.join(temporary_cache_dir, file_path)
# for download with commit we can't get Sha256
if file_to_download_info[FILE_HASH] is not None:
file_integrity_validation(temp_file_path,
file_to_download_info[FILE_HASH])
return cache.put_file(file_to_download_info,
os.path.join(temporary_cache_dir, file_path))
if repo_type == REPO_TYPE_MODEL:
url_to_download = get_file_download_url(repo_id, file_path, revision)
elif repo_type == REPO_TYPE_DATASET:
url_to_download = _api.get_dataset_file_url(
file_name=file_to_download_meta['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
return download_file(url_to_download, file_to_download_meta,
temporary_cache_dir, cache, headers, cookies)
def create_temporary_directory_and_cache(model_id: str, local_dir: str,
cache_dir: str):
def create_temporary_directory_and_cache(model_id: str,
local_dir: str = None,
cache_dir: str = None,
repo_type: str = REPO_TYPE_MODEL):
if repo_type == REPO_TYPE_MODEL:
default_cache_root = get_model_cache_root()
elif repo_type == REPO_TYPE_DATASET:
default_cache_root = get_dataset_cache_root()
group_or_owner, name = model_id_to_group_owner_name(model_id)
if local_dir is not None:
temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
cache = ModelFileSystemCache(local_dir)
else:
if cache_dir is None:
cache_dir = get_model_cache_root()
cache_dir = default_cache_root
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME,
@@ -269,6 +371,7 @@ def parallel_download(
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
tasks = []
file_path = os.path.join(local_dir, file_name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
for idx in range(int(file_size / PART_SIZE)):
start = idx * PART_SIZE
end = (idx + 1) * PART_SIZE - 1
@@ -323,6 +426,7 @@ def http_get_model_file(
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
temp_file_path = os.path.join(local_dir, file_name)
os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
logger.debug('downloading %s to %s', url, temp_file_path)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
@@ -349,7 +453,7 @@ def http_get_model_file(
break
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
file_size - 1)
with open(temp_file_path, 'ab') as f:
with open(temp_file_path, 'ab+') as f:
r = requests.get(
url,
stream=True,
@@ -451,3 +555,31 @@ def http_get_file(
logger.error(msg)
raise FileDownloadError(msg)
os.replace(temp_file.name, os.path.join(local_dir, file_name))
def download_file(url, file_meta, temporary_cache_dir, cache, headers,
cookies):
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
parallel_download(
url,
temporary_cache_dir,
file_meta['Path'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'])
else:
http_get_model_file(
url,
temporary_cache_dir,
file_meta['Path'],
file_size=file_meta['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
if FILE_HASH in file_meta:
file_integrity_validation(temp_file, file_meta[FILE_HASH])
# put file into to cache
return cache.put_file(file_meta, temp_file)

View File

@@ -8,14 +8,15 @@ from pathlib import Path
from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.hub.errors import InvalidParameter
from modelscope.hub.utils.utils import model_id_to_group_owner_name
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT)
from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from .file_download import (create_temporary_directory_and_cache,
get_file_download_url, http_get_model_file,
parallel_download)
from .utils.utils import file_integrity_validation
download_file, get_file_download_url)
logger = get_logger()
@@ -70,14 +71,110 @@ def snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
return _snapshot_download(
model_id,
repo_type=REPO_TYPE_MODEL,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
def dataset_snapshot_download(
dataset_id: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
cache_dir: Union[str, Path, None] = None,
local_dir: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
is useful when you want all files from a dataset, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a dataset but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
be save as cache_dir/dataset_id/THE_DATASET_FILES.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
Use regression is deprecated.
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be downloading, like exact file names or file extensions.
Raises:
ValueError: the value details.
Returns:
str: Local folder path (string) of repo snapshot
Note:
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
return _snapshot_download(
dataset_id,
repo_type=REPO_TYPE_DATASET,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
def _snapshot_download(
repo_id: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
local_dir: Optional[str] = None,
):
if not repo_type:
repo_type = REPO_TYPE_MODEL
if repo_type not in REPO_TYPE_SUPPORT:
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
repo_type, REPO_TYPE_SUPPORT))
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir)
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
if local_files_only:
if len(cache.cached_files) == 0:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable model look-ups and downloads'
' traffic has been disabled. To enable look-ups and downloads'
" online, set 'local_files_only' to False.")
logger.warning('We can not confirm the cached file is for revision: %s'
% revision)
@@ -92,27 +189,48 @@ def snapshot_download(
_api = HubApi()
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
revision_detail = _api.get_valid_revision_detail(
model_id, revision=revision, cookies=cookies)
revision = revision_detail['Revision']
repo_files = []
if repo_type == REPO_TYPE_MODEL:
revision_detail = _api.get_valid_revision_detail(
repo_id, revision=revision, cookies=cookies)
revision = revision_detail['Revision']
snapshot_header = headers if 'CI_TEST' in os.environ else {
**headers,
**{
'Snapshot': 'True'
snapshot_header = headers if 'CI_TEST' in os.environ else {
**headers,
**{
'Snapshot': 'True'
}
}
}
if cache.cached_model_revision is not None:
snapshot_header[
'cached_model_revision'] = cache.cached_model_revision
if cache.cached_model_revision is not None:
snapshot_header[
'cached_model_revision'] = cache.cached_model_revision
model_files = _api.get_model_files(
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies,
headers=snapshot_header,
)
repo_files = _api.get_model_files(
model_id=repo_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies,
headers=snapshot_header,
)
elif repo_type == REPO_TYPE_DATASET:
group_or_owner, name = model_id_to_group_owner_name(repo_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
revision_detail = revision
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree
and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (repo_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
repo_files = files_list_tree['Data']['Files']
if ignore_file_pattern is None:
ignore_file_pattern = []
@@ -122,6 +240,12 @@ def snapshot_download(
item if not item.endswith('/') else item + '*'
for item in ignore_file_pattern
]
ignore_regex_pattern = []
for file_pattern in ignore_file_pattern:
if file_pattern.startswith('*'):
ignore_regex_pattern.append('.' + file_pattern)
else:
ignore_regex_pattern.append(file_pattern)
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
@@ -131,55 +255,39 @@ def snapshot_download(
for item in allow_file_pattern
]
for model_file in model_files:
if model_file['Type'] == 'tree' or \
any(fnmatch.fnmatch(model_file['Path'], pattern) for pattern in ignore_file_pattern) or \
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
for repo_file in repo_files:
if repo_file['Type'] == 'tree' or \
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(model_file['Path'], pattern)
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
continue
if repo_type == REPO_TYPE_MODEL:
# get download url
url = get_file_download_url(
model_id=repo_id,
file_path=repo_file['Path'],
revision=revision)
elif repo_type == REPO_TYPE_DATASET:
url = _api.get_dataset_file_url(
file_name=repo_file['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
# get download url
url = get_file_download_url(
model_id=model_id,
file_path=model_file['Path'],
revision=revision)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
parallel_download(
url,
temporary_cache_dir,
model_file['Name'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=model_file['Size'])
else:
http_get_model_file(
url,
temporary_cache_dir,
model_file['Name'],
file_size=model_file['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, model_file['Name'])
if FILE_HASH in model_file:
file_integrity_validation(temp_file, model_file[FILE_HASH])
# put file into to cache
cache.put_file(model_file, temp_file)
download_file(url, repo_file, temporary_cache_dir, cache, headers,
cookies)
cache.save_model_version(revision_info=revision_detail)
return os.path.join(cache.get_root_location())

View File

@@ -164,9 +164,12 @@ class ModelFileSystemCache(FileSystemCache):
model_version_file_path = os.path.join(
self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
with open(model_version_file_path, 'w') as f:
version_info_str = 'Revision:%s,CreatedAt:%s' % (
revision_info['Revision'], revision_info['CreatedAt'])
f.write(version_info_str)
if isinstance(revision_info, dict):
version_info_str = 'Revision:%s,CreatedAt:%s' % (
revision_info['Revision'], revision_info['CreatedAt'])
f.write(version_info_str)
else:
f.write(revision_info)
def get_model_id(self):
return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID]

View File

@@ -7,22 +7,22 @@ import os
import os.path as osp
import time
import traceback
from datetime import datetime
from functools import reduce
from pathlib import Path
from typing import Union
import json
from modelscope import version
# do not delete
from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers,
Metrics, Models, Optimizers, Pipelines,
Preprocessors, TaskModels, Trainers)
from modelscope.utils.constant import Fields, Tasks
from modelscope.utils.file_utils import get_modelscope_cache_dir
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import default_group
logger = get_logger(log_level=logging.WARNING)
p = Path(__file__)
# get the path of package 'modelscope'
@@ -56,6 +56,15 @@ TEMPLATE_PATH = 'TEMPLATE_PATH'
TEMPLATE_FILE = 'ast_index_file.py'
def get_ast_logger():
ast_logger = logging.getLogger('modelscope.ast')
ast_logger.setLevel(logging.INFO)
return ast_logger
logger = get_ast_logger()
class AstScanning(object):
def __init__(self) -> None:
@@ -658,6 +667,19 @@ def _update_index(index, files_mtime):
index[REQUIREMENT_KEY].update(updated_index[REQUIREMENT_KEY])
def __is_develop_model():
# use the trick of release time check is in development
release_timestamp = int(
round(
datetime.strptime(version.__release_datetime__,
'%Y-%m-%d %H:%M:%S').timestamp()))
SECONDS_PER_YEAR = 24 * 365 * 60 * 60
current_timestamp = int(round(datetime.now().timestamp()))
if release_timestamp > current_timestamp + SECONDS_PER_YEAR:
return True
return False
def load_index(
file_list=None,
force_rebuild=False,
@@ -697,53 +719,45 @@ def load_index(
cache_dir = os.getenv('MODELSCOPE_CACHE', indexer_file_dir)
index_file = os.getenv('MODELSCOPE_INDEX_FILE', indexer_file)
file_path = os.path.join(cache_dir, index_file)
logger.info(f'Loading ast index from {file_path}')
index = None
local_changed = False
if not force_rebuild and os.path.exists(file_path):
wrapped_index = _load_index(file_path)
md5, files_mtime = file_scanner.files_mtime_md5(file_list=file_list)
from modelscope.version import __version__
if (wrapped_index[VERSION_KEY] == __version__):
if force_rebuild:
logger.info('Force rebuilding ast index from scanning every file!')
index = file_scanner.get_files_scan_results(file_list)
return index
# when developing, we need to generator as need.
if __is_develop_model():
logger.info(f'Loading ast index from {file_path}')
if os.path.exists(file_path): # already exist, check it's latest
wrapped_index = _load_index(file_path)
md5, files_mtime = file_scanner.files_mtime_md5(
file_list=file_list)
index = wrapped_index
if (wrapped_index[MD5_KEY] != md5):
local_changed = True
full_index_flag = False
if index is None:
full_index_flag = True
elif index and local_changed and FILES_MTIME_KEY not in index:
full_index_flag = True
elif index and local_changed and MODELSCOPE_PATH_KEY not in index:
full_index_flag = True
elif index and local_changed and index[
MODELSCOPE_PATH_KEY] != MODELSCOPE_PATH.as_posix():
full_index_flag = True
if full_index_flag:
if force_rebuild:
logger.info('Force rebuilding ast index from scanning every file!')
index = file_scanner.get_files_scan_results(file_list)
from modelscope.version import __version__
if (wrapped_index[VERSION_KEY] == __version__
and wrapped_index[MD5_KEY] != md5) or \
wrapped_index[VERSION_KEY] != __version__:
logger.info(
'Updating the files for the changes of local files, '
'first time updating will take longer time! Please wait till updating done!'
)
_update_index(index, files_mtime)
_save_index(index, file_path, file_list)
else:
logger.info(
f'No valid ast index found from {file_path}, generating ast index from prebuilt!'
f'No valid ast index found from {file_path}, generating ast index from scratch!'
)
index = load_from_prebuilt()
if index is None:
index = file_scanner.get_files_scan_results(file_list)
_save_index(index, file_path, file_list)
elif local_changed and not full_index_flag:
index = file_scanner.get_files_scan_results(
file_list) # generate new
_save_index(index, file_path, file_list) # save to generate path.
logger.info(
'Updating the files for the changes of local files, '
'first time updating will take longer time! Please wait till updating done!'
)
_update_index(index, files_mtime)
_save_index(index, file_path, file_list)
f'Loading done! Current index file version is {index[VERSION_KEY]}, '
f'with md5 {index[MD5_KEY]} and a total number of '
f'{len(index[INDEX_KEY])} components indexed')
else: # just load the prebuild index file.
index = load_from_prebuilt()
logger.info(
f'Loading done! Current index file version is {index[VERSION_KEY]}, '
f'with md5 {index[MD5_KEY]} and a total number of '
f'{len(index[INDEX_KEY])} components indexed')
return index

View File

@@ -5,6 +5,8 @@ import enum
class Fields(object):
""" Names for different application fields
"""
hub = 'hub'
datasets = 'datasets'
framework = 'framework'
cv = 'cv'
nlp = 'nlp'
@@ -491,6 +493,9 @@ class Frameworks(object):
kaldi = 'kaldi'
REPO_TYPE_MODEL = 'model'
REPO_TYPE_DATASET = 'dataset'
REPO_TYPE_SUPPORT = [REPO_TYPE_MODEL, REPO_TYPE_DATASET]
DEFAULT_MODEL_REVISION = None
MASTER_MODEL_BRANCH = 'master'
DEFAULT_REPOSITORY_REVISION = 'master'

View File

@@ -53,11 +53,35 @@ def get_model_cache_root() -> str:
"""Get model cache root path.
Returns:
str: the modelscope cache root.
str: the modelscope model cache root.
"""
return os.path.join(get_modelscope_cache_dir(), 'hub')
def get_dataset_cache_root() -> str:
"""Get dataset raw file cache root path.
Returns:
str: the modelscope dataset raw file cache root.
"""
return os.path.join(get_modelscope_cache_dir(), 'datasets')
def get_dataset_cache_dir(dataset_id: str) -> str:
"""Get the dataset_id's path.
dataset_cache_root/dataset_id.
Args:
dataset_id (str): The dataset id.
Returns:
str: The dataset_id's cache root path.
"""
dataset_root = get_dataset_cache_root()
return dataset_root if dataset_id is None else os.path.join(
dataset_root, dataset_id + '/')
def get_model_cache_dir(model_id: str) -> str:
"""cache dir precedence:
function parameter > environment > ~/.cache/modelscope/hub/model_id

View File

@@ -395,7 +395,8 @@ def import_module_from_model_dir(model_dir):
]
create_module_from_files(relative_file_dirs, model_dir, module_name)
for file in relative_file_dirs:
submodule = module_name + '.' + file.replace('.py', '').replace(os.sep, '.')
submodule = module_name + '.' + file.replace('.py', '').replace(
os.sep, '.')
importlib.import_module(submodule)

View File

@@ -25,6 +25,7 @@ imgaug>=0.4.0
kornia>=0.5.0
lmdb
lpips
matplotlib>=3.8.0
ml_collections
mmcls>=0.21.0
mmdet>=2.25.0,<=2.28.2
@@ -44,10 +45,11 @@ opencv-python
paint_ldm
pandas
panopticapi
Pillow>=6.2.0
plyfile>=0.7.4
psutil
pyclipper
PyMCubes
PyMCubes<=0.1.4
pytorch-lightning
regex
# <0.20.0 for compatible python3.7 python3.8

12
requirements/datasets.txt Normal file
View File

@@ -0,0 +1,12 @@
addict
attrs
datasets>=2.16.0,<2.19.0
einops
oss2
python-dateutil>=2.1
scipy
# latest version has some compatible issue.
setuptools==69.5.1
simplejson>=3.3.0
sortedcontainers>=1.5.9
urllib3>=1.26

View File

@@ -2,21 +2,12 @@ addict
attrs
datasets>=2.16.0,<2.19.0
einops
filelock>=3.3.0
huggingface_hub
numpy
oss2
pandas
Pillow>=6.2.0
# pyarrow 9.0.0 introduced event_loop core dump
pyarrow>=6.0.0,!=9.0.0
python-dateutil>=2.1
pyyaml
requests>=2.25
scipy
setuptools
# latest version has some compatible issue.
setuptools==69.5.1
simplejson>=3.3.0
sortedcontainers>=1.5.9
tqdm>=4.64.0
transformers
urllib3>=1.26
yapf

View File

@@ -196,7 +196,7 @@ if __name__ == '__main__':
# add framework dependencies to every field
for field, requires in extra_requires.items():
if field not in [
'server', 'framework'
'server', 'framework', 'hub', 'datasets'
]: # server need install model's field dependencies before.
extra_requires[field] = framework_requires + extra_requires[field]
extra_requires['all'] = all_requires

View File

@@ -0,0 +1,169 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
import time
import unittest
from modelscope.hub.file_download import dataset_file_download
from modelscope.hub.snapshot_download import dataset_snapshot_download
class DownloadDatasetTest(unittest.TestCase):
def setUp(self):
pass
def test_dataset_file_download(self):
dataset_id = 'citest/test_dataset_download'
file_path = 'open_qa.jsonl'
deep_file_path = '111/222/333/shijian.jpeg'
start_time = time.time()
# test download to cache dir.
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=file_path,
cache_dir=temp_cache_dir)
file_modify_time = os.path.getmtime(cache_file_path)
print(cache_file_path)
assert cache_file_path == os.path.join(temp_cache_dir, dataset_id,
file_path)
assert file_modify_time > start_time
# download again, will get cached file.
cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=file_path,
cache_dir=temp_cache_dir)
file_modify_time2 = os.path.getmtime(cache_file_path)
assert file_modify_time == file_modify_time2
deep_cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=deep_file_path,
cache_dir=temp_cache_dir)
deep_file_cath_path = os.path.join(temp_cache_dir, dataset_id,
deep_file_path)
assert deep_cache_file_path == deep_file_cath_path
os.path.exists(deep_cache_file_path)
# test download to local dir
with tempfile.TemporaryDirectory() as temp_local_dir:
# first download to cache.
cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=file_path,
local_dir=temp_local_dir)
assert cache_file_path == os.path.join(temp_local_dir, file_path)
file_modify_time = os.path.getmtime(cache_file_path)
assert file_modify_time > start_time
# download again, will get cached file.
cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=file_path,
local_dir=temp_local_dir)
file_modify_time2 = os.path.getmtime(cache_file_path)
assert file_modify_time == file_modify_time2
def test_dataset_snapshot_download(self):
dataset_id = 'citest/test_dataset_download'
file_path = 'open_qa.jsonl'
deep_file_path = '111/222/333/shijian.jpeg'
start_time = time.time()
# test download to cache dir.
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id, cache_dir=temp_cache_dir)
file_modify_time = os.path.getmtime(
os.path.join(dataset_cache_path, file_path))
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert file_modify_time > start_time
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
# download again, will get cached file.
dataset_cache_path2 = dataset_snapshot_download(
dataset_id=dataset_id, cache_dir=temp_cache_dir)
file_modify_time2 = os.path.getmtime(
os.path.join(dataset_cache_path2, file_path))
assert file_modify_time == file_modify_time2
# test download to local dir
with tempfile.TemporaryDirectory() as temp_local_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id, local_dir=temp_local_dir)
# root path is temp_local_dir, file download to local_dir
assert dataset_cache_path == temp_local_dir
file_modify_time = os.path.getmtime(
os.path.join(dataset_cache_path, file_path))
assert file_modify_time > start_time
# download again, will get cached file.
dataset_cache_path2 = dataset_snapshot_download(
dataset_id=dataset_id, local_dir=temp_local_dir)
file_modify_time2 = os.path.getmtime(
os.path.join(dataset_cache_path2, file_path))
assert file_modify_time == file_modify_time2
# test download with wild pattern, ignore_file_pattern
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id,
cache_dir=temp_cache_dir,
ignore_file_pattern='*.jpeg')
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id,
'111/222/shijian.jpeg'))
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id, file_path))
# test download with wild pattern, allow_file_pattern
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id,
cache_dir=temp_cache_dir,
allow_file_pattern='*.jpeg')
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id,
'111/222/shijian.jpeg'))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, file_path))
# test download with wild pattern, allow_file_pattern and ignore file pattern.
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id,
cache_dir=temp_cache_dir,
ignore_file_pattern='*.jpeg',
allow_file_pattern='*.xxx')
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id,
'111/222/shijian.jpeg'))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, file_path))

View File

@@ -12,7 +12,7 @@ from modelscope.hub.api import HubApi
from modelscope.hub.file_download import http_get_model_file
class HubOperationTest(unittest.TestCase):
class HubRetryTest(unittest.TestCase):
def setUp(self):
self.api = HubApi()
@@ -56,6 +56,8 @@ class HubOperationTest(unittest.TestCase):
rsp.msg = HTTPMessage()
rsp.read = get_content
rsp.chunked = False
rsp.length_remaining = 0
rsp.headers = {}
# retry 2 times and success.
getconn_mock.return_value.getresponse.side_effect = [
Mock(status=500, msg=HTTPMessage()),
@@ -88,16 +90,18 @@ class HubOperationTest(unittest.TestCase):
success_rsp = HTTPResponse(getconn_mock)
success_rsp.status = 200
success_rsp.msg = HTTPMessage()
success_rsp.msg.add_header('Content-Length', '2957783')
success_rsp.read = get_content
success_rsp.chunked = True
success_rsp.length_remaining = 0
success_rsp.headers = {'Content-Length': '2957783'}
failed_rsp = HTTPResponse(getconn_mock)
failed_rsp.status = 502
failed_rsp.msg = HTTPMessage()
failed_rsp.msg.add_header('Content-Length', '2957783')
failed_rsp.read = get_content
failed_rsp.chunked = True
success_rsp.length_remaining = 2957783
success_rsp.headers = {'Content-Length': '2957783'}
# retry 5 times and success.
getconn_mock.return_value.getresponse.side_effect = [