mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
[Fix] Fix dataset preview args, private streaming auth, and download retries (#1700)
This commit is contained in:
@@ -12,6 +12,7 @@ Sub-modules:
|
||||
_module_factories – dataset module factory functions & data-file resolution
|
||||
"""
|
||||
import contextlib
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import fields
|
||||
@@ -407,10 +408,26 @@ def _get_paths_info(
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# HfFileSystem patch (_hf_fs_open)
|
||||
# HfFileSystem patches (_hf_fs_open, _hf_fs_init)
|
||||
# ===================================================================
|
||||
|
||||
_hf_fs_open_original = None
|
||||
_hf_fs_init_original = None
|
||||
|
||||
|
||||
def _hf_fs_init_with_cookie(self, *args, endpoint=None, token=None, **kwargs):
|
||||
"""HfFileSystem.__init__ wrapper that injects ModelScope cookie auth.
|
||||
|
||||
ModelScope's /resolve/ endpoint authenticates via `m_session_id` cookie
|
||||
rather than the `Authorization: Bearer` header used by HuggingFace Hub.
|
||||
This wrapper ensures the cookie is included in all subsequent HTTP
|
||||
requests made by the HfFileSystem instance.
|
||||
"""
|
||||
_hf_fs_init_original(self, *args, endpoint=endpoint, token=token, **kwargs)
|
||||
if token and isinstance(token, str):
|
||||
if not hasattr(self._api, 'headers') or self._api.headers is None:
|
||||
self._api.headers = {}
|
||||
self._api.headers['Cookie'] = f'm_session_id={token}'
|
||||
|
||||
|
||||
def _hf_fs_open(self, path, mode='rb', **kwargs):
|
||||
@@ -769,6 +786,24 @@ class DatasetsWrapperHF:
|
||||
builder_cls = get_dataset_builder_class(
|
||||
dataset_module, dataset_name=dataset_name)
|
||||
|
||||
_config_cls = builder_cls.BUILDER_CONFIG_CLASS
|
||||
if hasattr(_config_cls, '__dataclass_fields__'):
|
||||
_valid_fields = set(_config_cls.__dataclass_fields__.keys())
|
||||
# Also preserve parameters accepted by the builder's
|
||||
# __init__ (e.g. writer_batch_size, base_path, repo_id)
|
||||
# so they are not inadvertently stripped.
|
||||
try:
|
||||
_init_params = set(
|
||||
inspect.signature(builder_cls.__init__).parameters.keys()
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
_init_params = set()
|
||||
_valid_fields = _valid_fields | _init_params
|
||||
config_kwargs = {
|
||||
k: v for k, v in config_kwargs.items()
|
||||
if k in _valid_fields
|
||||
}
|
||||
|
||||
builder_instance: DatasetBuilder = builder_cls(
|
||||
cache_dir=cache_dir,
|
||||
dataset_name=dataset_name,
|
||||
@@ -1032,7 +1067,7 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
non-streaming mode) or kept alive (for streaming mode, where lazy
|
||||
iteration needs the patches to remain active).
|
||||
"""
|
||||
global _hf_fs_open_original
|
||||
global _hf_fs_open_original, _hf_fs_init_original
|
||||
|
||||
# Save originals
|
||||
hf_endpoint_origin = config.HF_ENDPOINT
|
||||
@@ -1048,6 +1083,7 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
HubDatasetModuleFactoryWithScript.get_module if _HAS_SCRIPT_LOADING else None)
|
||||
generate_from_dict_origin = features.generate_from_dict
|
||||
hf_fs_open_origin = HfFileSystem._open
|
||||
hf_fs_init_origin = HfFileSystem.__init__
|
||||
|
||||
# Apply patches
|
||||
config.HF_ENDPOINT = get_endpoint()
|
||||
@@ -1066,6 +1102,8 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
features.generate_from_dict = generate_from_dict_ms
|
||||
_hf_fs_open_original = hf_fs_open_origin
|
||||
HfFileSystem._open = _hf_fs_open
|
||||
_hf_fs_init_original = hf_fs_init_origin
|
||||
HfFileSystem.__init__ = _hf_fs_init_with_cookie
|
||||
|
||||
streaming = kwargs.get('streaming', False)
|
||||
|
||||
@@ -1076,10 +1114,12 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
_repo_tree_cache.clear()
|
||||
HubApi._dataset_id_type_cache.clear()
|
||||
|
||||
HfFileSystem._open = hf_fs_open_origin
|
||||
_hf_fs_open_original = None
|
||||
|
||||
if not streaming:
|
||||
HfFileSystem._open = hf_fs_open_origin
|
||||
_hf_fs_open_original = None
|
||||
HfFileSystem.__init__ = hf_fs_init_origin
|
||||
_hf_fs_init_original = None
|
||||
|
||||
config.HF_ENDPOINT = hf_endpoint_origin
|
||||
file_utils.get_from_cache = get_from_cache_origin
|
||||
features.generate_from_dict = generate_from_dict_origin
|
||||
|
||||
@@ -53,7 +53,7 @@ def _request_with_retry_ms(
|
||||
url: str,
|
||||
max_retries: int = 2,
|
||||
base_wait_time: float = 0.5,
|
||||
max_wait_time: float = 8,
|
||||
max_wait_time: float = 3,
|
||||
timeout: float = 10.0,
|
||||
**params,
|
||||
) -> requests.Response:
|
||||
@@ -73,14 +73,35 @@ def _request_with_retry_ms(
|
||||
"""
|
||||
tries, success = 0, False
|
||||
response = None
|
||||
range_header = (params.get('headers') or {}).get('Range', '')
|
||||
while not success:
|
||||
tries += 1
|
||||
try:
|
||||
logger.debug(
|
||||
'[MS_DOWNLOAD] _request_with_retry_ms sending request: '
|
||||
'method=%s, url=%s, timeout=%s, Range=%s',
|
||||
method, url, timeout, range_header or 'N/A',
|
||||
)
|
||||
t0 = time.perf_counter()
|
||||
response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
|
||||
elapsed = time.perf_counter() - t0
|
||||
logger.debug(
|
||||
'[MS_DOWNLOAD] _request_with_retry_ms response: '
|
||||
'status=%s, content_length=%s, elapsed=%.3fs, url=%s',
|
||||
response.status_code,
|
||||
response.headers.get('Content-Length', 'N/A'),
|
||||
elapsed,
|
||||
url,
|
||||
)
|
||||
success = True
|
||||
except (requests.exceptions.ConnectTimeout,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.ReadTimeout) as err:
|
||||
except (requests.exceptions.ReadTimeout,
|
||||
requests.exceptions.ConnectTimeout,
|
||||
requests.exceptions.ConnectionError) as err:
|
||||
logger.error(
|
||||
'[MS_DOWNLOAD] _request_with_retry_ms %s: '
|
||||
'method=%s, url=%s, timeout=%s, error=%s',
|
||||
type(err).__name__, method, url, timeout, err,
|
||||
)
|
||||
if tries > max_retries:
|
||||
raise err
|
||||
else:
|
||||
@@ -111,6 +132,10 @@ def http_head_ms(
|
||||
def http_get_ms(
|
||||
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=300.0, max_retries=3, desc=None
|
||||
) -> Optional[requests.Response]:
|
||||
logger.debug(
|
||||
'[MS_DOWNLOAD] http_get_ms entry: url=%s, timeout=%s, resume_size=%s',
|
||||
url, timeout, resume_size,
|
||||
)
|
||||
headers = dict(headers) if headers is not None else {}
|
||||
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
|
||||
if resume_size > 0:
|
||||
@@ -323,6 +348,7 @@ def get_from_cache_ms(
|
||||
if scheme not in ('http', 'https'):
|
||||
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
||||
else:
|
||||
logger.info('[MS_DOWNLOAD] get_from_cache_ms downloading: url=%s', url)
|
||||
http_get_ms(
|
||||
url,
|
||||
temp_file=temp_file,
|
||||
|
||||
Reference in New Issue
Block a user