fix file pattern compatible issue

This commit is contained in:
mulin.lyh
2024-07-31 10:43:33 +08:00
parent e4f5f95cc2
commit 2d14d24d3e

View File

@@ -81,10 +81,6 @@ def snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
if allow_patterns:
allow_file_pattern = allow_patterns
if ignore_patterns:
ignore_file_pattern = ignore_patterns
return _snapshot_download(
model_id,
repo_type=REPO_TYPE_MODEL,
@@ -95,7 +91,9 @@ def snapshot_download(
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
local_dir=local_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns)
def dataset_snapshot_download(
@@ -157,10 +155,6 @@ def dataset_snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
if allow_patterns:
allow_file_pattern = allow_patterns
if ignore_patterns:
ignore_file_pattern = ignore_patterns
return _snapshot_download(
dataset_id,
repo_type=REPO_TYPE_DATASET,
@@ -171,7 +165,9 @@ def dataset_snapshot_download(
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
local_dir=local_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns)
def _snapshot_download(
@@ -186,6 +182,8 @@ def _snapshot_download(
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
local_dir: Optional[str] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
):
if not repo_type:
repo_type = REPO_TYPE_MODEL
@@ -254,7 +252,9 @@ def _snapshot_download(
revision=revision,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern)
allow_file_pattern=allow_file_pattern,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns)
elif repo_type == REPO_TYPE_DATASET:
group_or_owner, name = model_id_to_group_owner_name(repo_id)
@@ -295,7 +295,9 @@ def _snapshot_download(
revision=revision,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern)
allow_file_pattern=allow_file_pattern,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns)
if len(repo_files) < page_size:
break
page_number += 1
@@ -304,6 +306,35 @@ def _snapshot_download(
return os.path.join(cache.get_root_location())
def _is_valid_regex(pattern: str):
try:
re.compile(pattern)
return True
except BaseException:
return False
def _normalize_patterns(patterns: Union[str, List[str]]):
if isinstance(patterns, str):
patterns = [patterns]
if patterns is not None:
patterns = [
item if not item.endswith('/') else item + '*' for item in patterns
]
return patterns
def _get_valid_regex_pattern(patterns: List[str]):
if patterns is not None:
regex_patterns = []
for item in patterns:
if _is_valid_regex(item):
regex_patterns.append(item)
return regex_patterns
else:
return None
def _download_file_lists(
repo_files: List[str],
cache: ModelFileSystemCache,
@@ -319,42 +350,53 @@ def _download_file_lists(
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
):
if ignore_file_pattern is None:
ignore_file_pattern = []
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
ignore_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in ignore_file_pattern
]
ignore_regex_pattern = []
for file_pattern in ignore_file_pattern:
if file_pattern.startswith('*'):
ignore_regex_pattern.append('.' + file_pattern)
else:
ignore_regex_pattern.append(file_pattern)
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
allow_file_pattern = [allow_file_pattern]
allow_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in allow_file_pattern
]
ignore_patterns = _normalize_patterns(ignore_patterns)
allow_patterns = _normalize_patterns(allow_patterns)
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
allow_file_pattern = _normalize_patterns(ignore_file_pattern)
# to compatible regex usage.
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
for repo_file in repo_files:
if repo_file['Type'] == 'tree' or \
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
if repo_file['Type'] == 'tree':
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
try:
# processing patterns
if ignore_patterns and any([
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
for pattern in ignore_patterns
]):
continue
if ignore_file_pattern and any([
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in ignore_file_pattern
]):
continue
if ignore_regex_pattern and any([
re.search(pattern, repo_file['Name']) is not None
for pattern in ignore_regex_pattern
]): # noqa E501
continue
if allow_patterns is not None and allow_patterns:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_patterns):
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
except Exception as e:
logger.warning('The file pattern is invalid : %s' % e)
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])