shorten check model timeout and snapshot_download file patter folder/ to folder/*

This commit is contained in:
mulin.lyh
2024-07-01 14:39:32 +08:00
parent 5b8e1e971c
commit c316211e2d
3 changed files with 11 additions and 3 deletions

View File

@@ -61,7 +61,7 @@ logger = get_logger()
class HubApi:
"""Model hub api interface.
"""
def __init__(self, endpoint: Optional[str] = None):
def __init__(self, endpoint: Optional[str] = None, timeout=API_HTTP_CLIENT_TIMEOUT):
"""The ModelScope HubApi。
Args:
@@ -86,7 +86,7 @@ class HubApi:
self.session, method,
functools.partial(
getattr(self.session, method),
timeout=API_HTTP_CLIENT_TIMEOUT))
timeout=timeout))
def login(
self,

View File

@@ -48,7 +48,7 @@ def check_local_model_is_latest(
'Snapshot': 'True'
}
}
_api = HubApi()
_api = HubApi(timeout=0.5)
try:
_, revisions = _api.get_model_branches_and_tags(
model_id=model_id, use_cookies=cookies)

View File

@@ -118,10 +118,18 @@ def snapshot_download(
ignore_file_pattern = []
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
ignore_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in ignore_file_pattern
]
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
allow_file_pattern = [allow_file_pattern]
allow_file_pattern = [
item if not item.endswith('/') else item + '*'
for item in allow_file_pattern
]
for model_file in model_files:
if model_file['Type'] == 'tree' or \