diff --git a/modelscope/cli/download.py b/modelscope/cli/download.py index 321c2b5d..6b430453 100644 --- a/modelscope/cli/download.py +++ b/modelscope/cli/download.py @@ -8,6 +8,7 @@ from modelscope.hub.file_download import (dataset_file_download, model_file_download) from modelscope.hub.snapshot_download import (dataset_snapshot_download, snapshot_download) +from modelscope.hub.utils.utils import convert_patterns from modelscope.utils.constant import DEFAULT_DATASET_REVISION @@ -141,8 +142,8 @@ class DownloadCMD(CLICommand): revision=self.args.revision, cache_dir=self.args.cache_dir, local_dir=self.args.local_dir, - allow_file_pattern=self.args.include, - ignore_file_pattern=self.args.exclude, + allow_file_pattern=convert_patterns(self.args.include), + ignore_file_pattern=convert_patterns(self.args.exclude), max_workers=self.args.max_workers, ) elif self.args.dataset: @@ -170,8 +171,8 @@ class DownloadCMD(CLICommand): revision=dataset_revision, cache_dir=self.args.cache_dir, local_dir=self.args.local_dir, - allow_file_pattern=self.args.include, - ignore_file_pattern=self.args.exclude, + allow_file_pattern=convert_patterns(self.args.include), + ignore_file_pattern=convert_patterns(self.args.exclude), max_workers=self.args.max_workers, ) else: diff --git a/modelscope/cli/upload.py b/modelscope/cli/upload.py index 29dacbe5..d32abdcc 100644 --- a/modelscope/cli/upload.py +++ b/modelscope/cli/upload.py @@ -4,6 +4,7 @@ from argparse import ArgumentParser, _SubParsersAction from modelscope.cli.base import CLICommand from modelscope.hub.api import HubApi, ModelScopeConfig +from modelscope.hub.utils.utils import convert_patterns, get_endpoint from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT @@ -89,7 +90,7 @@ class UploadCMD(CLICommand): parser.add_argument( '--endpoint', type=str, - default='https://www.modelscope.cn', + default=get_endpoint(), help='Endpoint for Modelscope service.') parser.set_defaults(func=subparser_func) @@ -166,8 +167,8 @@ class UploadCMD(CLICommand): commit_message=self.args.commit_message, commit_description=self.args.commit_description, repo_type=self.args.repo_type, - allow_patterns=self.args.include, - ignore_patterns=self.args.exclude, + allow_file_pattern=convert_patterns(self.args.include), + ignore_file_pattern=convert_patterns(self.args.exclude), max_workers=self.args.max_workers, ) else: diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index 3f3a4c75..3ad96fe2 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -31,6 +31,25 @@ def model_id_to_group_owner_name(model_id): return group_or_owner, name +def convert_patterns(raw_input: Union[str, List[str]]): + output = None + if isinstance(raw_input, str): + output = list() + if ',' in raw_input: + output = [s.strip() for s in raw_input.split(',')] + else: + output.append(raw_input.strip()) + elif isinstance(raw_input, list): + output = list() + for s in raw_input: + if isinstance(s, str): + if ',' in s: + output.extend([ss.strip() for ss in s.split(',')]) + else: + output.append(s.strip()) + return output + + # during model download, the '.' would be converted to '___' to produce # actual physical (masked) directory for storage def get_model_masked_directory(directory, model_id): diff --git a/tests/fileio/test_file.py b/tests/fileio/test_file.py index ded8ece7..383e8231 100644 --- a/tests/fileio/test_file.py +++ b/tests/fileio/test_file.py @@ -6,10 +6,26 @@ import unittest from requests import HTTPError from modelscope.fileio.file import File, HTTPStorage, LocalStorage +from modelscope.hub.utils.utils import convert_patterns class FileTest(unittest.TestCase): + def test_pattern_conversion(self): + self._assert_patterns(None, None) + self._assert_patterns('*.h5', ['*.h5']) + self._assert_patterns('*.h5 ', ['*.h5']) + self._assert_patterns('*.h5, *flax_model.msgpack', + ['*.h5', '*flax_model.msgpack']) + self._assert_patterns(['*.h5, *flax_model.msgpack'], + ['*.h5', '*flax_model.msgpack']) + self._assert_patterns(['*.h5 ', '*flax_model.msgpack'], + ['*.h5', '*flax_model.msgpack']) + + def _assert_patterns(self, raw_input, expected_output): + output = convert_patterns(raw_input) + self.assertEqual(expected_output, output) + def test_local_storage(self): storage = LocalStorage() temp_name = tempfile.gettempdir() + '/' + next(