mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
add dataset download
This commit is contained in:
@@ -34,21 +34,7 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
echo 'cpu unsupport detectron2'; \
|
||||
fi
|
||||
|
||||
# torchmetrics==0.11.4 for ofa
|
||||
# tinycudann for cuda12.1.0 pytorch 2.1.2
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
|
||||
pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
|
||||
pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \
|
||||
pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \
|
||||
pip install --no-cache-dir -U vllm; \
|
||||
else \
|
||||
echo 'cpu unsupport vllm auto-gptq'; \
|
||||
fi
|
||||
|
||||
# install dependencies
|
||||
COPY requirements /var/modelscope
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir -r /var/modelscope/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
@@ -64,5 +50,21 @@ RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip cache purge
|
||||
# 'scipy<1.13.0' for cannot import name 'kaiser' from 'scipy.signal'
|
||||
COPY examples /modelscope/examples
|
||||
# torchmetrics==0.11.4 for ofa
|
||||
# tinycudann for cuda12.1.0 pytorch 2.1.2
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
|
||||
pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
|
||||
pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \
|
||||
pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \
|
||||
pip install --no-cache-dir -U triton vllm https://modelscope.oss-cn-beijing.aliyuncs.com/packages/lmdeploy-0.5.0-cp310-cp310-linux_x86_64.whl; \
|
||||
else \
|
||||
echo 'cpu unsupport vllm auto-gptq'; \
|
||||
fi
|
||||
|
||||
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||
ENV VLLM_USE_MODELSCOPE=True
|
||||
ENV LMDEPLOY_USE_MODELSCOPE=True
|
||||
|
||||
@@ -11,6 +11,8 @@ if TYPE_CHECKING:
|
||||
from .hub.check_model import check_local_model_is_latest, check_model_is_id
|
||||
from .hub.push_to_hub import push_to_hub, push_to_hub_async
|
||||
from .hub.snapshot_download import snapshot_download
|
||||
from .hub.dataset_download import dataset_snapshot_download, dataset_file_download
|
||||
|
||||
from .metrics import (
|
||||
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
|
||||
ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric,
|
||||
@@ -60,6 +62,8 @@ else:
|
||||
],
|
||||
'hub.api': ['HubApi'],
|
||||
'hub.snapshot_download': ['snapshot_download'],
|
||||
'hub.dataset_download':
|
||||
['dataset_snapshot_download', 'dataset_file_download'],
|
||||
'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'],
|
||||
'hub.check_model':
|
||||
['check_model_is_id', 'check_local_model_is_latest'],
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.dataset_download import (dataset_file_download,
|
||||
dataset_snapshot_download)
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
@@ -24,11 +26,17 @@ class DownloadCMD(CLICommand):
|
||||
""" define args for download command.
|
||||
"""
|
||||
parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
|
||||
parser.add_argument(
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The model id to be downloaded.')
|
||||
help='The model id to be downloaded, model or dataset must provide.'
|
||||
)
|
||||
group.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
help=
|
||||
'The dataset id to be downloaded, model or dataset must provide.')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=str,
|
||||
@@ -69,27 +77,55 @@ class DownloadCMD(CLICommand):
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
if len(self.args.files) == 1: # download single file
|
||||
model_file_download(
|
||||
self.args.model,
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision)
|
||||
elif len(self.args.files) > 1: # download specified multiple files.
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
)
|
||||
else: # download repo
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
)
|
||||
if self.args.model is not None:
|
||||
if len(self.args.files) == 1: # download single file
|
||||
model_file_download(
|
||||
self.args.model,
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
)
|
||||
else: # download repo
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
)
|
||||
else:
|
||||
if len(self.args.files) == 1: # download single file
|
||||
dataset_file_download(
|
||||
self.args.dataset,
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
dataset_snapshot_download(
|
||||
self.args.dataset,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
)
|
||||
else: # download repo
|
||||
dataset_snapshot_download(
|
||||
self.args.dataset,
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
)
|
||||
|
||||
@@ -741,7 +741,8 @@ class HubApi:
|
||||
|
||||
recursive = 'True' if recursive else 'False'
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
|
||||
params = {'Revision': revision, 'Root': root_path, 'Recursive': recursive}
|
||||
params = {'Revision': revision if revision else 'master',
|
||||
'Root': root_path if root_path else '/', 'Recursive': recursive}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
r = self.session.get(datahub_url, params=params, cookies=cookies)
|
||||
@@ -771,8 +772,10 @@ class HubApi:
|
||||
@staticmethod
|
||||
def dump_datatype_file(dataset_type: int, meta_cache_dir: str):
|
||||
"""
|
||||
Dump the data_type as a local file, in order to get the dataset formation without calling the datahub.
|
||||
More details, please refer to the class `modelscope.utils.constant.DatasetFormations`.
|
||||
Dump the data_type as a local file, in order to get the dataset
|
||||
formation without calling the datahub.
|
||||
More details, please refer to the class
|
||||
`modelscope.utils.constant.DatasetFormations`.
|
||||
"""
|
||||
dataset_type_file_path = os.path.join(meta_cache_dir,
|
||||
f'{str(dataset_type)}{DatasetFormations.formation_mark_ext.value}')
|
||||
@@ -874,13 +877,14 @@ class HubApi:
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
view: Optional[bool] = False,
|
||||
extension_filter: Optional[bool] = True):
|
||||
|
||||
if not file_name or not dataset_name or not namespace:
|
||||
raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
|
||||
|
||||
# Note: make sure the FilePath is the last parameter in the url
|
||||
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name}
|
||||
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view}
|
||||
params: str = urlencode(params)
|
||||
file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
|
||||
|
||||
@@ -1113,7 +1117,8 @@ class ModelScopeConfig:
|
||||
ModelScopeConfig.cookie_expired_warning = True
|
||||
logger.warning(
|
||||
'Authentication has expired, '
|
||||
'please re-login if you need to access private models or datasets.')
|
||||
'please re-login with modelscope login --token "YOUR_SDK_TOKEN" '
|
||||
'if you need to access private models or datasets.')
|
||||
return None
|
||||
return cookies
|
||||
return None
|
||||
|
||||
328
modelscope/hub/dataset_download.py
Normal file
328
modelscope/hub/dataset_download.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import fnmatch
|
||||
import os
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.errors import NotExistError
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
|
||||
from modelscope.utils.file_utils import get_dataset_cache_root
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
||||
from .file_download import (create_temporary_directory_and_cache,
|
||||
http_get_model_file, parallel_download)
|
||||
from .utils.utils import (file_integrity_validation,
|
||||
model_id_to_group_owner_name)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def dataset_file_download(
|
||||
dataset_id: str,
|
||||
file_path: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
is useful when you want all files from a dataset, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a dataset but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
||||
file_path (str): The relative path of the file to download.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
|
||||
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
dataset_id,
|
||||
local_dir=local_dir,
|
||||
cache_dir=cache_dir,
|
||||
default_cache_root=get_dataset_cache_root())
|
||||
|
||||
if local_files_only:
|
||||
cached_file_path = cache.get_file_by_path(file_path)
|
||||
if cached_file_path is not None:
|
||||
logger.warning(
|
||||
"File exists in local cache, but we're not sure it's up to date"
|
||||
)
|
||||
return cached_file_path
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable dataset look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
else:
|
||||
# make headers
|
||||
headers = {
|
||||
'user-agent':
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
}
|
||||
_api = HubApi()
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
group_or_owner, name = model_id_to_group_owner_name(dataset_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (dataset_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
file_meta_to_download = None
|
||||
# find the file to download.
|
||||
for file_meta in files_list_tree['Data']['Files']:
|
||||
if file_meta['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if fnmatch.fnmatch(file_meta['Path'], file_path):
|
||||
# check file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(file_meta):
|
||||
file_name = os.path.basename(file_meta['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(file_meta)
|
||||
else:
|
||||
file_meta_to_download = file_meta
|
||||
break
|
||||
if file_meta_to_download is None:
|
||||
raise NotExistError('The file path: %s not exist in: %s' %
|
||||
(file_path, dataset_id))
|
||||
|
||||
# start download file.
|
||||
# get download url
|
||||
url = _api.get_dataset_file_url(
|
||||
file_name=file_meta_to_download['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta_to_download[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||
parallel_download(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta_to_download['Name'], # TODO use Path
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_meta_to_download['Size'])
|
||||
else:
|
||||
http_get_model_file(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta_to_download['Name'],
|
||||
file_size=file_meta_to_download['Size'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temporary_cache_dir,
|
||||
file_meta_to_download['Name'])
|
||||
if FILE_HASH in file_meta_to_download:
|
||||
file_integrity_validation(temp_file,
|
||||
file_meta_to_download[FILE_HASH])
|
||||
# put file into to cache
|
||||
return cache.put_file(file_meta_to_download, temp_file)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
dataset_id: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
is useful when you want all files from a dataset, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a dataset but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
|
||||
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be ignored in downloading, like exact file names or file extensions.
|
||||
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be downloading, like exact file names or file extensions.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
dataset_id,
|
||||
local_dir=local_dir,
|
||||
cache_dir=cache_dir,
|
||||
default_cache_root=get_dataset_cache_root())
|
||||
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable dataset look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
logger.warning('We can not confirm the cached file is for revision: %s'
|
||||
% revision)
|
||||
return cache.get_root_location(
|
||||
) # we can not confirm the cached file is for snapshot 'revision'
|
||||
else:
|
||||
# make headers
|
||||
headers = {
|
||||
'user-agent':
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
}
|
||||
_api = HubApi()
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
group_or_owner, name = model_id_to_group_owner_name(dataset_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (dataset_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
|
||||
if ignore_file_pattern is None:
|
||||
ignore_file_pattern = []
|
||||
if isinstance(ignore_file_pattern, str):
|
||||
ignore_file_pattern = [ignore_file_pattern]
|
||||
ignore_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in ignore_file_pattern
|
||||
]
|
||||
|
||||
if allow_file_pattern is not None:
|
||||
if isinstance(allow_file_pattern, str):
|
||||
allow_file_pattern = [allow_file_pattern]
|
||||
allow_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in allow_file_pattern
|
||||
]
|
||||
|
||||
for file_meta in files_list_tree['Data']['Files']:
|
||||
if file_meta['Type'] == 'tree' or \
|
||||
any(fnmatch.fnmatch(file_meta['Path'], pattern) for pattern in ignore_file_pattern):
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(file_meta['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
|
||||
# check file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(file_meta):
|
||||
file_name = os.path.basename(file_meta['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
|
||||
# get download url
|
||||
url = _api.get_dataset_file_url(
|
||||
file_name=file_meta['Path'],
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||
parallel_download(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta['Name'], # TODO use Path
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_meta['Size'])
|
||||
else:
|
||||
http_get_model_file(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta['Name'],
|
||||
file_size=file_meta['Size'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temporary_cache_dir, file_meta['Name'])
|
||||
if FILE_HASH in file_meta:
|
||||
file_integrity_validation(temp_file, file_meta[FILE_HASH])
|
||||
# put file into to cache
|
||||
cache.put_file(file_meta, temp_file)
|
||||
|
||||
cache.save_model_version(revision_info=revision)
|
||||
return os.path.join(cache.get_root_location())
|
||||
@@ -153,9 +153,9 @@ def datahub_raise_on_error(url, rsp, http_response: requests.Response):
|
||||
if rsp.get('Code') == HTTPStatus.OK:
|
||||
return True
|
||||
else:
|
||||
request_id = get_request_id(http_response)
|
||||
request_id = rsp['RequestId']
|
||||
raise RequestError(
|
||||
f"Url = {url}, Request id={request_id} Message = {rsp.get('Message')},\
|
||||
f"Url = {url}, Request id={request_id} Code = {rsp['Code']} Message = {rsp['Message']},\
|
||||
Please specify correct dataset_name and namespace.")
|
||||
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ def model_file_download(
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
model_id, local_dir, cache_dir)
|
||||
model_id, local_dir, cache_dir, get_model_cache_root())
|
||||
|
||||
# if local_files_only is `True` and the file already exists in cached_path
|
||||
# return the cached path
|
||||
@@ -163,14 +163,15 @@ def model_file_download(
|
||||
|
||||
|
||||
def create_temporary_directory_and_cache(model_id: str, local_dir: str,
|
||||
cache_dir: str):
|
||||
cache_dir: str,
|
||||
default_cache_root: str):
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
if local_dir is not None:
|
||||
temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
|
||||
cache = ModelFileSystemCache(local_dir)
|
||||
else:
|
||||
if cache_dir is None:
|
||||
cache_dir = get_model_cache_root()
|
||||
cache_dir = default_cache_root
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME,
|
||||
@@ -269,6 +270,7 @@ def parallel_download(
|
||||
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
||||
tasks = []
|
||||
file_path = os.path.join(local_dir, file_name)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
for idx in range(int(file_size / PART_SIZE)):
|
||||
start = idx * PART_SIZE
|
||||
end = (idx + 1) * PART_SIZE - 1
|
||||
@@ -323,6 +325,7 @@ def http_get_model_file(
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
temp_file_path = os.path.join(local_dir, file_name)
|
||||
os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
|
||||
logger.debug('downloading %s to %s', url, temp_file_path)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
@@ -349,7 +352,7 @@ def http_get_model_file(
|
||||
break
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||
file_size - 1)
|
||||
with open(temp_file_path, 'ab') as f:
|
||||
with open(temp_file_path, 'ab+') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.file_utils import get_model_cache_root
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
||||
@@ -71,7 +72,7 @@ def snapshot_download(
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
||||
model_id, local_dir, cache_dir)
|
||||
model_id, local_dir, cache_dir, get_model_cache_root())
|
||||
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
|
||||
@@ -164,9 +164,12 @@ class ModelFileSystemCache(FileSystemCache):
|
||||
model_version_file_path = os.path.join(
|
||||
self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
|
||||
with open(model_version_file_path, 'w') as f:
|
||||
version_info_str = 'Revision:%s,CreatedAt:%s' % (
|
||||
revision_info['Revision'], revision_info['CreatedAt'])
|
||||
f.write(version_info_str)
|
||||
if isinstance(revision_info, dict):
|
||||
version_info_str = 'Revision:%s,CreatedAt:%s' % (
|
||||
revision_info['Revision'], revision_info['CreatedAt'])
|
||||
f.write(version_info_str)
|
||||
else:
|
||||
f.write(revision_info)
|
||||
|
||||
def get_model_id(self):
|
||||
return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID]
|
||||
|
||||
@@ -5,6 +5,8 @@ import enum
|
||||
class Fields(object):
|
||||
""" Names for different application fields
|
||||
"""
|
||||
hub = 'hub'
|
||||
datasets = 'datasets'
|
||||
framework = 'framework'
|
||||
cv = 'cv'
|
||||
nlp = 'nlp'
|
||||
|
||||
@@ -53,11 +53,35 @@ def get_model_cache_root() -> str:
|
||||
"""Get model cache root path.
|
||||
|
||||
Returns:
|
||||
str: the modelscope cache root.
|
||||
str: the model cache root.
|
||||
"""
|
||||
return os.path.join(get_modelscope_cache_dir(), 'hub')
|
||||
|
||||
|
||||
def get_dataset_cache_root() -> str:
|
||||
"""Get dataset raw file cache root path.
|
||||
|
||||
Returns:
|
||||
str: the dataset raw file cache root.
|
||||
"""
|
||||
return os.path.join(get_modelscope_cache_dir(), 'datasets')
|
||||
|
||||
|
||||
def get_dataset_cache_dir(dataset_id: str) -> str:
|
||||
"""Get the dataset_id's path.
|
||||
dataset_cache_root/dataset_id.
|
||||
|
||||
Args:
|
||||
dataset_id (str): The dataset id.
|
||||
|
||||
Returns:
|
||||
str: The dataset_id's cache root path.
|
||||
"""
|
||||
dataset_root = get_dataset_cache_root()
|
||||
return dataset_root if dataset_id is None else os.path.join(
|
||||
dataset_root, dataset_id + '/')
|
||||
|
||||
|
||||
def get_model_cache_dir(model_id: str) -> str:
|
||||
"""cache dir precedence:
|
||||
function parameter > environment > ~/.cache/modelscope/hub/model_id
|
||||
|
||||
@@ -44,10 +44,11 @@ opencv-python
|
||||
paint_ldm
|
||||
pandas
|
||||
panopticapi
|
||||
Pillow>=6.2.0
|
||||
plyfile>=0.7.4
|
||||
psutil
|
||||
pyclipper
|
||||
PyMCubes
|
||||
PyMCubes<=0.1.4
|
||||
pytorch-lightning
|
||||
regex
|
||||
# <0.20.0 for compatible python3.7 python3.8
|
||||
|
||||
12
requirements/datasets.txt
Normal file
12
requirements/datasets.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.16.0,<2.19.0
|
||||
einops
|
||||
oss2
|
||||
python-dateutil>=2.1
|
||||
scipy
|
||||
# latest version has some compatible issue.
|
||||
setuptools==69.5.1
|
||||
simplejson>=3.3.0
|
||||
sortedcontainers>=1.5.9
|
||||
urllib3>=1.26
|
||||
@@ -2,21 +2,12 @@ addict
|
||||
attrs
|
||||
datasets>=2.16.0,<2.19.0
|
||||
einops
|
||||
filelock>=3.3.0
|
||||
huggingface_hub
|
||||
numpy
|
||||
oss2
|
||||
pandas
|
||||
Pillow>=6.2.0
|
||||
# pyarrow 9.0.0 introduced event_loop core dump
|
||||
pyarrow>=6.0.0,!=9.0.0
|
||||
python-dateutil>=2.1
|
||||
pyyaml
|
||||
requests>=2.25
|
||||
scipy
|
||||
setuptools
|
||||
# latest version has some compatible issue.
|
||||
setuptools==69.5.1
|
||||
simplejson>=3.3.0
|
||||
sortedcontainers>=1.5.9
|
||||
tqdm>=4.64.0
|
||||
transformers
|
||||
urllib3>=1.26
|
||||
yapf
|
||||
|
||||
2
setup.py
2
setup.py
@@ -196,7 +196,7 @@ if __name__ == '__main__':
|
||||
# add framework dependencies to every field
|
||||
for field, requires in extra_requires.items():
|
||||
if field not in [
|
||||
'server', 'framework'
|
||||
'server', 'framework', 'hub', 'datasets'
|
||||
]: # server need install model's field dependencies before.
|
||||
extra_requires[field] = framework_requires + extra_requires[field]
|
||||
extra_requires['all'] = all_requires
|
||||
|
||||
111
tests/hub/test_download_dataset_file.py
Normal file
111
tests/hub/test_download_dataset_file.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.dataset_download import (dataset_file_download,
|
||||
dataset_snapshot_download)
|
||||
|
||||
|
||||
class DownloadDatasetTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_dataset_file_download(self):
|
||||
dataset_id = 'citest/test_dataset_download'
|
||||
file_path = 'open_qa.jsonl'
|
||||
deep_file_path = '111/222/333/shijian.jpeg'
|
||||
start_time = time.time()
|
||||
|
||||
# test download to cache dir.
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
# first download to cache.
|
||||
cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=file_path,
|
||||
cache_dir=temp_cache_dir)
|
||||
file_modify_time = os.path.getmtime(cache_file_path)
|
||||
print(cache_file_path)
|
||||
assert cache_file_path == os.path.join(temp_cache_dir, dataset_id,
|
||||
file_path)
|
||||
assert file_modify_time > start_time
|
||||
# download again, will get cached file.
|
||||
cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=file_path,
|
||||
cache_dir=temp_cache_dir)
|
||||
file_modify_time2 = os.path.getmtime(cache_file_path)
|
||||
assert file_modify_time == file_modify_time2
|
||||
|
||||
deep_cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=deep_file_path,
|
||||
cache_dir=temp_cache_dir)
|
||||
deep_file_cath_path = os.path.join(temp_cache_dir, dataset_id,
|
||||
deep_file_path)
|
||||
assert deep_cache_file_path == deep_file_cath_path
|
||||
os.path.exists(deep_cache_file_path)
|
||||
|
||||
# test download to local dir
|
||||
with tempfile.TemporaryDirectory() as temp_local_dir:
|
||||
# first download to cache.
|
||||
cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=file_path,
|
||||
local_dir=temp_local_dir)
|
||||
assert cache_file_path == os.path.join(temp_local_dir, file_path)
|
||||
file_modify_time = os.path.getmtime(cache_file_path)
|
||||
assert file_modify_time > start_time
|
||||
# download again, will get cached file.
|
||||
cache_file_path = dataset_file_download(
|
||||
dataset_id=dataset_id,
|
||||
file_path=file_path,
|
||||
local_dir=temp_local_dir)
|
||||
file_modify_time2 = os.path.getmtime(cache_file_path)
|
||||
assert file_modify_time == file_modify_time2
|
||||
|
||||
def test_dataset_snapshot_download(self):
|
||||
dataset_id = 'citest/test_dataset_download'
|
||||
file_path = 'open_qa.jsonl'
|
||||
deep_file_path = '111/222/333/shijian.jpeg'
|
||||
start_time = time.time()
|
||||
|
||||
# test download to cache dir.
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
# first download to cache.
|
||||
dataset_cache_path = dataset_snapshot_download(
|
||||
dataset_id=dataset_id, cache_dir=temp_cache_dir)
|
||||
file_modify_time = os.path.getmtime(
|
||||
os.path.join(dataset_cache_path, file_path))
|
||||
assert dataset_cache_path == os.path.join(temp_cache_dir,
|
||||
dataset_id)
|
||||
assert file_modify_time > start_time
|
||||
assert os.path.exists(
|
||||
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
|
||||
|
||||
# download again, will get cached file.
|
||||
dataset_cache_path2 = dataset_snapshot_download(
|
||||
dataset_id=dataset_id, cache_dir=temp_cache_dir)
|
||||
file_modify_time2 = os.path.getmtime(
|
||||
os.path.join(dataset_cache_path2, file_path))
|
||||
assert file_modify_time == file_modify_time2
|
||||
|
||||
# test download to local dir
|
||||
with tempfile.TemporaryDirectory() as temp_local_dir:
|
||||
# first download to cache.
|
||||
dataset_cache_path = dataset_snapshot_download(
|
||||
dataset_id=dataset_id, local_dir=temp_local_dir)
|
||||
# root path is temp_local_dir, file download to local_dir
|
||||
assert dataset_cache_path == temp_local_dir
|
||||
file_modify_time = os.path.getmtime(
|
||||
os.path.join(dataset_cache_path, file_path))
|
||||
assert file_modify_time > start_time
|
||||
# download again, will get cached file.
|
||||
dataset_cache_path2 = dataset_snapshot_download(
|
||||
dataset_id=dataset_id, local_dir=temp_local_dir)
|
||||
file_modify_time2 = os.path.getmtime(
|
||||
os.path.join(dataset_cache_path2, file_path))
|
||||
assert file_modify_time == file_modify_time2
|
||||
Reference in New Issue
Block a user