mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
user-custom timeout and retry in HubApi
This commit is contained in:
@@ -22,7 +22,8 @@ import requests
|
||||
from requests import Session
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
|
||||
from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT,
|
||||
from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
||||
API_HTTP_CLIENT_TIMEOUT,
|
||||
API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_EMAIL,
|
||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN,
|
||||
@@ -61,7 +62,10 @@ logger = get_logger()
|
||||
class HubApi:
|
||||
"""Model hub api interface.
|
||||
"""
|
||||
def __init__(self, endpoint: Optional[str] = None, timeout=API_HTTP_CLIENT_TIMEOUT):
|
||||
def __init__(self,
|
||||
endpoint: Optional[str] = None,
|
||||
timeout=API_HTTP_CLIENT_TIMEOUT,
|
||||
max_retries=API_HTTP_CLIENT_MAX_RETRIES):
|
||||
"""The ModelScope HubApi。
|
||||
|
||||
Args:
|
||||
@@ -71,7 +75,7 @@ class HubApi:
|
||||
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
|
||||
self.session = Session()
|
||||
retry = Retry(
|
||||
total=2,
|
||||
total=max_retries,
|
||||
read=2,
|
||||
connect=2,
|
||||
backoff_factor=1,
|
||||
|
||||
@@ -17,6 +17,7 @@ LOGGER_NAME = 'ModelScopeHub'
|
||||
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
|
||||
REQUESTS_API_HTTP_METHOD = ['get', 'head', 'post', 'put', 'patch', 'delete']
|
||||
API_HTTP_CLIENT_TIMEOUT = 60
|
||||
API_HTTP_CLIENT_MAX_RETRIES = 2
|
||||
API_RESPONSE_FIELD_DATA = 'Data'
|
||||
API_FILE_DOWNLOAD_RETRY_TIMES = 5
|
||||
API_FILE_DOWNLOAD_TIMEOUT = 60
|
||||
|
||||
@@ -218,7 +218,7 @@ def _list_repo_tree(
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
) -> Iterable[Union[RepoFile, RepoFolder]]:
|
||||
|
||||
_api = HubApi()
|
||||
_api = HubApi(timeout=3 * 60, max_retries=3)
|
||||
|
||||
if is_relative_path(repo_id) and repo_id.count('/') == 1:
|
||||
_namespace, _dataset_name = repo_id.split('/')
|
||||
@@ -231,7 +231,6 @@ def _list_repo_tree(
|
||||
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
total_data_list = []
|
||||
while True:
|
||||
data: dict = _api.list_repo_tree(dataset_name=_dataset_name,
|
||||
namespace=_namespace,
|
||||
@@ -247,7 +246,6 @@ def _list_repo_tree(
|
||||
|
||||
# Parse data (Type: 'tree' or 'blob')
|
||||
data_file_list: list = data['Data']['Files']
|
||||
total_data_list.extend(data_file_list)
|
||||
|
||||
for file_info_d in data_file_list:
|
||||
path_info = {}
|
||||
@@ -398,7 +396,10 @@ def _resolve_pattern(
|
||||
# 10 times faster glob with detail=True (ignores costly info like lastCommit)
|
||||
glob_kwargs['expand_info'] = False
|
||||
|
||||
tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
|
||||
try:
|
||||
tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
|
||||
except FileNotFoundError:
|
||||
raise DataFilesNotFoundError(f"Unable to find '{pattern}'")
|
||||
|
||||
matched_paths = [
|
||||
filepath if filepath.startswith(protocol_prefix) else protocol_prefix
|
||||
@@ -857,13 +858,6 @@ def get_module_with_script(self) -> DatasetModule:
|
||||
return DatasetModule(module_path, hash, builder_kwargs)
|
||||
|
||||
|
||||
def increase_load_count_ms(name: str, resource_type: str):
|
||||
"""
|
||||
Placeholder for increasing the load count of a hf resource.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DatasetsWrapperHF:
|
||||
|
||||
@staticmethod
|
||||
@@ -1307,6 +1301,7 @@ class DatasetsWrapperHF:
|
||||
).get_module()
|
||||
except Exception as e1:
|
||||
# All the attempts failed, before raising the error we should check if the module is already cached
|
||||
logger.error(f'>> Error loading {path}: {e1}')
|
||||
try:
|
||||
return CachedDatasetModuleFactory(
|
||||
path,
|
||||
@@ -1337,7 +1332,6 @@ class DatasetsWrapperHF:
|
||||
|
||||
@contextlib.contextmanager
|
||||
def load_dataset_with_ctx(*args, **kwargs):
|
||||
from datasets.load import increase_load_count as increase_load_count
|
||||
|
||||
hf_endpoint_origin = config.HF_ENDPOINT
|
||||
get_from_cache_origin = file_utils.get_from_cache
|
||||
@@ -1346,8 +1340,6 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
_download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
|
||||
else DownloadManager._download_single
|
||||
|
||||
increase_load_count_origin = increase_load_count
|
||||
|
||||
dataset_info_origin = HfApi.dataset_info
|
||||
list_repo_tree_origin = HfApi.list_repo_tree
|
||||
get_paths_info_origin = HfApi.get_paths_info
|
||||
@@ -1370,7 +1362,6 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
data_files.resolve_pattern = _resolve_pattern
|
||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
|
||||
increase_load_count = increase_load_count_ms
|
||||
|
||||
try:
|
||||
dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs)
|
||||
@@ -1391,4 +1382,3 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
data_files.resolve_pattern = resolve_pattern_origin
|
||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
|
||||
increase_load_count = increase_load_count_origin
|
||||
|
||||
Reference in New Issue
Block a user