|
|
|
|
@@ -81,10 +81,6 @@ def snapshot_download(
|
|
|
|
|
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
|
|
|
if some parameter value is invalid
|
|
|
|
|
"""
|
|
|
|
|
if allow_patterns:
|
|
|
|
|
allow_file_pattern = allow_patterns
|
|
|
|
|
if ignore_patterns:
|
|
|
|
|
ignore_file_pattern = ignore_patterns
|
|
|
|
|
return _snapshot_download(
|
|
|
|
|
model_id,
|
|
|
|
|
repo_type=REPO_TYPE_MODEL,
|
|
|
|
|
@@ -95,7 +91,9 @@ def snapshot_download(
|
|
|
|
|
cookies=cookies,
|
|
|
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
|
|
|
allow_file_pattern=allow_file_pattern,
|
|
|
|
|
local_dir=local_dir)
|
|
|
|
|
local_dir=local_dir,
|
|
|
|
|
ignore_patterns=ignore_patterns,
|
|
|
|
|
allow_patterns=allow_patterns)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dataset_snapshot_download(
|
|
|
|
|
@@ -157,10 +155,6 @@ def dataset_snapshot_download(
|
|
|
|
|
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
|
|
|
if some parameter value is invalid
|
|
|
|
|
"""
|
|
|
|
|
if allow_patterns:
|
|
|
|
|
allow_file_pattern = allow_patterns
|
|
|
|
|
if ignore_patterns:
|
|
|
|
|
ignore_file_pattern = ignore_patterns
|
|
|
|
|
return _snapshot_download(
|
|
|
|
|
dataset_id,
|
|
|
|
|
repo_type=REPO_TYPE_DATASET,
|
|
|
|
|
@@ -171,7 +165,9 @@ def dataset_snapshot_download(
|
|
|
|
|
cookies=cookies,
|
|
|
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
|
|
|
allow_file_pattern=allow_file_pattern,
|
|
|
|
|
local_dir=local_dir)
|
|
|
|
|
local_dir=local_dir,
|
|
|
|
|
ignore_patterns=ignore_patterns,
|
|
|
|
|
allow_patterns=allow_patterns)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _snapshot_download(
|
|
|
|
|
@@ -186,6 +182,8 @@ def _snapshot_download(
|
|
|
|
|
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
|
|
|
|
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
|
|
|
|
local_dir: Optional[str] = None,
|
|
|
|
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
|
|
|
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
|
|
|
):
|
|
|
|
|
if not repo_type:
|
|
|
|
|
repo_type = REPO_TYPE_MODEL
|
|
|
|
|
@@ -249,12 +247,13 @@ def _snapshot_download(
|
|
|
|
|
None,
|
|
|
|
|
None,
|
|
|
|
|
headers,
|
|
|
|
|
revision_detail=revision_detail,
|
|
|
|
|
repo_type=repo_type,
|
|
|
|
|
revision=revision,
|
|
|
|
|
cookies=cookies,
|
|
|
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
|
|
|
allow_file_pattern=allow_file_pattern)
|
|
|
|
|
allow_file_pattern=allow_file_pattern,
|
|
|
|
|
ignore_patterns=ignore_patterns,
|
|
|
|
|
allow_patterns=allow_patterns)
|
|
|
|
|
|
|
|
|
|
elif repo_type == REPO_TYPE_DATASET:
|
|
|
|
|
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
|
|
|
|
@@ -290,12 +289,13 @@ def _snapshot_download(
|
|
|
|
|
name,
|
|
|
|
|
group_or_owner,
|
|
|
|
|
headers,
|
|
|
|
|
revision_detail=revision_detail,
|
|
|
|
|
repo_type=repo_type,
|
|
|
|
|
revision=revision,
|
|
|
|
|
cookies=cookies,
|
|
|
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
|
|
|
allow_file_pattern=allow_file_pattern)
|
|
|
|
|
allow_file_pattern=allow_file_pattern,
|
|
|
|
|
ignore_patterns=ignore_patterns,
|
|
|
|
|
allow_patterns=allow_patterns)
|
|
|
|
|
if len(repo_files) < page_size:
|
|
|
|
|
break
|
|
|
|
|
page_number += 1
|
|
|
|
|
@@ -304,6 +304,35 @@ def _snapshot_download(
|
|
|
|
|
return os.path.join(cache.get_root_location())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_valid_regex(pattern: str):
|
|
|
|
|
try:
|
|
|
|
|
re.compile(pattern)
|
|
|
|
|
return True
|
|
|
|
|
except BaseException:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_patterns(patterns: Union[str, List[str]]):
|
|
|
|
|
if isinstance(patterns, str):
|
|
|
|
|
patterns = [patterns]
|
|
|
|
|
if patterns is not None:
|
|
|
|
|
patterns = [
|
|
|
|
|
item if not item.endswith('/') else item + '*' for item in patterns
|
|
|
|
|
]
|
|
|
|
|
return patterns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_valid_regex_pattern(patterns: List[str]):
|
|
|
|
|
if patterns is not None:
|
|
|
|
|
regex_patterns = []
|
|
|
|
|
for item in patterns:
|
|
|
|
|
if _is_valid_regex(item):
|
|
|
|
|
regex_patterns.append(item)
|
|
|
|
|
return regex_patterns
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _download_file_lists(
|
|
|
|
|
repo_files: List[str],
|
|
|
|
|
cache: ModelFileSystemCache,
|
|
|
|
|
@@ -313,48 +342,58 @@ def _download_file_lists(
|
|
|
|
|
name: str,
|
|
|
|
|
group_or_owner: str,
|
|
|
|
|
headers,
|
|
|
|
|
revision_detail: str,
|
|
|
|
|
repo_type: Optional[str] = None,
|
|
|
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
|
|
|
cookies: Optional[CookieJar] = None,
|
|
|
|
|
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
|
|
|
|
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
|
|
|
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
|
|
|
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
|
|
|
):
|
|
|
|
|
if ignore_file_pattern is None:
|
|
|
|
|
ignore_file_pattern = []
|
|
|
|
|
if isinstance(ignore_file_pattern, str):
|
|
|
|
|
ignore_file_pattern = [ignore_file_pattern]
|
|
|
|
|
ignore_file_pattern = [
|
|
|
|
|
item if not item.endswith('/') else item + '*'
|
|
|
|
|
for item in ignore_file_pattern
|
|
|
|
|
]
|
|
|
|
|
ignore_regex_pattern = []
|
|
|
|
|
for file_pattern in ignore_file_pattern:
|
|
|
|
|
if file_pattern.startswith('*'):
|
|
|
|
|
ignore_regex_pattern.append('.' + file_pattern)
|
|
|
|
|
else:
|
|
|
|
|
ignore_regex_pattern.append(file_pattern)
|
|
|
|
|
|
|
|
|
|
if allow_file_pattern is not None:
|
|
|
|
|
if isinstance(allow_file_pattern, str):
|
|
|
|
|
allow_file_pattern = [allow_file_pattern]
|
|
|
|
|
allow_file_pattern = [
|
|
|
|
|
item if not item.endswith('/') else item + '*'
|
|
|
|
|
for item in allow_file_pattern
|
|
|
|
|
]
|
|
|
|
|
ignore_patterns = _normalize_patterns(ignore_patterns)
|
|
|
|
|
allow_patterns = _normalize_patterns(allow_patterns)
|
|
|
|
|
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
|
|
|
|
|
allow_file_pattern = _normalize_patterns(allow_file_pattern)
|
|
|
|
|
# to compatible regex usage.
|
|
|
|
|
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
|
|
|
|
|
|
|
|
|
|
for repo_file in repo_files:
|
|
|
|
|
if repo_file['Type'] == 'tree' or \
|
|
|
|
|
any([fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in ignore_file_pattern]) or \
|
|
|
|
|
any([re.search(pattern, repo_file['Name']) is not None for pattern in ignore_regex_pattern]): # noqa E501
|
|
|
|
|
if repo_file['Type'] == 'tree':
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if allow_file_pattern is not None and allow_file_pattern:
|
|
|
|
|
if not any(
|
|
|
|
|
try:
|
|
|
|
|
# processing patterns
|
|
|
|
|
if ignore_patterns and any([
|
|
|
|
|
fnmatch.fnmatch(repo_file['Path'], pattern)
|
|
|
|
|
for pattern in allow_file_pattern):
|
|
|
|
|
for pattern in ignore_patterns
|
|
|
|
|
]):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if ignore_file_pattern and any([
|
|
|
|
|
fnmatch.fnmatch(repo_file['Path'], pattern)
|
|
|
|
|
for pattern in ignore_file_pattern
|
|
|
|
|
]):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if ignore_regex_pattern and any([
|
|
|
|
|
re.search(pattern, repo_file['Name']) is not None
|
|
|
|
|
for pattern in ignore_regex_pattern
|
|
|
|
|
]): # noqa E501
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if allow_patterns is not None and allow_patterns:
|
|
|
|
|
if not any(
|
|
|
|
|
fnmatch.fnmatch(repo_file['Path'], pattern)
|
|
|
|
|
for pattern in allow_patterns):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if allow_file_pattern is not None and allow_file_pattern:
|
|
|
|
|
if not any(
|
|
|
|
|
fnmatch.fnmatch(repo_file['Path'], pattern)
|
|
|
|
|
for pattern in allow_file_pattern):
|
|
|
|
|
continue
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning('The file pattern is invalid : %s' % e)
|
|
|
|
|
|
|
|
|
|
# check model_file is exist in cache, if existed, skip download, otherwise download
|
|
|
|
|
if cache.exists(repo_file):
|
|
|
|
|
file_name = os.path.basename(repo_file['Name'])
|
|
|
|
|
|