mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 00:07:42 +01:00
add dataset download (#906)
* add dataset download * fix cr issue * fix cv matplotlib issue * refactor code * fix ut issue * remove debug code * remove unused import * fix import issue * sleep 65s before start docker avoid kill and run failed --------- Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
@@ -14,6 +14,7 @@ echo "PR modified files: $PR_CHANGED_FILES"
|
||||
PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
|
||||
echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
|
||||
idx=0
|
||||
sleep 65
|
||||
for gpu in $gpus
|
||||
do
|
||||
exec {lock_fd}>"/tmp/gpu$gpu" || exit 1
|
||||
|
||||
@@ -34,21 +34,7 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
echo 'cpu unsupport detectron2'; \
|
||||
fi
|
||||
|
||||
# torchmetrics==0.11.4 for ofa
|
||||
# tinycudann for cuda12.1.0 pytorch 2.1.2
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
|
||||
pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
|
||||
pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \
|
||||
pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \
|
||||
pip install --no-cache-dir -U vllm; \
|
||||
else \
|
||||
echo 'cpu unsupport vllm auto-gptq'; \
|
||||
fi
|
||||
|
||||
# install dependencies
|
||||
COPY requirements /var/modelscope
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir -r /var/modelscope/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
@@ -64,5 +50,21 @@ RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip cache purge
|
||||
# 'scipy<1.13.0' for cannot import name 'kaiser' from 'scipy.signal'
|
||||
COPY examples /modelscope/examples
|
||||
# torchmetrics==0.11.4 for ofa
|
||||
# tinycudann for cuda12.1.0 pytorch 2.1.2
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
|
||||
pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
|
||||
pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \
|
||||
pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \
|
||||
pip install --no-cache-dir -U triton vllm https://modelscope.oss-cn-beijing.aliyuncs.com/packages/lmdeploy-0.5.0-cp310-cp310-linux_x86_64.whl; \
|
||||
else \
|
||||
echo 'cpu unsupport vllm auto-gptq'; \
|
||||
fi
|
||||
|
||||
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||
ENV VLLM_USE_MODELSCOPE=True
|
||||
ENV LMDEPLOY_USE_MODELSCOPE=True
|
||||
|
||||
@@ -10,7 +10,9 @@ if TYPE_CHECKING:
|
||||
from .hub.api import HubApi
|
||||
from .hub.check_model import check_local_model_is_latest, check_model_is_id
|
||||
from .hub.push_to_hub import push_to_hub, push_to_hub_async
|
||||
from .hub.snapshot_download import snapshot_download
|
||||
from .hub.snapshot_download import snapshot_download, dataset_snapshot_download
|
||||
from .hub.file_download import model_file_download, dataset_file_download
|
||||
|
||||
from .metrics import (
|
||||
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
|
||||
ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric,
|
||||
@@ -59,7 +61,9 @@ else:
|
||||
'TorchModelExporter',
|
||||
],
|
||||
'hub.api': ['HubApi'],
|
||||
'hub.snapshot_download': ['snapshot_download'],
|
||||
'hub.snapshot_download':
|
||||
['snapshot_download', 'dataset_snapshot_download'],
|
||||
'hub.file_download': ['model_file_download', 'dataset_file_download'],
|
||||
'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'],
|
||||
'hub.check_model':
|
||||
['check_model_is_id', 'check_local_model_is_latest'],
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.hub.file_download import (dataset_file_download,
|
||||
model_file_download)
|
||||
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
|
||||
snapshot_download)
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
@@ -24,11 +26,17 @@ class DownloadCMD(CLICommand):
|
||||
""" define args for download command.
|
||||
"""
|
||||
parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
|
||||
parser.add_argument(
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The model id to be downloaded.')
|
||||
help='The id of the model to be downloaded. For download, '
|
||||
'the id of either a model or dataset must be provided.')
|
||||
group.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
help='The id of the dataset to be downloaded. For download, '
|
||||
'the id of either a model or dataset must be provided.')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=str,
|
||||
@@ -69,27 +77,57 @@ class DownloadCMD(CLICommand):
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
if len(self.args.files) == 1: # download single file
|
||||
model_file_download(
|
||||
self.args.model,
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision)
|
||||
elif len(self.args.files) > 1: # download specified multiple files.
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
)
|
||||
else: # download repo
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
)
|
||||
if self.args.model:
|
||||
if len(self.args.files) == 1: # download single file
|
||||
model_file_download(
|
||||
self.args.model,
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
)
|
||||
else: # download repo
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
)
|
||||
elif self.args.dataset:
|
||||
if len(self.args.files) == 1: # download single file
|
||||
dataset_file_download(
|
||||
self.args.dataset,
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
dataset_snapshot_download(
|
||||
self.args.dataset,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
)
|
||||
else: # download repo
|
||||
dataset_snapshot_download(
|
||||
self.args.dataset,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
)
|
||||
else:
|
||||
pass # noop
|
||||
|
||||
@@ -76,6 +76,7 @@ class HubApi:
|
||||
connect=2,
|
||||
backoff_factor=1,
|
||||
status_forcelist=(500, 502, 503, 504),
|
||||
respect_retry_after_header=False,
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry)
|
||||
self.session.mount('http://', adapter)
|
||||
@@ -741,7 +742,8 @@ class HubApi:
|
||||
|
||||
recursive = 'True' if recursive else 'False'
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
|
||||
params = {'Revision': revision, 'Root': root_path, 'Recursive': recursive}
|
||||
params = {'Revision': revision if revision else 'master',
|
||||
'Root': root_path if root_path else '/', 'Recursive': recursive}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
r = self.session.get(datahub_url, params=params, cookies=cookies)
|
||||
@@ -771,8 +773,10 @@ class HubApi:
|
||||
@staticmethod
|
||||
def dump_datatype_file(dataset_type: int, meta_cache_dir: str):
|
||||
"""
|
||||
Dump the data_type as a local file, in order to get the dataset formation without calling the datahub.
|
||||
More details, please refer to the class `modelscope.utils.constant.DatasetFormations`.
|
||||
Dump the data_type as a local file, in order to get the dataset
|
||||
formation without calling the datahub.
|
||||
More details, please refer to the class
|
||||
`modelscope.utils.constant.DatasetFormations`.
|
||||
"""
|
||||
dataset_type_file_path = os.path.join(meta_cache_dir,
|
||||
f'{str(dataset_type)}{DatasetFormations.formation_mark_ext.value}')
|
||||
@@ -874,13 +878,14 @@ class HubApi:
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
view: Optional[bool] = False,
|
||||
extension_filter: Optional[bool] = True):
|
||||
|
||||
if not file_name or not dataset_name or not namespace:
|
||||
raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
|
||||
|
||||
# Note: make sure the FilePath is the last parameter in the url
|
||||
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name}
|
||||
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view}
|
||||
params: str = urlencode(params)
|
||||
file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
|
||||
|
||||
@@ -1113,7 +1118,8 @@ class ModelScopeConfig:
|
||||
ModelScopeConfig.cookie_expired_warning = True
|
||||
logger.warning(
|
||||
'Authentication has expired, '
|
||||
'please re-login if you need to access private models or datasets.')
|
||||
'please re-login with modelscope login --token "YOUR_SDK_TOKEN" '
|
||||
'if you need to access private models or datasets.')
|
||||
return None
|
||||
return cookies
|
||||
return None
|
||||
|
||||
@@ -153,9 +153,9 @@ def datahub_raise_on_error(url, rsp, http_response: requests.Response):
|
||||
if rsp.get('Code') == HTTPStatus.OK:
|
||||
return True
|
||||
else:
|
||||
request_id = get_request_id(http_response)
|
||||
request_id = rsp['RequestId']
|
||||
raise RequestError(
|
||||
f"Url = {url}, Request id={request_id} Message = {rsp.get('Message')},\
|
||||
f"Url = {url}, Request id={request_id} Code = {rsp['Code']} Message = {rsp['Message']},\
|
||||
Please specify correct dataset_name and namespace.")
|
||||
|
||||
|
||||
|
||||
@@ -21,10 +21,14 @@ from modelscope.hub.constants import (
|
||||
API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME)
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.file_utils import get_model_cache_root
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
||||
REPO_TYPE_SUPPORT)
|
||||
from modelscope.utils.file_utils import (get_dataset_cache_root,
|
||||
get_model_cache_root)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .errors import FileDownloadError, NotExistError
|
||||
from .errors import FileDownloadError, InvalidParameter, NotExistError
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||
model_id_to_group_owner_name)
|
||||
@@ -78,8 +82,97 @@ def model_file_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
return _repo_file_download(
|
||||
model_id,
|
||||
file_path,
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
local_dir=local_dir)
|
||||
|
||||
|
||||
def dataset_file_download(
|
||||
dataset_id: str,
|
||||
file_path: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
is useful when you want all files from a dataset, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a dataset but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
||||
file_path (str): The relative path of the file to download.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
|
||||
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
return _repo_file_download(
|
||||
dataset_id,
|
||||
file_path,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
local_dir=local_dir)
|
||||
|
||||
|
||||
def _repo_file_download(
|
||||
repo_id: str,
|
||||
file_path: str,
|
||||
*,
|
||||
repo_type: str = None,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cache_dir: Optional[str] = None,
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
|
||||
repo_type, REPO_TYPE_SUPPORT))
|
||||
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
model_id, local_dir, cache_dir)
|
||||
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
|
||||
|
||||
# if local_files_only is `True` and the file already exists in cached_path
|
||||
# return the cached path
|
||||
@@ -93,7 +186,7 @@ def model_file_download(
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
' traffic has been disabled. To enable look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
|
||||
_api = HubApi()
|
||||
@@ -102,75 +195,84 @@ def model_file_download(
|
||||
}
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
repo_files = []
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
revision = _api.get_valid_revision(
|
||||
repo_id, revision=revision, cookies=cookies)
|
||||
file_to_download_meta = None
|
||||
# we need to confirm the version is up-to-date
|
||||
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
|
||||
repo_files = _api.get_model_files(
|
||||
model_id=repo_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
|
||||
revision = _api.get_valid_revision(
|
||||
model_id, revision=revision, cookies=cookies)
|
||||
file_to_download_info = None
|
||||
# we need to confirm the version is up-to-date
|
||||
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
file_to_download_meta = None
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if model_file['Path'] == file_path:
|
||||
if cache.exists(model_file):
|
||||
if repo_file['Path'] == file_path:
|
||||
if cache.exists(repo_file):
|
||||
logger.debug(
|
||||
f'File {model_file["Name"]} already in cache, skip downloading!'
|
||||
f'File {repo_file["Name"]} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(model_file)
|
||||
return cache.get_file_by_info(repo_file)
|
||||
else:
|
||||
file_to_download_info = model_file
|
||||
file_to_download_meta = repo_file
|
||||
break
|
||||
|
||||
if file_to_download_info is None:
|
||||
if file_to_download_meta is None:
|
||||
raise NotExistError('The file path: %s not exist in: %s' %
|
||||
(file_path, model_id))
|
||||
(file_path, repo_id))
|
||||
|
||||
# we need to download again
|
||||
url_to_download = get_file_download_url(model_id, file_path, revision)
|
||||
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_to_download_info[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
|
||||
parallel_download(
|
||||
url_to_download,
|
||||
temporary_cache_dir,
|
||||
file_path,
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_to_download_info['Size'])
|
||||
else:
|
||||
http_get_model_file(
|
||||
url_to_download,
|
||||
temporary_cache_dir,
|
||||
file_path,
|
||||
file_size=file_to_download_info['Size'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict())
|
||||
|
||||
temp_file_path = os.path.join(temporary_cache_dir, file_path)
|
||||
# for download with commit we can't get Sha256
|
||||
if file_to_download_info[FILE_HASH] is not None:
|
||||
file_integrity_validation(temp_file_path,
|
||||
file_to_download_info[FILE_HASH])
|
||||
return cache.put_file(file_to_download_info,
|
||||
os.path.join(temporary_cache_dir, file_path))
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
url_to_download = get_file_download_url(repo_id, file_path, revision)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
url_to_download = _api.get_dataset_file_url(
|
||||
file_name=file_to_download_meta['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
return download_file(url_to_download, file_to_download_meta,
|
||||
temporary_cache_dir, cache, headers, cookies)
|
||||
|
||||
|
||||
def create_temporary_directory_and_cache(model_id: str, local_dir: str,
|
||||
cache_dir: str):
|
||||
def create_temporary_directory_and_cache(model_id: str,
|
||||
local_dir: str = None,
|
||||
cache_dir: str = None,
|
||||
repo_type: str = REPO_TYPE_MODEL):
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
default_cache_root = get_model_cache_root()
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
default_cache_root = get_dataset_cache_root()
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
if local_dir is not None:
|
||||
temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
|
||||
cache = ModelFileSystemCache(local_dir)
|
||||
else:
|
||||
if cache_dir is None:
|
||||
cache_dir = get_model_cache_root()
|
||||
cache_dir = default_cache_root
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME,
|
||||
@@ -269,6 +371,7 @@ def parallel_download(
|
||||
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
||||
tasks = []
|
||||
file_path = os.path.join(local_dir, file_name)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
for idx in range(int(file_size / PART_SIZE)):
|
||||
start = idx * PART_SIZE
|
||||
end = (idx + 1) * PART_SIZE - 1
|
||||
@@ -323,6 +426,7 @@ def http_get_model_file(
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
temp_file_path = os.path.join(local_dir, file_name)
|
||||
os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
|
||||
logger.debug('downloading %s to %s', url, temp_file_path)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
@@ -349,7 +453,7 @@ def http_get_model_file(
|
||||
break
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||
file_size - 1)
|
||||
with open(temp_file_path, 'ab') as f:
|
||||
with open(temp_file_path, 'ab+') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
@@ -451,3 +555,31 @@ def http_get_file(
|
||||
logger.error(msg)
|
||||
raise FileDownloadError(msg)
|
||||
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
||||
|
||||
|
||||
def download_file(url, file_meta, temporary_cache_dir, cache, headers,
|
||||
cookies):
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||
parallel_download(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta['Path'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_meta['Size'])
|
||||
else:
|
||||
http_get_model_file(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta['Path'],
|
||||
file_size=file_meta['Size'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
|
||||
if FILE_HASH in file_meta:
|
||||
file_integrity_validation(temp_file, file_meta[FILE_HASH])
|
||||
# put file into to cache
|
||||
return cache.put_file(file_meta, temp_file)
|
||||
|
||||
@@ -8,14 +8,15 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.hub.errors import InvalidParameter
|
||||
from modelscope.hub.utils.utils import model_id_to_group_owner_name
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
||||
REPO_TYPE_SUPPORT)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
||||
from .file_download import (create_temporary_directory_and_cache,
|
||||
get_file_download_url, http_get_model_file,
|
||||
parallel_download)
|
||||
from .utils.utils import file_integrity_validation
|
||||
download_file, get_file_download_url)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -70,14 +71,110 @@ def snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
return _snapshot_download(
|
||||
model_id,
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
local_dir=local_dir)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
dataset_id: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
is useful when you want all files from a dataset, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a dataset but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
|
||||
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be ignored in downloading, like exact file names or file extensions.
|
||||
Use regression is deprecated.
|
||||
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be downloading, like exact file names or file extensions.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
return _snapshot_download(
|
||||
dataset_id,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
local_dir=local_dir)
|
||||
|
||||
|
||||
def _snapshot_download(
|
||||
repo_id: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
):
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise InvalidParameter('Invalid repo type: %s, only support: %s' (
|
||||
repo_type, REPO_TYPE_SUPPORT))
|
||||
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
model_id, local_dir, cache_dir)
|
||||
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
|
||||
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
' traffic has been disabled. To enable look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
logger.warning('We can not confirm the cached file is for revision: %s'
|
||||
% revision)
|
||||
@@ -92,27 +189,48 @@ def snapshot_download(
|
||||
_api = HubApi()
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
revision_detail = _api.get_valid_revision_detail(
|
||||
model_id, revision=revision, cookies=cookies)
|
||||
revision = revision_detail['Revision']
|
||||
repo_files = []
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
revision_detail = _api.get_valid_revision_detail(
|
||||
repo_id, revision=revision, cookies=cookies)
|
||||
revision = revision_detail['Revision']
|
||||
|
||||
snapshot_header = headers if 'CI_TEST' in os.environ else {
|
||||
**headers,
|
||||
**{
|
||||
'Snapshot': 'True'
|
||||
snapshot_header = headers if 'CI_TEST' in os.environ else {
|
||||
**headers,
|
||||
**{
|
||||
'Snapshot': 'True'
|
||||
}
|
||||
}
|
||||
}
|
||||
if cache.cached_model_revision is not None:
|
||||
snapshot_header[
|
||||
'cached_model_revision'] = cache.cached_model_revision
|
||||
if cache.cached_model_revision is not None:
|
||||
snapshot_header[
|
||||
'cached_model_revision'] = cache.cached_model_revision
|
||||
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies,
|
||||
headers=snapshot_header,
|
||||
)
|
||||
repo_files = _api.get_model_files(
|
||||
model_id=repo_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies,
|
||||
headers=snapshot_header,
|
||||
)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
revision_detail = revision
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree
|
||||
and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
|
||||
if ignore_file_pattern is None:
|
||||
ignore_file_pattern = []
|
||||
@@ -122,6 +240,12 @@ def snapshot_download(
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in ignore_file_pattern
|
||||
]
|
||||
ignore_regex_pattern = []
|
||||
for file_pattern in ignore_file_pattern:
|
||||
if file_pattern.startswith('*'):
|
||||
ignore_regex_pattern.append('.' + file_pattern)
|
||||
else:
|
||||
ignore_regex_pattern.append(file_pattern)
|
||||
|
||||
if allow_file_pattern is not None:
|
||||
if isinstance(allow_file_pattern, str):
|
||||
@@ -131,55 +255,39 @@ def snapshot_download(
|
||||
for item in allow_file_pattern
|
||||
]
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree' or \
|
||||
any(fnmatch.fnmatch(model_file['Path'], pattern) for pattern in ignore_file_pattern) or \
|
||||
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree' or \
|
||||
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
|
||||
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(model_file['Path'], pattern)
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
file_name = os.path.basename(model_file['Name'])
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=repo_id,
|
||||
file_path=repo_file['Path'],
|
||||
revision=revision)
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
url = _api.get_dataset_file_url(
|
||||
file_name=repo_file['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
|
||||
parallel_download(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=model_file['Size'])
|
||||
else:
|
||||
http_get_model_file(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
model_file['Name'],
|
||||
file_size=model_file['Size'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temporary_cache_dir, model_file['Name'])
|
||||
if FILE_HASH in model_file:
|
||||
file_integrity_validation(temp_file, model_file[FILE_HASH])
|
||||
# put file into to cache
|
||||
cache.put_file(model_file, temp_file)
|
||||
download_file(url, repo_file, temporary_cache_dir, cache, headers,
|
||||
cookies)
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
return os.path.join(cache.get_root_location())
|
||||
|
||||
@@ -164,9 +164,12 @@ class ModelFileSystemCache(FileSystemCache):
|
||||
model_version_file_path = os.path.join(
|
||||
self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
|
||||
with open(model_version_file_path, 'w') as f:
|
||||
version_info_str = 'Revision:%s,CreatedAt:%s' % (
|
||||
revision_info['Revision'], revision_info['CreatedAt'])
|
||||
f.write(version_info_str)
|
||||
if isinstance(revision_info, dict):
|
||||
version_info_str = 'Revision:%s,CreatedAt:%s' % (
|
||||
revision_info['Revision'], revision_info['CreatedAt'])
|
||||
f.write(version_info_str)
|
||||
else:
|
||||
f.write(revision_info)
|
||||
|
||||
def get_model_id(self):
|
||||
return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID]
|
||||
|
||||
@@ -5,6 +5,8 @@ import enum
|
||||
class Fields(object):
|
||||
""" Names for different application fields
|
||||
"""
|
||||
hub = 'hub'
|
||||
datasets = 'datasets'
|
||||
framework = 'framework'
|
||||
cv = 'cv'
|
||||
nlp = 'nlp'
|
||||
@@ -491,6 +493,9 @@ class Frameworks(object):
|
||||
kaldi = 'kaldi'
|
||||
|
||||
|
||||
REPO_TYPE_MODEL = 'model'
|
||||
REPO_TYPE_DATASET = 'dataset'
|
||||
REPO_TYPE_SUPPORT = [REPO_TYPE_MODEL, REPO_TYPE_DATASET]
|
||||
DEFAULT_MODEL_REVISION = None
|
||||
MASTER_MODEL_BRANCH = 'master'
|
||||
DEFAULT_REPOSITORY_REVISION = 'master'
|
||||
|
||||
@@ -53,11 +53,35 @@ def get_model_cache_root() -> str:
|
||||
"""Get model cache root path.
|
||||
|
||||
Returns:
|
||||
str: the modelscope cache root.
|
||||
str: the modelscope model cache root.
|
||||
"""
|
||||
return os.path.join(get_modelscope_cache_dir(), 'hub')
|
||||
|
||||
|
||||
def get_dataset_cache_root() -> str:
|
||||
"""Get dataset raw file cache root path.
|
||||
|
||||
Returns:
|
||||
str: the modelscope dataset raw file cache root.
|
||||
"""
|
||||
return os.path.join(get_modelscope_cache_dir(), 'datasets')
|
||||
|
||||
|
||||
def get_dataset_cache_dir(dataset_id: str) -> str:
|
||||
"""Get the dataset_id's path.
|
||||
dataset_cache_root/dataset_id.
|
||||
|
||||
Args:
|
||||
dataset_id (str): The dataset id.
|
||||
|
||||
Returns:
|
||||
str: The dataset_id's cache root path.
|
||||
"""
|
||||
dataset_root = get_dataset_cache_root()
|
||||
return dataset_root if dataset_id is None else os.path.join(
|
||||
dataset_root, dataset_id + '/')
|
||||
|
||||
|
||||
def get_model_cache_dir(model_id: str) -> str:
|
||||
"""cache dir precedence:
|
||||
function parameter > environment > ~/.cache/modelscope/hub/model_id
|
||||
|
||||
@@ -25,6 +25,7 @@ imgaug>=0.4.0
|
||||
kornia>=0.5.0
|
||||
lmdb
|
||||
lpips
|
||||
matplotlib>=3.8.0
|
||||
ml_collections
|
||||
mmcls>=0.21.0
|
||||
mmdet>=2.25.0,<=2.28.2
|
||||
@@ -44,10 +45,11 @@ opencv-python
|
||||
paint_ldm
|
||||
pandas
|
||||
panopticapi
|
||||
Pillow>=6.2.0
|
||||
plyfile>=0.7.4
|
||||
psutil
|
||||
pyclipper
|
||||
PyMCubes
|
||||
PyMCubes<=0.1.4
|
||||
pytorch-lightning
|
||||
regex
|
||||
# <0.20.0 for compatible python3.7 python3.8
|
||||
|
||||
12
requirements/datasets.txt
Normal file
12
requirements/datasets.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.16.0,<2.19.0
|
||||
einops
|
||||
oss2
|
||||
python-dateutil>=2.1
|
||||
scipy
|
||||
# latest version has some compatible issue.
|
||||
setuptools==69.5.1
|
||||
simplejson>=3.3.0
|
||||
sortedcontainers>=1.5.9
|
||||
urllib3>=1.26
|
||||
@@ -2,21 +2,12 @@ addict
|
||||
attrs
|
||||
datasets>=2.16.0,<2.19.0
|
||||
einops
|
||||
filelock>=3.3.0
|
||||
huggingface_hub
|
||||
numpy
|
||||
oss2
|
||||
pandas
|
||||
Pillow>=6.2.0
|
||||
# pyarrow 9.0.0 introduced event_loop core dump
|
||||
pyarrow>=6.0.0,!=9.0.0
|
||||
python-dateutil>=2.1
|
||||
pyyaml
|
||||
requests>=2.25
|
||||
scipy
|
||||
setuptools
|
||||
# latest version has some compatible issue.
|
||||
setuptools==69.5.1
|
||||
simplejson>=3.3.0
|
||||
sortedcontainers>=1.5.9
|
||||
tqdm>=4.64.0
|
||||
transformers
|
||||
urllib3>=1.26
|
||||
yapf
|
||||
|
||||
2
setup.py
2
setup.py
@@ -196,7 +196,7 @@ if __name__ == '__main__':
|
||||
# add framework dependencies to every field
|
||||
for field, requires in extra_requires.items():
|
||||
if field not in [
|
||||
'server', 'framework'
|
||||
'server', 'framework', 'hub', 'datasets'
|
||||
]: # server need install model's field dependencies before.
|
||||
extra_requires[field] = framework_requires + extra_requires[field]
|
||||
extra_requires['all'] = all_requires
|
||||
|
||||
169
tests/hub/test_download_dataset_file.py
Normal file
169
tests/hub/test_download_dataset_file.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.file_download import dataset_file_download
|
||||
from modelscope.hub.snapshot_download import dataset_snapshot_download
|
||||
|
||||
|
||||
class DownloadDatasetTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_dataset_file_download(self):
|
||||
dataset_id = 'citest/test_dataset_download'
|
||||
file_path = 'open_qa.jsonl'
|
||||
deep_file_path = '111/222/333/shijian.jpeg'
|
||||
start_time = time.time()
|
||||
|
||||
# test download to cache dir.
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
# first download to cache.
|
||||
cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=file_path,
|
||||
cache_dir=temp_cache_dir)
|
||||
file_modify_time = os.path.getmtime(cache_file_path)
|
||||
print(cache_file_path)
|
||||
assert cache_file_path == os.path.join(temp_cache_dir, dataset_id,
|
||||
file_path)
|
||||
assert file_modify_time > start_time
|
||||
# download again, will get cached file.
|
||||
cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=file_path,
|
||||
cache_dir=temp_cache_dir)
|
||||
file_modify_time2 = os.path.getmtime(cache_file_path)
|
||||
assert file_modify_time == file_modify_time2
|
||||
|
||||
deep_cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=deep_file_path,
|
||||
cache_dir=temp_cache_dir)
|
||||
deep_file_cath_path = os.path.join(temp_cache_dir, dataset_id,
|
||||
deep_file_path)
|
||||
assert deep_cache_file_path == deep_file_cath_path
|
||||
os.path.exists(deep_cache_file_path)
|
||||
|
||||
# test download to local dir
|
||||
with tempfile.TemporaryDirectory() as temp_local_dir:
|
||||
# first download to cache.
|
||||
cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=file_path,
|
||||
local_dir=temp_local_dir)
|
||||
assert cache_file_path == os.path.join(temp_local_dir, file_path)
|
||||
file_modify_time = os.path.getmtime(cache_file_path)
|
||||
assert file_modify_time > start_time
|
||||
# download again, will get cached file.
|
||||
cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=file_path,
|
||||
local_dir=temp_local_dir)
|
||||
file_modify_time2 = os.path.getmtime(cache_file_path)
|
||||
assert file_modify_time == file_modify_time2
|
||||
|
||||
def test_dataset_snapshot_download(self):
|
||||
dataset_id = 'citest/test_dataset_download'
|
||||
file_path = 'open_qa.jsonl'
|
||||
deep_file_path = '111/222/333/shijian.jpeg'
|
||||
start_time = time.time()
|
||||
|
||||
# test download to cache dir.
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
# first download to cache.
|
||||
dataset_cache_path = dataset_snapshot_download(
|
||||
dataset_id=dataset_id, cache_dir=temp_cache_dir)
|
||||
file_modify_time = os.path.getmtime(
|
||||
os.path.join(dataset_cache_path, file_path))
|
||||
assert dataset_cache_path == os.path.join(temp_cache_dir,
|
||||
dataset_id)
|
||||
assert file_modify_time > start_time
|
||||
assert os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
|
||||
|
||||
# download again, will get cached file.
|
||||
dataset_cache_path2 = dataset_snapshot_download(
|
||||
dataset_id=dataset_id, cache_dir=temp_cache_dir)
|
||||
file_modify_time2 = os.path.getmtime(
|
||||
os.path.join(dataset_cache_path2, file_path))
|
||||
assert file_modify_time == file_modify_time2
|
||||
|
||||
# test download to local dir
|
||||
with tempfile.TemporaryDirectory() as temp_local_dir:
|
||||
# first download to cache.
|
||||
dataset_cache_path = dataset_snapshot_download(
|
||||
dataset_id=dataset_id, local_dir=temp_local_dir)
|
||||
# root path is temp_local_dir, file download to local_dir
|
||||
assert dataset_cache_path == temp_local_dir
|
||||
file_modify_time = os.path.getmtime(
|
||||
os.path.join(dataset_cache_path, file_path))
|
||||
assert file_modify_time > start_time
|
||||
# download again, will get cached file.
|
||||
dataset_cache_path2 = dataset_snapshot_download(
|
||||
dataset_id=dataset_id, local_dir=temp_local_dir)
|
||||
file_modify_time2 = os.path.getmtime(
|
||||
os.path.join(dataset_cache_path2, file_path))
|
||||
assert file_modify_time == file_modify_time2
|
||||
|
||||
# test download with wild pattern, ignore_file_pattern
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
# first download to cache.
|
||||
dataset_cache_path = dataset_snapshot_download(
|
||||
dataset_id=dataset_id,
|
||||
cache_dir=temp_cache_dir,
|
||||
ignore_file_pattern='*.jpeg')
|
||||
assert dataset_cache_path == os.path.join(temp_cache_dir,
|
||||
dataset_id)
|
||||
assert not os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
|
||||
assert not os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
|
||||
assert not os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id,
|
||||
'111/222/shijian.jpeg'))
|
||||
assert os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, file_path))
|
||||
|
||||
# test download with wild pattern, allow_file_pattern
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
# first download to cache.
|
||||
dataset_cache_path = dataset_snapshot_download(
|
||||
dataset_id=dataset_id,
|
||||
cache_dir=temp_cache_dir,
|
||||
allow_file_pattern='*.jpeg')
|
||||
assert dataset_cache_path == os.path.join(temp_cache_dir,
|
||||
dataset_id)
|
||||
assert os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
|
||||
assert os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
|
||||
assert os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id,
|
||||
'111/222/shijian.jpeg'))
|
||||
assert not os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, file_path))
|
||||
|
||||
# test download with wild pattern, allow_file_pattern and ignore file pattern.
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
# first download to cache.
|
||||
dataset_cache_path = dataset_snapshot_download(
|
||||
dataset_id=dataset_id,
|
||||
cache_dir=temp_cache_dir,
|
||||
ignore_file_pattern='*.jpeg',
|
||||
allow_file_pattern='*.xxx')
|
||||
assert dataset_cache_path == os.path.join(temp_cache_dir,
|
||||
dataset_id)
|
||||
assert not os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
|
||||
assert not os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
|
||||
assert not os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id,
|
||||
'111/222/shijian.jpeg'))
|
||||
assert not os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, file_path))
|
||||
@@ -12,7 +12,7 @@ from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.file_download import http_get_model_file
|
||||
|
||||
|
||||
class HubOperationTest(unittest.TestCase):
|
||||
class HubRetryTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.api = HubApi()
|
||||
@@ -56,6 +56,8 @@ class HubOperationTest(unittest.TestCase):
|
||||
rsp.msg = HTTPMessage()
|
||||
rsp.read = get_content
|
||||
rsp.chunked = False
|
||||
rsp.length_remaining = 0
|
||||
rsp.headers = {}
|
||||
# retry 2 times and success.
|
||||
getconn_mock.return_value.getresponse.side_effect = [
|
||||
Mock(status=500, msg=HTTPMessage()),
|
||||
@@ -88,16 +90,18 @@ class HubOperationTest(unittest.TestCase):
|
||||
success_rsp = HTTPResponse(getconn_mock)
|
||||
success_rsp.status = 200
|
||||
success_rsp.msg = HTTPMessage()
|
||||
success_rsp.msg.add_header('Content-Length', '2957783')
|
||||
success_rsp.read = get_content
|
||||
success_rsp.chunked = True
|
||||
success_rsp.length_remaining = 0
|
||||
success_rsp.headers = {'Content-Length': '2957783'}
|
||||
|
||||
failed_rsp = HTTPResponse(getconn_mock)
|
||||
failed_rsp.status = 502
|
||||
failed_rsp.msg = HTTPMessage()
|
||||
failed_rsp.msg.add_header('Content-Length', '2957783')
|
||||
failed_rsp.read = get_content
|
||||
failed_rsp.chunked = True
|
||||
success_rsp.length_remaining = 2957783
|
||||
success_rsp.headers = {'Content-Length': '2957783'}
|
||||
|
||||
# retry 5 times and success.
|
||||
getconn_mock.return_value.getresponse.side_effect = [
|
||||
|
||||
Reference in New Issue
Block a user