diff --git a/modelscope/cli/download.py b/modelscope/cli/download.py index ba313a14..5839bdb9 100644 --- a/modelscope/cli/download.py +++ b/modelscope/cli/download.py @@ -30,13 +30,13 @@ class DownloadCMD(CLICommand): group.add_argument( '--model', type=str, - help='The model id to be downloaded, model or dataset must provide.' - ) + help='The id of the model to be downloaded. For download, ' + 'the id of either a model or dataset must be provided.') group.add_argument( '--dataset', type=str, - help= - 'The dataset id to be downloaded, model or dataset must provide.') + help='The id of the dataset to be downloaded. For download, ' + 'the id of either a model or dataset must be provided.') parser.add_argument( '--revision', type=str, @@ -77,7 +77,7 @@ class DownloadCMD(CLICommand): parser.set_defaults(func=subparser_func) def execute(self): - if self.args.model is not None: + if self.args.model: if len(self.args.files) == 1: # download single file model_file_download( self.args.model, @@ -103,7 +103,7 @@ class DownloadCMD(CLICommand): allow_file_pattern=self.args.include, ignore_file_pattern=self.args.exclude, ) - else: + elif self.args.dataset: if len(self.args.files) == 1: # download single file dataset_file_download( self.args.dataset, @@ -129,3 +129,5 @@ class DownloadCMD(CLICommand): allow_file_pattern=self.args.include, ignore_file_pattern=self.args.exclude, ) + else: + pass # noop diff --git a/modelscope/utils/file_utils.py b/modelscope/utils/file_utils.py index 845d936b..c00e8d26 100644 --- a/modelscope/utils/file_utils.py +++ b/modelscope/utils/file_utils.py @@ -53,7 +53,7 @@ def get_model_cache_root() -> str: """Get model cache root path. Returns: - str: the model cache root. + str: the modelscope model cache root. """ return os.path.join(get_modelscope_cache_dir(), 'hub') @@ -62,7 +62,7 @@ def get_dataset_cache_root() -> str: """Get dataset raw file cache root path. Returns: - str: the dataset raw file cache root. + str: the modelscope dataset raw file cache root. """ return os.path.join(get_modelscope_cache_dir(), 'datasets') diff --git a/tests/hub/test_download_dataset_file.py b/tests/hub/test_download_dataset_file.py index 6fc14e38..49d8c238 100644 --- a/tests/hub/test_download_dataset_file.py +++ b/tests/hub/test_download_dataset_file.py @@ -109,3 +109,61 @@ class DownloadDatasetTest(unittest.TestCase): file_modify_time2 = os.path.getmtime( os.path.join(dataset_cache_path2, file_path)) assert file_modify_time == file_modify_time2 + + # test download with wild pattern, ignore_file_pattern + with tempfile.TemporaryDirectory() as temp_cache_dir: + # first download to cache. + dataset_cache_path = dataset_snapshot_download( + dataset_id=dataset_id, + cache_dir=temp_cache_dir, + ignore_file_pattern='*.jpeg') + assert dataset_cache_path == os.path.join(temp_cache_dir, + dataset_id) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, deep_file_path)) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg')) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, + '111/222/shijian.jpeg')) + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, file_path)) + + # test download with wild pattern, allow_file_pattern + with tempfile.TemporaryDirectory() as temp_cache_dir: + # first download to cache. + dataset_cache_path = dataset_snapshot_download( + dataset_id=dataset_id, + cache_dir=temp_cache_dir, + allow_file_pattern='*.jpeg') + assert dataset_cache_path == os.path.join(temp_cache_dir, + dataset_id) + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, deep_file_path)) + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg')) + assert os.path.exists( + os.path.join(temp_cache_dir, dataset_id, + '111/222/shijian.jpeg')) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, file_path)) + + # test download with wild pattern, allow_file_pattern and ignore file pattern. + with tempfile.TemporaryDirectory() as temp_cache_dir: + # first download to cache. + dataset_cache_path = dataset_snapshot_download( + dataset_id=dataset_id, + cache_dir=temp_cache_dir, + ignore_file_pattern='*.jpeg', + allow_file_pattern='*.xxx') + assert dataset_cache_path == os.path.join(temp_cache_dir, + dataset_id) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, deep_file_path)) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, '111/shijian.jpeg')) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, + '111/222/shijian.jpeg')) + assert not os.path.exists( + os.path.join(temp_cache_dir, dataset_id, file_path))