Merge pull request #942 from modelscope/release/1.17

Release/1.17
This commit is contained in:
Yingda Chen
2024-08-06 14:26:13 +08:00
committed by GitHub
4 changed files with 90 additions and 45 deletions

View File

@@ -514,6 +514,11 @@ class HubApi:
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
logger.info('Development mode use revision: %s' % revision)
else:
if revision is not None and revision in all_branches:
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
logger.warning('Using branch: %s as version is unstable, use with caution' % revision)
return revision_detail
if len(all_tags_detail) == 0: # use no revision use master as default.
if revision is None or revision == MASTER_MODEL_BRANCH:
revision = MASTER_MODEL_BRANCH

View File

@@ -81,10 +81,6 @@ def snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
if allow_patterns:
allow_file_pattern = allow_patterns
if ignore_patterns:
ignore_file_pattern = ignore_patterns
return _snapshot_download(
model_id,
repo_type=REPO_TYPE_MODEL,
@@ -95,7 +91,9 @@ def snapshot_download(
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
local_dir=local_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns)
def dataset_snapshot_download(
@@ -157,10 +155,6 @@ def dataset_snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
if allow_patterns:
allow_file_pattern = allow_patterns
if ignore_patterns:
ignore_file_pattern = ignore_patterns
return _snapshot_download(
dataset_id,
repo_type=REPO_TYPE_DATASET,
@@ -171,7 +165,9 @@ def dataset_snapshot_download(
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir)
local_dir=local_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns)
def _snapshot_download(
@@ -186,6 +182,8 @@ def _snapshot_download(
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
local_dir: Optional[str] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
):
if not repo_type:
repo_type = REPO_TYPE_MODEL
@@ -249,12 +247,13 @@ def _snapshot_download(
None,
None,
headers,
revision_detail=revision_detail,
repo_type=repo_type,
revision=revision,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern)
allow_file_pattern=allow_file_pattern,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns)
elif repo_type == REPO_TYPE_DATASET:
group_or_owner, name = model_id_to_group_owner_name(repo_id)
@@ -290,12 +289,13 @@ def _snapshot_download(
name,
group_or_owner,
headers,
revision_detail=revision_detail,
repo_type=repo_type,
revision=revision,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern)
allow_file_pattern=allow_file_pattern,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns)
if len(repo_files) < page_size:
break
page_number += 1
@@ -304,6 +304,35 @@ def _snapshot_download(
return os.path.join(cache.get_root_location())
def _is_valid_regex(pattern: str):
try:
re.compile(pattern)
return True
except BaseException:
return False
def _normalize_patterns(patterns: Union[str, List[str]]):
if isinstance(patterns, str):
patterns = [patterns]
if patterns is not None:
patterns = [
item if not item.endswith('/') else item + '*' for item in patterns
]
return patterns
def _get_valid_regex_pattern(patterns: List[str]):
if patterns is not None:
regex_patterns = []
for item in patterns:
if _is_valid_regex(item):
regex_patterns.append(item)
return regex_patterns
else:
return None
def _download_file_lists(
repo_files: List[str],
cache: ModelFileSystemCache,
@@ -313,48 +342,58 @@ def _download_file_lists(
name: str,
group_or_owner: str,
headers,
revision_detail: str,
repo_type: Optional[str] = None,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
):
if ignore_file_pattern is None:
ignore_file_pattern = []
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
ignore_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in ignore_file_pattern
]
ignore_regex_pattern = []
for file_pattern in ignore_file_pattern:
if file_pattern.startswith('*'):
ignore_regex_pattern.append('.' + file_pattern)
else:
ignore_regex_pattern.append(file_pattern)
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
allow_file_pattern = [allow_file_pattern]
allow_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in allow_file_pattern
]
ignore_patterns = _normalize_patterns(ignore_patterns)
allow_patterns = _normalize_patterns(allow_patterns)
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
allow_file_pattern = _normalize_patterns(allow_file_pattern)
# to compatible regex usage.
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
for repo_file in repo_files:
if repo_file['Type'] == 'tree' or \
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
if repo_file['Type'] == 'tree':
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
try:
# processing patterns
if ignore_patterns and any([
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
for pattern in ignore_patterns
]):
continue
if ignore_file_pattern and any([
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in ignore_file_pattern
]):
continue
if ignore_regex_pattern and any([
re.search(pattern, repo_file['Name']) is not None
for pattern in ignore_regex_pattern
]): # noqa E501
continue
if allow_patterns is not None and allow_patterns:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_patterns):
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
except Exception as e:
logger.warning('The file pattern is invalid : %s' % e)
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])

View File

@@ -116,7 +116,7 @@ class DownloadDatasetTest(unittest.TestCase):
dataset_cache_path = dataset_snapshot_download(
dataset_id=dataset_id,
cache_dir=temp_cache_dir,
ignore_file_pattern='*.jpeg')
ignore_file_pattern=['*.jpeg', '.jpg'])
assert dataset_cache_path == os.path.join(temp_cache_dir,
dataset_id)
assert not os.path.exists(

View File

@@ -93,6 +93,7 @@ class HubRevisionTest(unittest.TestCase):
self.prepare_repo_data() # no tag, default get master
branch_name = 'test'
self.add_new_file_and_branch_to_repo(branch_name)
time.sleep(5)
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id,