mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
fix file pattern compatible issue
This commit is contained in:
@@ -81,10 +81,6 @@ def snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if allow_patterns:
|
||||
allow_file_pattern = allow_patterns
|
||||
if ignore_patterns:
|
||||
ignore_file_pattern = ignore_patterns
|
||||
return _snapshot_download(
|
||||
model_id,
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
@@ -95,7 +91,9 @@ def snapshot_download(
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
local_dir=local_dir)
|
||||
local_dir=local_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
@@ -157,10 +155,6 @@ def dataset_snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if allow_patterns:
|
||||
allow_file_pattern = allow_patterns
|
||||
if ignore_patterns:
|
||||
ignore_file_pattern = ignore_patterns
|
||||
return _snapshot_download(
|
||||
dataset_id,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
@@ -171,7 +165,9 @@ def dataset_snapshot_download(
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
local_dir=local_dir)
|
||||
local_dir=local_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns)
|
||||
|
||||
|
||||
def _snapshot_download(
|
||||
@@ -186,6 +182,8 @@ def _snapshot_download(
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
@@ -254,7 +252,9 @@ def _snapshot_download(
|
||||
revision=revision,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern)
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns)
|
||||
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
@@ -295,7 +295,9 @@ def _snapshot_download(
|
||||
revision=revision,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern)
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns)
|
||||
if len(repo_files) < page_size:
|
||||
break
|
||||
page_number += 1
|
||||
@@ -304,6 +306,35 @@ def _snapshot_download(
|
||||
return os.path.join(cache.get_root_location())
|
||||
|
||||
|
||||
def _is_valid_regex(pattern: str):
|
||||
try:
|
||||
re.compile(pattern)
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_patterns(patterns: Union[str, List[str]]):
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
if patterns is not None:
|
||||
patterns = [
|
||||
item if not item.endswith('/') else item + '*' for item in patterns
|
||||
]
|
||||
return patterns
|
||||
|
||||
|
||||
def _get_valid_regex_pattern(patterns: List[str]):
|
||||
if patterns is not None:
|
||||
regex_patterns = []
|
||||
for item in patterns:
|
||||
if _is_valid_regex(item):
|
||||
regex_patterns.append(item)
|
||||
return regex_patterns
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _download_file_lists(
|
||||
repo_files: List[str],
|
||||
cache: ModelFileSystemCache,
|
||||
@@ -319,42 +350,53 @@ def _download_file_lists(
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
if ignore_file_pattern is None:
|
||||
ignore_file_pattern = []
|
||||
if isinstance(ignore_file_pattern, str):
|
||||
ignore_file_pattern = [ignore_file_pattern]
|
||||
ignore_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in ignore_file_pattern
|
||||
]
|
||||
ignore_regex_pattern = []
|
||||
for file_pattern in ignore_file_pattern:
|
||||
if file_pattern.startswith('*'):
|
||||
ignore_regex_pattern.append('.' + file_pattern)
|
||||
else:
|
||||
ignore_regex_pattern.append(file_pattern)
|
||||
|
||||
if allow_file_pattern is not None:
|
||||
if isinstance(allow_file_pattern, str):
|
||||
allow_file_pattern = [allow_file_pattern]
|
||||
allow_file_pattern = [
|
||||
item if not item.endswith('/') else item + '*'
|
||||
for item in allow_file_pattern
|
||||
]
|
||||
ignore_patterns = _normalize_patterns(ignore_patterns)
|
||||
allow_patterns = _normalize_patterns(allow_patterns)
|
||||
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
|
||||
allow_file_pattern = _normalize_patterns(ignore_file_pattern)
|
||||
# to compatible regex usage.
|
||||
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
|
||||
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree' or \
|
||||
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
|
||||
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
try:
|
||||
# processing patterns
|
||||
if ignore_patterns and any([
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
for pattern in ignore_patterns
|
||||
]):
|
||||
continue
|
||||
|
||||
if ignore_file_pattern and any([
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in ignore_file_pattern
|
||||
]):
|
||||
continue
|
||||
|
||||
if ignore_regex_pattern and any([
|
||||
re.search(pattern, repo_file['Name']) is not None
|
||||
for pattern in ignore_regex_pattern
|
||||
]): # noqa E501
|
||||
continue
|
||||
|
||||
if allow_patterns is not None and allow_patterns:
|
||||
if not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_patterns):
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning('The file pattern is invalid : %s' % e)
|
||||
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
|
||||
Reference in New Issue
Block a user