From 2d14d24d3e67294708d16a84c1f09d494cf7b76a Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Wed, 31 Jul 2024 10:43:33 +0800 Subject: [PATCH] fix file pattern compatible issue --- modelscope/hub/snapshot_download.py | 124 +++++++++++++++++++--------- 1 file changed, 83 insertions(+), 41 deletions(-) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 7f069795..2f5d5a4d 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -81,10 +81,6 @@ def snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ - if allow_patterns: - allow_file_pattern = allow_patterns - if ignore_patterns: - ignore_file_pattern = ignore_patterns return _snapshot_download( model_id, repo_type=REPO_TYPE_MODEL, @@ -95,7 +91,9 @@ def snapshot_download( cookies=cookies, ignore_file_pattern=ignore_file_pattern, allow_file_pattern=allow_file_pattern, - local_dir=local_dir) + local_dir=local_dir, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns) def dataset_snapshot_download( @@ -157,10 +155,6 @@ def dataset_snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ - if allow_patterns: - allow_file_pattern = allow_patterns - if ignore_patterns: - ignore_file_pattern = ignore_patterns return _snapshot_download( dataset_id, repo_type=REPO_TYPE_DATASET, @@ -171,7 +165,9 @@ def dataset_snapshot_download( cookies=cookies, ignore_file_pattern=ignore_file_pattern, allow_file_pattern=allow_file_pattern, - local_dir=local_dir) + local_dir=local_dir, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns) def _snapshot_download( @@ -186,6 +182,8 @@ def _snapshot_download( ignore_file_pattern: Optional[Union[str, List[str]]] = None, allow_file_pattern: Optional[Union[str, List[str]]] = None, local_dir: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, ): if not repo_type: repo_type = REPO_TYPE_MODEL @@ -254,7 +252,9 @@ def _snapshot_download( revision=revision, cookies=cookies, ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=allow_file_pattern) + allow_file_pattern=allow_file_pattern, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns) elif repo_type == REPO_TYPE_DATASET: group_or_owner, name = model_id_to_group_owner_name(repo_id) @@ -295,7 +295,9 @@ def _snapshot_download( revision=revision, cookies=cookies, ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=allow_file_pattern) + allow_file_pattern=allow_file_pattern, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns) if len(repo_files) < page_size: break page_number += 1 @@ -304,6 +306,35 @@ def _snapshot_download( return os.path.join(cache.get_root_location()) +def _is_valid_regex(pattern: str): + try: + re.compile(pattern) + return True + except BaseException: + return False + + +def _normalize_patterns(patterns: Union[str, List[str]]): + if isinstance(patterns, str): + patterns = [patterns] + if patterns is not None: + patterns = [ + item if not item.endswith('/') else item + '*' for item in patterns + ] + return patterns + + +def _get_valid_regex_pattern(patterns: List[str]): + if patterns is not None: + regex_patterns = [] + for item in patterns: + if _is_valid_regex(item): + regex_patterns.append(item) + return regex_patterns + else: + return None + + def _download_file_lists( repo_files: List[str], cache: ModelFileSystemCache, @@ -319,42 +350,53 @@ def _download_file_lists( cookies: Optional[CookieJar] = None, ignore_file_pattern: Optional[Union[str, List[str]]] = None, allow_file_pattern: Optional[Union[str, List[str]]] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, ): - if ignore_file_pattern is None: - ignore_file_pattern = [] - if isinstance(ignore_file_pattern, str): - ignore_file_pattern = [ignore_file_pattern] - ignore_file_pattern = [ - item if not item.endswith('/') else item + '*' - for item in ignore_file_pattern - ] - ignore_regex_pattern = [] - for file_pattern in ignore_file_pattern: - if file_pattern.startswith('*'): - ignore_regex_pattern.append('.' + file_pattern) - else: - ignore_regex_pattern.append(file_pattern) - - if allow_file_pattern is not None: - if isinstance(allow_file_pattern, str): - allow_file_pattern = [allow_file_pattern] - allow_file_pattern = [ - item if not item.endswith('/') else item + '*' - for item in allow_file_pattern - ] + ignore_patterns = _normalize_patterns(ignore_patterns) + allow_patterns = _normalize_patterns(allow_patterns) + ignore_file_pattern = _normalize_patterns(ignore_file_pattern) + allow_file_pattern = _normalize_patterns(ignore_file_pattern) + # to compatible regex usage. + ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern) for repo_file in repo_files: - if repo_file['Type'] == 'tree' or \ - any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \ - any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501 + if repo_file['Type'] == 'tree': continue - - if allow_file_pattern is not None and allow_file_pattern: - if not any( + try: + # processing patterns + if ignore_patterns and any([ fnmatch.fnmatch(repo_file['Path'], pattern) - for pattern in allow_file_pattern): + for pattern in ignore_patterns + ]): continue + if ignore_file_pattern and any([ + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in ignore_file_pattern + ]): + continue + + if ignore_regex_pattern and any([ + re.search(pattern, repo_file['Name']) is not None + for pattern in ignore_regex_pattern + ]): # noqa E501 + continue + + if allow_patterns is not None and allow_patterns: + if not any( + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in allow_patterns): + continue + + if allow_file_pattern is not None and allow_file_pattern: + if not any( + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in allow_file_pattern): + continue + except Exception as e: + logger.warning('The file pattern is invalid : %s' % e) + # check model_file is exist in cache, if existed, skip download, otherwise download if cache.exists(repo_file): file_name = os.path.basename(repo_file['Name'])