support multiple include/exclude filter patterns in command line (#1214)

Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
(cherry picked from commit 1f88654aa1)
This commit is contained in:
Yingda Chen
2025-02-07 16:02:37 +08:00
committed by yuze.zyz
parent 9dda88a399
commit 2df80c0c40
4 changed files with 44 additions and 7 deletions

View File

@@ -8,6 +8,7 @@ from modelscope.hub.file_download import (dataset_file_download,
model_file_download) model_file_download)
from modelscope.hub.snapshot_download import (dataset_snapshot_download, from modelscope.hub.snapshot_download import (dataset_snapshot_download,
snapshot_download) snapshot_download)
from modelscope.hub.utils.utils import convert_patterns
from modelscope.utils.constant import DEFAULT_DATASET_REVISION from modelscope.utils.constant import DEFAULT_DATASET_REVISION
@@ -141,8 +142,8 @@ class DownloadCMD(CLICommand):
revision=self.args.revision, revision=self.args.revision,
cache_dir=self.args.cache_dir, cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir, local_dir=self.args.local_dir,
allow_file_pattern=self.args.include, allow_file_pattern=convert_patterns(self.args.include),
ignore_file_pattern=self.args.exclude, ignore_file_pattern=convert_patterns(self.args.exclude),
max_workers=self.args.max_workers, max_workers=self.args.max_workers,
) )
elif self.args.dataset: elif self.args.dataset:
@@ -170,8 +171,8 @@ class DownloadCMD(CLICommand):
revision=dataset_revision, revision=dataset_revision,
cache_dir=self.args.cache_dir, cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir, local_dir=self.args.local_dir,
allow_file_pattern=self.args.include, allow_file_pattern=convert_patterns(self.args.include),
ignore_file_pattern=self.args.exclude, ignore_file_pattern=convert_patterns(self.args.exclude),
max_workers=self.args.max_workers, max_workers=self.args.max_workers,
) )
else: else:

View File

@@ -4,6 +4,7 @@ from argparse import ArgumentParser, _SubParsersAction
from modelscope.cli.base import CLICommand from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.utils.utils import convert_patterns, get_endpoint
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
@@ -89,7 +90,7 @@ class UploadCMD(CLICommand):
parser.add_argument( parser.add_argument(
'--endpoint', '--endpoint',
type=str, type=str,
default='https://www.modelscope.cn', default=get_endpoint(),
help='Endpoint for Modelscope service.') help='Endpoint for Modelscope service.')
parser.set_defaults(func=subparser_func) parser.set_defaults(func=subparser_func)
@@ -166,8 +167,8 @@ class UploadCMD(CLICommand):
commit_message=self.args.commit_message, commit_message=self.args.commit_message,
commit_description=self.args.commit_description, commit_description=self.args.commit_description,
repo_type=self.args.repo_type, repo_type=self.args.repo_type,
allow_patterns=self.args.include, allow_file_pattern=convert_patterns(self.args.include),
ignore_patterns=self.args.exclude, ignore_file_pattern=convert_patterns(self.args.exclude),
max_workers=self.args.max_workers, max_workers=self.args.max_workers,
) )
else: else:

View File

@@ -31,6 +31,25 @@ def model_id_to_group_owner_name(model_id):
return group_or_owner, name return group_or_owner, name
def convert_patterns(raw_input: Union[str, List[str]]):
output = None
if isinstance(raw_input, str):
output = list()
if ',' in raw_input:
output = [s.strip() for s in raw_input.split(',')]
else:
output.append(raw_input.strip())
elif isinstance(raw_input, list):
output = list()
for s in raw_input:
if isinstance(s, str):
if ',' in s:
output.extend([ss.strip() for ss in s.split(',')])
else:
output.append(s.strip())
return output
# during model download, the '.' would be converted to '___' to produce # during model download, the '.' would be converted to '___' to produce
# actual physical (masked) directory for storage # actual physical (masked) directory for storage
def get_model_masked_directory(directory, model_id): def get_model_masked_directory(directory, model_id):

View File

@@ -6,10 +6,26 @@ import unittest
from requests import HTTPError from requests import HTTPError
from modelscope.fileio.file import File, HTTPStorage, LocalStorage from modelscope.fileio.file import File, HTTPStorage, LocalStorage
from modelscope.hub.utils.utils import convert_patterns
class FileTest(unittest.TestCase): class FileTest(unittest.TestCase):
def test_pattern_conversion(self):
self._assert_patterns(None, None)
self._assert_patterns('*.h5', ['*.h5'])
self._assert_patterns('*.h5 ', ['*.h5'])
self._assert_patterns('*.h5, *flax_model.msgpack',
['*.h5', '*flax_model.msgpack'])
self._assert_patterns(['*.h5, *flax_model.msgpack'],
['*.h5', '*flax_model.msgpack'])
self._assert_patterns(['*.h5 ', '*flax_model.msgpack'],
['*.h5', '*flax_model.msgpack'])
def _assert_patterns(self, raw_input, expected_output):
output = convert_patterns(raw_input)
self.assertEqual(expected_output, output)
def test_local_storage(self): def test_local_storage(self):
storage = LocalStorage() storage = LocalStorage()
temp_name = tempfile.gettempdir() + '/' + next( temp_name = tempfile.gettempdir() + '/' + next(