user-custom timeout and retry in HubApi

This commit is contained in:
xingjun.wang
2024-08-08 10:42:42 +08:00
parent b355360727
commit fe3eb0d0c3
3 changed files with 14 additions and 19 deletions

View File

@@ -22,7 +22,8 @@ import requests
from requests import Session
from requests.adapters import HTTPAdapter, Retry
from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT,
from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
API_HTTP_CLIENT_TIMEOUT,
API_RESPONSE_FIELD_DATA,
API_RESPONSE_FIELD_EMAIL,
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN,
@@ -61,7 +62,10 @@ logger = get_logger()
class HubApi:
"""Model hub api interface.
"""
def __init__(self, endpoint: Optional[str] = None, timeout=API_HTTP_CLIENT_TIMEOUT):
def __init__(self,
endpoint: Optional[str] = None,
timeout=API_HTTP_CLIENT_TIMEOUT,
max_retries=API_HTTP_CLIENT_MAX_RETRIES):
"""The ModelScope HubApi。
Args:
@@ -71,7 +75,7 @@ class HubApi:
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
self.session = Session()
retry = Retry(
total=2,
total=max_retries,
read=2,
connect=2,
backoff_factor=1,

View File

@@ -17,6 +17,7 @@ LOGGER_NAME = 'ModelScopeHub'
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
REQUESTS_API_HTTP_METHOD = ['get', 'head', 'post', 'put', 'patch', 'delete']
API_HTTP_CLIENT_TIMEOUT = 60
API_HTTP_CLIENT_MAX_RETRIES = 2
API_RESPONSE_FIELD_DATA = 'Data'
API_FILE_DOWNLOAD_RETRY_TIMES = 5
API_FILE_DOWNLOAD_TIMEOUT = 60

View File

@@ -218,7 +218,7 @@ def _list_repo_tree(
token: Optional[Union[bool, str]] = None,
) -> Iterable[Union[RepoFile, RepoFolder]]:
_api = HubApi()
_api = HubApi(timeout=3 * 60, max_retries=3)
if is_relative_path(repo_id) and repo_id.count('/') == 1:
_namespace, _dataset_name = repo_id.split('/')
@@ -231,7 +231,6 @@ def _list_repo_tree(
page_number = 1
page_size = 100
total_data_list = []
while True:
data: dict = _api.list_repo_tree(dataset_name=_dataset_name,
namespace=_namespace,
@@ -247,7 +246,6 @@ def _list_repo_tree(
# Parse data (Type: 'tree' or 'blob')
data_file_list: list = data['Data']['Files']
total_data_list.extend(data_file_list)
for file_info_d in data_file_list:
path_info = {}
@@ -398,7 +396,10 @@ def _resolve_pattern(
# 10 times faster glob with detail=True (ignores costly info like lastCommit)
glob_kwargs['expand_info'] = False
tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
try:
tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
except FileNotFoundError:
raise DataFilesNotFoundError(f"Unable to find '{pattern}'")
matched_paths = [
filepath if filepath.startswith(protocol_prefix) else protocol_prefix
@@ -857,13 +858,6 @@ def get_module_with_script(self) -> DatasetModule:
return DatasetModule(module_path, hash, builder_kwargs)
def increase_load_count_ms(name: str, resource_type: str):
"""
Placeholder for increasing the load count of a hf resource.
"""
...
class DatasetsWrapperHF:
@staticmethod
@@ -1307,6 +1301,7 @@ class DatasetsWrapperHF:
).get_module()
except Exception as e1:
# All the attempts failed, before raising the error we should check if the module is already cached
logger.error(f'>> Error loading {path}: {e1}')
try:
return CachedDatasetModuleFactory(
path,
@@ -1337,7 +1332,6 @@ class DatasetsWrapperHF:
@contextlib.contextmanager
def load_dataset_with_ctx(*args, **kwargs):
from datasets.load import increase_load_count as increase_load_count
hf_endpoint_origin = config.HF_ENDPOINT
get_from_cache_origin = file_utils.get_from_cache
@@ -1346,8 +1340,6 @@ def load_dataset_with_ctx(*args, **kwargs):
_download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
else DownloadManager._download_single
increase_load_count_origin = increase_load_count
dataset_info_origin = HfApi.dataset_info
list_repo_tree_origin = HfApi.list_repo_tree
get_paths_info_origin = HfApi.get_paths_info
@@ -1370,7 +1362,6 @@ def load_dataset_with_ctx(*args, **kwargs):
data_files.resolve_pattern = _resolve_pattern
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
increase_load_count = increase_load_count_ms
try:
dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs)
@@ -1391,4 +1382,3 @@ def load_dataset_with_ctx(*args, **kwargs):
data_files.resolve_pattern = resolve_pattern_origin
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
increase_load_count = increase_load_count_origin