fix cr issue

This commit is contained in:
mulin.lyh
2024-07-12 10:05:57 +08:00
parent 9a656a149e
commit aa083a291c
3 changed files with 68 additions and 8 deletions

View File

@@ -30,13 +30,13 @@ class DownloadCMD(CLICommand):
group.add_argument(
'--model',
type=str,
help='The model id to be downloaded, model or dataset must provide.'
)
help='The id of the model to be downloaded. For download, '
'the id of either a model or dataset must be provided.')
group.add_argument(
'--dataset',
type=str,
help=
'The dataset id to be downloaded, model or dataset must provide.')
help='The id of the dataset to be downloaded. For download, '
'the id of either a model or dataset must be provided.')
parser.add_argument(
'--revision',
type=str,
@@ -77,7 +77,7 @@ class DownloadCMD(CLICommand):
parser.set_defaults(func=subparser_func)
def execute(self):
if self.args.model is not None:
if self.args.model:
if len(self.args.files) == 1: # download single file
model_file_download(
self.args.model,
@@ -103,7 +103,7 @@ class DownloadCMD(CLICommand):
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)
else:
elif self.args.dataset:
if len(self.args.files) == 1: # download single file
dataset_file_download(
self.args.dataset,
@@ -129,3 +129,5 @@ class DownloadCMD(CLICommand):
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)
else:
pass # noop

View File

@@ -53,7 +53,7 @@ def get_model_cache_root() -> str:
"""Get model cache root path.
Returns:
str: the model cache root.
str: the modelscope model cache root.
"""
return os.path.join(get_modelscope_cache_dir(), 'hub')
@@ -62,7 +62,7 @@ def get_dataset_cache_root() -> str:
"""Get dataset raw file cache root path.
Returns:
str: the dataset raw file cache root.
str: the modelscope dataset raw file cache root.
"""
return os.path.join(get_modelscope_cache_dir(), 'datasets')

View File

@@ -109,3 +109,61 @@ class DownloadDatasetTest(unittest.TestCase):
file_modify_time2 = os.path.getmtime(
os.path.join(dataset_cache_path2, file_path))
assert file_modify_time == file_modify_time2
# test download with wild pattern, ignore_file_pattern
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id,
cache_dir=temp_cache_dir,
ignore_file_pattern='*.jpeg')
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id,
'111/222/shijian.jpeg'))
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id, file_path))
# test download with wild pattern, allow_file_pattern
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id,
cache_dir=temp_cache_dir,
allow_file_pattern='*.jpeg')
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
assert os.path.exists(
os.path.join(temp_cache_dir, dataset_id,
'111/222/shijian.jpeg'))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, file_path))
# test download with wild pattern, allow_file_pattern and ignore file pattern.
with tempfile.TemporaryDirectory() as temp_cache_dir:
# first download to cache.
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id,
cache_dir=temp_cache_dir,
ignore_file_pattern='*.jpeg',
allow_file_pattern='*.xxx')
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, deep_file_path))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg'))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id,
'111/222/shijian.jpeg'))
assert not os.path.exists(
os.path.join(temp_cache_dir, dataset_id, file_path))