diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 59a3b3ba..71d20dc9 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -514,6 +514,11 @@ class HubApi: revision_detail = self.get_branch_tag_detail(all_branches_detail, revision) logger.info('Development mode use revision: %s' % revision) else: + if revision is not None and revision in all_branches: + revision_detail = self.get_branch_tag_detail(all_branches_detail, revision) + logger.warning('Using branch: %s as version is unstable, use with caution' % revision) + return revision_detail + if len(all_tags_detail) == 0: # use no revision use master as default. if revision is None or revision == MASTER_MODEL_BRANCH: revision = MASTER_MODEL_BRANCH diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 7f069795..2544b58e 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -81,10 +81,6 @@ def snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ - if allow_patterns: - allow_file_pattern = allow_patterns - if ignore_patterns: - ignore_file_pattern = ignore_patterns return _snapshot_download( model_id, repo_type=REPO_TYPE_MODEL, @@ -95,7 +91,9 @@ def snapshot_download( cookies=cookies, ignore_file_pattern=ignore_file_pattern, allow_file_pattern=allow_file_pattern, - local_dir=local_dir) + local_dir=local_dir, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns) def dataset_snapshot_download( @@ -157,10 +155,6 @@ def dataset_snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ - if allow_patterns: - allow_file_pattern = allow_patterns - if ignore_patterns: - ignore_file_pattern = ignore_patterns return _snapshot_download( dataset_id, repo_type=REPO_TYPE_DATASET, @@ -171,7 +165,9 @@ def dataset_snapshot_download( cookies=cookies, ignore_file_pattern=ignore_file_pattern, allow_file_pattern=allow_file_pattern, - local_dir=local_dir) + local_dir=local_dir, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns) def _snapshot_download( @@ -186,6 +182,8 @@ def _snapshot_download( ignore_file_pattern: Optional[Union[str, List[str]]] = None, allow_file_pattern: Optional[Union[str, List[str]]] = None, local_dir: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, ): if not repo_type: repo_type = REPO_TYPE_MODEL @@ -249,12 +247,13 @@ def _snapshot_download( None, None, headers, - revision_detail=revision_detail, repo_type=repo_type, revision=revision, cookies=cookies, ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=allow_file_pattern) + allow_file_pattern=allow_file_pattern, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns) elif repo_type == REPO_TYPE_DATASET: group_or_owner, name = model_id_to_group_owner_name(repo_id) @@ -290,12 +289,13 @@ def _snapshot_download( name, group_or_owner, headers, - revision_detail=revision_detail, repo_type=repo_type, revision=revision, cookies=cookies, ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=allow_file_pattern) + allow_file_pattern=allow_file_pattern, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns) if len(repo_files) < page_size: break page_number += 1 @@ -304,6 +304,35 @@ def _snapshot_download( return os.path.join(cache.get_root_location()) +def _is_valid_regex(pattern: str): + try: + re.compile(pattern) + return True + except BaseException: + return False + + +def _normalize_patterns(patterns: Union[str, List[str]]): + if isinstance(patterns, str): + patterns = [patterns] + if patterns is not None: + patterns = [ + item if not item.endswith('/') else item + '*' for item in patterns + ] + return patterns + + +def _get_valid_regex_pattern(patterns: List[str]): + if patterns is not None: + regex_patterns = [] + for item in patterns: + if _is_valid_regex(item): + regex_patterns.append(item) + return regex_patterns + else: + return None + + def _download_file_lists( repo_files: List[str], cache: ModelFileSystemCache, @@ -313,48 +342,58 @@ def _download_file_lists( name: str, group_or_owner: str, headers, - revision_detail: str, repo_type: Optional[str] = None, revision: Optional[str] = DEFAULT_MODEL_REVISION, cookies: Optional[CookieJar] = None, ignore_file_pattern: Optional[Union[str, List[str]]] = None, allow_file_pattern: Optional[Union[str, List[str]]] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, ): - if ignore_file_pattern is None: - ignore_file_pattern = [] - if isinstance(ignore_file_pattern, str): - ignore_file_pattern = [ignore_file_pattern] - ignore_file_pattern = [ - item if not item.endswith('/') else item + '*' - for item in ignore_file_pattern - ] - ignore_regex_pattern = [] - for file_pattern in ignore_file_pattern: - if file_pattern.startswith('*'): - ignore_regex_pattern.append('.' + file_pattern) - else: - ignore_regex_pattern.append(file_pattern) - - if allow_file_pattern is not None: - if isinstance(allow_file_pattern, str): - allow_file_pattern = [allow_file_pattern] - allow_file_pattern = [ - item if not item.endswith('/') else item + '*' - for item in allow_file_pattern - ] + ignore_patterns = _normalize_patterns(ignore_patterns) + allow_patterns = _normalize_patterns(allow_patterns) + ignore_file_pattern = _normalize_patterns(ignore_file_pattern) + allow_file_pattern = _normalize_patterns(allow_file_pattern) + # to compatible regex usage. + ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern) for repo_file in repo_files: - if repo_file['Type'] == 'tree' or \ - any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \ - any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501 + if repo_file['Type'] == 'tree': continue - - if allow_file_pattern is not None and allow_file_pattern: - if not any( + try: + # processing patterns + if ignore_patterns and any([ fnmatch.fnmatch(repo_file['Path'], pattern) - for pattern in allow_file_pattern): + for pattern in ignore_patterns + ]): continue + if ignore_file_pattern and any([ + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in ignore_file_pattern + ]): + continue + + if ignore_regex_pattern and any([ + re.search(pattern, repo_file['Name']) is not None + for pattern in ignore_regex_pattern + ]): # noqa E501 + continue + + if allow_patterns is not None and allow_patterns: + if not any( + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in allow_patterns): + continue + + if allow_file_pattern is not None and allow_file_pattern: + if not any( + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in allow_file_pattern): + continue + except Exception as e: + logger.warning('The file pattern is invalid : %s' % e) + # check model_file is exist in cache, if existed, skip download, otherwise download if cache.exists(repo_file): file_name = os.path.basename(repo_file['Name']) diff --git a/tests/hub/test_download_dataset_file.py b/tests/hub/test_download_dataset_file.py index 0c4e9307..8e8712f5 100644 --- a/tests/hub/test_download_dataset_file.py +++ b/tests/hub/test_download_dataset_file.py @@ -116,7 +116,7 @@ class DownloadDatasetTest(unittest.TestCase): dataset_cache_path = dataset_snapshot_download( dataset_id=dataset_id, cache_dir=temp_cache_dir, - ignore_file_pattern='*.jpeg') + ignore_file_pattern=['*.jpeg', '.jpg']) assert dataset_cache_path == os.path.join(temp_cache_dir, dataset_id) assert not os.path.exists( diff --git a/tests/hub/test_hub_revision_release_mode.py b/tests/hub/test_hub_revision_release_mode.py index 3b8416db..823e1d5d 100644 --- a/tests/hub/test_hub_revision_release_mode.py +++ b/tests/hub/test_hub_revision_release_mode.py @@ -93,6 +93,7 @@ class HubRevisionTest(unittest.TestCase): self.prepare_repo_data() # no tag, default get master branch_name = 'test' self.add_new_file_and_branch_to_repo(branch_name) + time.sleep(5) with tempfile.TemporaryDirectory() as temp_cache_dir: snapshot_path = snapshot_download( self.model_id,