add dataset download

This commit is contained in:
mulin.lyh
2024-07-11 21:01:32 +08:00
parent 53acf79d04
commit 9a656a149e
16 changed files with 595 additions and 72 deletions

View File

@@ -34,21 +34,7 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
echo 'cpu unsupport detectron2'; \
fi
# torchmetrics==0.11.4 for ofa
# tinycudann for cuda12.1.0 pytorch 2.1.2
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \
pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \
pip install --no-cache-dir -U vllm; \
else \
echo 'cpu unsupport vllm auto-gptq'; \
fi
# install dependencies
COPY requirements /var/modelscope
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r /var/modelscope/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
@@ -64,5 +50,21 @@ RUN pip install --no-cache-dir --upgrade pip && \
pip cache purge
# 'scipy<1.13.0' for cannot import name 'kaiser' from 'scipy.signal'
COPY examples /modelscope/examples
# torchmetrics==0.11.4 for ofa
# tinycudann for cuda12.1.0 pytorch 2.1.2
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
pip install --no-cache-dir flash_attn==2.5.9.post1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
pip install --no-cache-dir -U 'xformers' --index-url https://download.pytorch.org/whl/cu121 && \
pip install --no-cache-dir --force tinycudann==1.7 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip uninstall -y torch-scatter && TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.9;9.0" pip install --no-cache-dir -U torch-scatter && \
pip install --no-cache-dir -U triton vllm https://modelscope.oss-cn-beijing.aliyuncs.com/packages/lmdeploy-0.5.0-cp310-cp310-linux_x86_64.whl; \
else \
echo 'cpu unsupport vllm auto-gptq'; \
fi
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
ENV VLLM_USE_MODELSCOPE=True
ENV LMDEPLOY_USE_MODELSCOPE=True

View File

@@ -11,6 +11,8 @@ if TYPE_CHECKING:
from .hub.check_model import check_local_model_is_latest, check_model_is_id
from .hub.push_to_hub import push_to_hub, push_to_hub_async
from .hub.snapshot_download import snapshot_download
from .hub.dataset_download import dataset_snapshot_download, dataset_file_download
from .metrics import (
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric,
@@ -60,6 +62,8 @@ else:
],
'hub.api': ['HubApi'],
'hub.snapshot_download': ['snapshot_download'],
'hub.dataset_download':
['dataset_snapshot_download', 'dataset_file_download'],
'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'],
'hub.check_model':
['check_model_is_id', 'check_local_model_is_latest'],

View File

@@ -3,6 +3,8 @@
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.hub.dataset_download import (dataset_file_download,
dataset_snapshot_download)
from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download
@@ -24,11 +26,17 @@ class DownloadCMD(CLICommand):
""" define args for download command.
"""
parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
parser.add_argument(
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'--model',
type=str,
required=True,
help='The model id to be downloaded.')
help='The model id to be downloaded, model or dataset must provide.'
)
group.add_argument(
'--dataset',
type=str,
help=
'The dataset id to be downloaded, model or dataset must provide.')
parser.add_argument(
'--revision',
type=str,
@@ -69,27 +77,55 @@ class DownloadCMD(CLICommand):
parser.set_defaults(func=subparser_func)
def execute(self):
if len(self.args.files) == 1: # download single file
model_file_download(
self.args.model,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
elif len(self.args.files) > 1: # download specified multiple files.
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
)
else: # download repo
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)
if self.args.model is not None:
if len(self.args.files) == 1: # download single file
model_file_download(
self.args.model,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
elif len(
self.args.files) > 1: # download specified multiple files.
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
)
else: # download repo
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)
else:
if len(self.args.files) == 1: # download single file
dataset_file_download(
self.args.dataset,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
elif len(
self.args.files) > 1: # download specified multiple files.
dataset_snapshot_download(
self.args.dataset,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
)
else: # download repo
dataset_snapshot_download(
self.args.dataset,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)

View File

@@ -741,7 +741,8 @@ class HubApi:
recursive = 'True' if recursive else 'False'
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
params = {'Revision': revision, 'Root': root_path, 'Recursive': recursive}
params = {'Revision': revision if revision else 'master',
'Root': root_path if root_path else '/', 'Recursive': recursive}
cookies = ModelScopeConfig.get_cookies()
r = self.session.get(datahub_url, params=params, cookies=cookies)
@@ -771,8 +772,10 @@ class HubApi:
@staticmethod
def dump_datatype_file(dataset_type: int, meta_cache_dir: str):
"""
Dump the data_type as a local file, in order to get the dataset formation without calling the datahub.
More details, please refer to the class `modelscope.utils.constant.DatasetFormations`.
Dump the data_type as a local file, in order to get the dataset
formation without calling the datahub.
More details, please refer to the class
`modelscope.utils.constant.DatasetFormations`.
"""
dataset_type_file_path = os.path.join(meta_cache_dir,
f'{str(dataset_type)}{DatasetFormations.formation_mark_ext.value}')
@@ -874,13 +877,14 @@ class HubApi:
dataset_name: str,
namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
view: Optional[bool] = False,
extension_filter: Optional[bool] = True):
if not file_name or not dataset_name or not namespace:
raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
# Note: make sure the FilePath is the last parameter in the url
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name}
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view}
params: str = urlencode(params)
file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
@@ -1113,7 +1117,8 @@ class ModelScopeConfig:
ModelScopeConfig.cookie_expired_warning = True
logger.warning(
'Authentication has expired, '
'please re-login if you need to access private models or datasets.')
'please re-login with modelscope login --token "YOUR_SDK_TOKEN" '
'if you need to access private models or datasets.')
return None
return cookies
return None

View File

@@ -0,0 +1,328 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import fnmatch
import os
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import NotExistError
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
from modelscope.utils.file_utils import get_dataset_cache_root
from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from .file_download import (create_temporary_directory_and_cache,
http_get_model_file, parallel_download)
from .utils.utils import (file_integrity_validation,
model_id_to_group_owner_name)
logger = get_logger()
def dataset_file_download(
dataset_id: str,
file_path: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
cache_dir: Union[str, Path, None] = None,
local_dir: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
is useful when you want all files from a dataset, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a dataset but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
file_path (str): The relative path of the file to download.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
be save as cache_dir/dataset_id/THE_DATASET_FILES.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
Raises:
ValueError: the value details.
Returns:
str: Local folder path (string) of repo snapshot
Note:
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
temporary_cache_dir, cache = create_temporary_directory_and_cache(
dataset_id,
local_dir=local_dir,
cache_dir=cache_dir,
default_cache_root=get_dataset_cache_root())
if local_files_only:
cached_file_path = cache.get_file_by_path(file_path)
if cached_file_path is not None:
logger.warning(
"File exists in local cache, but we're not sure it's up to date"
)
return cached_file_path
else:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable dataset look-ups and downloads'
" online, set 'local_files_only' to False.")
else:
# make headers
headers = {
'user-agent':
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
}
_api = HubApi()
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
group_or_owner, name = model_id_to_group_owner_name(dataset_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (dataset_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
file_meta_to_download = None
# find the file to download.
for file_meta in files_list_tree['Data']['Files']:
if file_meta['Type'] == 'tree':
continue
if fnmatch.fnmatch(file_meta['Path'], file_path):
# check file is exist in cache, if existed, skip download, otherwise download
if cache.exists(file_meta):
file_name = os.path.basename(file_meta['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!'
)
return cache.get_file_by_info(file_meta)
else:
file_meta_to_download = file_meta
break
if file_meta_to_download is None:
raise NotExistError('The file path: %s not exist in: %s' %
(file_path, dataset_id))
# start download file.
# get download url
url = _api.get_dataset_file_url(
file_name=file_meta_to_download['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta_to_download[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
parallel_download(
url,
temporary_cache_dir,
file_meta_to_download['Name'], # TODO use Path
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta_to_download['Size'])
else:
http_get_model_file(
url,
temporary_cache_dir,
file_meta_to_download['Name'],
file_size=file_meta_to_download['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir,
file_meta_to_download['Name'])
if FILE_HASH in file_meta_to_download:
file_integrity_validation(temp_file,
file_meta_to_download[FILE_HASH])
# put file into to cache
return cache.put_file(file_meta_to_download, temp_file)
def dataset_snapshot_download(
dataset_id: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION,
cache_dir: Union[str, Path, None] = None,
local_dir: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
is useful when you want all files from a dataset, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a dataset but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
be save as cache_dir/dataset_id/THE_DATASET_FILES.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be downloading, like exact file names or file extensions.
Raises:
ValueError: the value details.
Returns:
str: Local folder path (string) of repo snapshot
Note:
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
temporary_cache_dir, cache = create_temporary_directory_and_cache(
dataset_id,
local_dir=local_dir,
cache_dir=cache_dir,
default_cache_root=get_dataset_cache_root())
if local_files_only:
if len(cache.cached_files) == 0:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable dataset look-ups and downloads'
" online, set 'local_files_only' to False.")
logger.warning('We can not confirm the cached file is for revision: %s'
% revision)
return cache.get_root_location(
) # we can not confirm the cached file is for snapshot 'revision'
else:
# make headers
headers = {
'user-agent':
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
}
_api = HubApi()
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
group_or_owner, name = model_id_to_group_owner_name(dataset_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
files_list_tree = _api.list_repo_tree(
dataset_name=name,
namespace=group_or_owner,
revision=revision,
root_path='/',
recursive=True)
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
print(
'Get dataset: %s file list failed, request_id: %s, message: %s'
% (dataset_id, files_list_tree['RequestId'],
files_list_tree['Message']))
return None
if ignore_file_pattern is None:
ignore_file_pattern = []
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
ignore_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in ignore_file_pattern
]
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
allow_file_pattern = [allow_file_pattern]
allow_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in allow_file_pattern
]
for file_meta in files_list_tree['Data']['Files']:
if file_meta['Type'] == 'tree' or \
any(fnmatch.fnmatch(file_meta['Path'], pattern) for pattern in ignore_file_pattern):
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(file_meta['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check file is exist in cache, if existed, skip download, otherwise download
if cache.exists(file_meta):
file_name = os.path.basename(file_meta['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
continue
# get download url
url = _api.get_dataset_file_url(
file_name=file_meta['Path'],
dataset_name=name,
namespace=group_or_owner,
revision=revision)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
parallel_download(
url,
temporary_cache_dir,
file_meta['Name'], # TODO use Path
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'])
else:
http_get_model_file(
url,
temporary_cache_dir,
file_meta['Name'],
file_size=file_meta['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, file_meta['Name'])
if FILE_HASH in file_meta:
file_integrity_validation(temp_file, file_meta[FILE_HASH])
# put file into to cache
cache.put_file(file_meta, temp_file)
cache.save_model_version(revision_info=revision)
return os.path.join(cache.get_root_location())

View File

@@ -153,9 +153,9 @@ def datahub_raise_on_error(url, rsp, http_response: requests.Response):
if rsp.get('Code') == HTTPStatus.OK:
return True
else:
request_id = get_request_id(http_response)
request_id = rsp['RequestId']
raise RequestError(
f"Url = {url}, Request id={request_id} Message = {rsp.get('Message')},\
f"Url = {url}, Request id={request_id} Code = {rsp['Code']} Message = {rsp['Message']},\
Please specify correct dataset_name and namespace.")

View File

@@ -79,7 +79,7 @@ def model_file_download(
if some parameter value is invalid
"""
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir)
model_id, local_dir, cache_dir, get_model_cache_root())
# if local_files_only is `True` and the file already exists in cached_path
# return the cached path
@@ -163,14 +163,15 @@ def model_file_download(
def create_temporary_directory_and_cache(model_id: str, local_dir: str,
cache_dir: str):
cache_dir: str,
default_cache_root: str):
group_or_owner, name = model_id_to_group_owner_name(model_id)
if local_dir is not None:
temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
cache = ModelFileSystemCache(local_dir)
else:
if cache_dir is None:
cache_dir = get_model_cache_root()
cache_dir = default_cache_root
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME,
@@ -269,6 +270,7 @@ def parallel_download(
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
tasks = []
file_path = os.path.join(local_dir, file_name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
for idx in range(int(file_size / PART_SIZE)):
start = idx * PART_SIZE
end = (idx + 1) * PART_SIZE - 1
@@ -323,6 +325,7 @@ def http_get_model_file(
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
temp_file_path = os.path.join(local_dir, file_name)
os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
logger.debug('downloading %s to %s', url, temp_file_path)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
@@ -349,7 +352,7 @@ def http_get_model_file(
break
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
file_size - 1)
with open(temp_file_path, 'ab') as f:
with open(temp_file_path, 'ab+') as f:
r = requests.get(
url,
stream=True,

View File

@@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
@@ -71,7 +72,7 @@ def snapshot_download(
if some parameter value is invalid
"""
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir)
model_id, local_dir, cache_dir, get_model_cache_root())
if local_files_only:
if len(cache.cached_files) == 0:

View File

@@ -164,9 +164,12 @@ class ModelFileSystemCache(FileSystemCache):
model_version_file_path = os.path.join(
self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
with open(model_version_file_path, 'w') as f:
version_info_str = 'Revision:%s,CreatedAt:%s' % (
revision_info['Revision'], revision_info['CreatedAt'])
f.write(version_info_str)
if isinstance(revision_info, dict):
version_info_str = 'Revision:%s,CreatedAt:%s' % (
revision_info['Revision'], revision_info['CreatedAt'])
f.write(version_info_str)
else:
f.write(revision_info)
def get_model_id(self):
return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID]

View File

@@ -5,6 +5,8 @@ import enum
class Fields(object):
""" Names for different application fields
"""
hub = 'hub'
datasets = 'datasets'
framework = 'framework'
cv = 'cv'
nlp = 'nlp'

View File

@@ -53,11 +53,35 @@ def get_model_cache_root() -> str:
"""Get model cache root path.
Returns:
str: the modelscope cache root.
str: the model cache root.
"""
return os.path.join(get_modelscope_cache_dir(), 'hub')
def get_dataset_cache_root() -> str:
"""Get dataset raw file cache root path.
Returns:
str: the dataset raw file cache root.
"""
return os.path.join(get_modelscope_cache_dir(), 'datasets')
def get_dataset_cache_dir(dataset_id: str) -> str:
"""Get the dataset_id's path.
dataset_cache_root/dataset_id.
Args:
dataset_id (str): The dataset id.
Returns:
str: The dataset_id's cache root path.
"""
dataset_root = get_dataset_cache_root()
return dataset_root if dataset_id is None else os.path.join(
dataset_root, dataset_id + '/')
def get_model_cache_dir(model_id: str) -> str:
"""cache dir precedence:
function parameter > environment > ~/.cache/modelscope/hub/model_id

View File

@@ -44,10 +44,11 @@ opencv-python
paint_ldm
pandas
panopticapi
Pillow>=6.2.0
plyfile>=0.7.4
psutil
pyclipper
PyMCubes
PyMCubes<=0.1.4
pytorch-lightning
regex
# <0.20.0 for compatible python3.7 python3.8

12
requirements/datasets.txt Normal file
View File

@@ -0,0 +1,12 @@
addict
attrs
datasets>=2.16.0,<2.19.0
einops
oss2
python-dateutil>=2.1
scipy
# latest version has some compatible issue.
setuptools==69.5.1
simplejson>=3.3.0
sortedcontainers>=1.5.9
urllib3>=1.26

View File

@@ -2,21 +2,12 @@ addict
attrs
datasets>=2.16.0,<2.19.0
einops
filelock>=3.3.0
huggingface_hub
numpy
oss2
pandas
Pillow>=6.2.0
# pyarrow 9.0.0 introduced event_loop core dump
pyarrow>=6.0.0,!=9.0.0
python-dateutil>=2.1
pyyaml
requests>=2.25
scipy
setuptools
# latest version has some compatible issue.
setuptools==69.5.1
simplejson>=3.3.0
sortedcontainers>=1.5.9
tqdm>=4.64.0
transformers
urllib3>=1.26
yapf

View File

@@ -196,7 +196,7 @@ if __name__ == '__main__':
# add framework dependencies to every field
for field, requires in extra_requires.items():
if field not in [
'server', 'framework'
'server', 'framework', 'hub', 'datasets'
]: # server need install model's field dependencies before.
extra_requires[field] = framework_requires + extra_requires[field]
extra_requires['all'] = all_requires

View File

@@ -0,0 +1,111 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
import time
import unittest
from modelscope.hub.dataset_download import (dataset_file_download,
dataset_snapshot_download)
class DownloadDatasetTest(unittest.TestCase):
def setUp(self):
pass
def test_dataset_file_download(self):
dataset_id = 'citest/test_dataset_download'
file_path = 'open_qa.jsonl'
deep_file_path = '111/222/333/shijian.jpeg'
start_time = time.time()
# test download to cache dir.
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=file_path,
cache_dir=temp_cache_dir)
file_modify_time = os.path.getmtime(cache_file_path)
print(cache_file_path)
assert cache_file_path == os.path.join(temp_cache_dir, dataset_id,
file_path)
assert file_modify_time > start_time
# download again, will get cached file.
cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=file_path,
cache_dir=temp_cache_dir)
file_modify_time2 = os.path.getmtime(cache_file_path)
assert file_modify_time == file_modify_time2
deep_cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=deep_file_path,
cache_dir=temp_cache_dir)
deep_file_cath_path = os.path.join(temp_cache_dir, dataset_id,
deep_file_path)
assert deep_cache_file_path == deep_file_cath_path
os.path.exists(deep_cache_file_path)
# test download to local dir
with tempfile.TemporaryDirectory() as temp_local_dir:
# first download to cache.
cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=file_path,
local_dir=temp_local_dir)
assert cache_file_path == os.path.join(temp_local_dir, file_path)
file_modify_time = os.path.getmtime(cache_file_path)
assert file_modify_time > start_time
# download again, will get cached file.
cache_file_path = dataset_file_download(
dataset_id=dataset_id,
file_path=file_path,
local_dir=temp_local_dir)
file_modify_time2 = os.path.getmtime(cache_file_path)
assert file_modify_time == file_modify_time2
def test_dataset_snapshot_download(self):
dataset_id = 'citest/test_dataset_download'
file_path = 'open_qa.jsonl'
deep_file_path = '111/222/333/shijian.jpeg'
start_time = time.time()
# test download to cache dir.
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id, cache_dir=temp_cache_dir)
file_modify_time = os.path.getmtime(
os.path.join(dataset_cache_path, file_path))
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert file_modify_time > start_time
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
# download again, will get cached file.
dataset_cache_path2 = dataset_snapshot_download(
dataset_id=dataset_id, cache_dir=temp_cache_dir)
file_modify_time2 = os.path.getmtime(
os.path.join(dataset_cache_path2, file_path))
assert file_modify_time == file_modify_time2
# test download to local dir
with tempfile.TemporaryDirectory() as temp_local_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id, local_dir=temp_local_dir)
# root path is temp_local_dir, file download to local_dir
assert dataset_cache_path == temp_local_dir
file_modify_time = os.path.getmtime(
os.path.join(dataset_cache_path, file_path))
assert file_modify_time > start_time
# download again, will get cached file.
dataset_cache_path2 = dataset_snapshot_download(
dataset_id=dataset_id, local_dir=temp_local_dir)
file_modify_time2 = os.path.getmtime(
os.path.join(dataset_cache_path2, file_path))
assert file_modify_time == file_modify_time2