mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
support multiple include/exclude filter patterns in command line (#1214)
Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
(cherry picked from commit 1f88654aa1)
This commit is contained in:
@@ -8,6 +8,7 @@ from modelscope.hub.file_download import (dataset_file_download,
|
||||
model_file_download)
|
||||
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
|
||||
snapshot_download)
|
||||
from modelscope.hub.utils.utils import convert_patterns
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
|
||||
|
||||
|
||||
@@ -141,8 +142,8 @@ class DownloadCMD(CLICommand):
|
||||
revision=self.args.revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
allow_file_pattern=convert_patterns(self.args.include),
|
||||
ignore_file_pattern=convert_patterns(self.args.exclude),
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
elif self.args.dataset:
|
||||
@@ -170,8 +171,8 @@ class DownloadCMD(CLICommand):
|
||||
revision=dataset_revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
allow_file_pattern=convert_patterns(self.args.include),
|
||||
ignore_file_pattern=convert_patterns(self.args.exclude),
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -4,6 +4,7 @@ from argparse import ArgumentParser, _SubParsersAction
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.utils.utils import convert_patterns, get_endpoint
|
||||
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
|
||||
|
||||
|
||||
@@ -89,7 +90,7 @@ class UploadCMD(CLICommand):
|
||||
parser.add_argument(
|
||||
'--endpoint',
|
||||
type=str,
|
||||
default='https://www.modelscope.cn',
|
||||
default=get_endpoint(),
|
||||
help='Endpoint for Modelscope service.')
|
||||
|
||||
parser.set_defaults(func=subparser_func)
|
||||
@@ -166,8 +167,8 @@ class UploadCMD(CLICommand):
|
||||
commit_message=self.args.commit_message,
|
||||
commit_description=self.args.commit_description,
|
||||
repo_type=self.args.repo_type,
|
||||
allow_patterns=self.args.include,
|
||||
ignore_patterns=self.args.exclude,
|
||||
allow_file_pattern=convert_patterns(self.args.include),
|
||||
ignore_file_pattern=convert_patterns(self.args.exclude),
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -31,6 +31,25 @@ def model_id_to_group_owner_name(model_id):
|
||||
return group_or_owner, name
|
||||
|
||||
|
||||
def convert_patterns(raw_input: Union[str, List[str]]):
|
||||
output = None
|
||||
if isinstance(raw_input, str):
|
||||
output = list()
|
||||
if ',' in raw_input:
|
||||
output = [s.strip() for s in raw_input.split(',')]
|
||||
else:
|
||||
output.append(raw_input.strip())
|
||||
elif isinstance(raw_input, list):
|
||||
output = list()
|
||||
for s in raw_input:
|
||||
if isinstance(s, str):
|
||||
if ',' in s:
|
||||
output.extend([ss.strip() for ss in s.split(',')])
|
||||
else:
|
||||
output.append(s.strip())
|
||||
return output
|
||||
|
||||
|
||||
# during model download, the '.' would be converted to '___' to produce
|
||||
# actual physical (masked) directory for storage
|
||||
def get_model_masked_directory(directory, model_id):
|
||||
|
||||
@@ -6,10 +6,26 @@ import unittest
|
||||
from requests import HTTPError
|
||||
|
||||
from modelscope.fileio.file import File, HTTPStorage, LocalStorage
|
||||
from modelscope.hub.utils.utils import convert_patterns
|
||||
|
||||
|
||||
class FileTest(unittest.TestCase):
|
||||
|
||||
def test_pattern_conversion(self):
|
||||
self._assert_patterns(None, None)
|
||||
self._assert_patterns('*.h5', ['*.h5'])
|
||||
self._assert_patterns('*.h5 ', ['*.h5'])
|
||||
self._assert_patterns('*.h5, *flax_model.msgpack',
|
||||
['*.h5', '*flax_model.msgpack'])
|
||||
self._assert_patterns(['*.h5, *flax_model.msgpack'],
|
||||
['*.h5', '*flax_model.msgpack'])
|
||||
self._assert_patterns(['*.h5 ', '*flax_model.msgpack'],
|
||||
['*.h5', '*flax_model.msgpack'])
|
||||
|
||||
def _assert_patterns(self, raw_input, expected_output):
|
||||
output = convert_patterns(raw_input)
|
||||
self.assertEqual(expected_output, output)
|
||||
|
||||
def test_local_storage(self):
|
||||
storage = LocalStorage()
|
||||
temp_name = tempfile.gettempdir() + '/' + next(
|
||||
|
||||
Reference in New Issue
Block a user