[Fix] Fix dataset preview args, private streaming auth, and download retries (#1700)

This commit is contained in:
Xingjun.Wang
2026-05-08 15:13:52 +08:00
committed by GitHub
parent 2f5f52fc3c
commit c15c5261e1
2 changed files with 75 additions and 9 deletions

View File

@@ -12,6 +12,7 @@ Sub-modules:
_module_factories dataset module factory functions & data-file resolution
"""
import contextlib
import inspect
import os
import warnings
from dataclasses import fields
@@ -407,10 +408,26 @@ def _get_paths_info(
# ===================================================================
# HfFileSystem patch (_hf_fs_open)
# HfFileSystem patches (_hf_fs_open, _hf_fs_init)
# ===================================================================
_hf_fs_open_original = None
_hf_fs_init_original = None
def _hf_fs_init_with_cookie(self, *args, endpoint=None, token=None, **kwargs):
"""HfFileSystem.__init__ wrapper that injects ModelScope cookie auth.
ModelScope's /resolve/ endpoint authenticates via `m_session_id` cookie
rather than the `Authorization: Bearer` header used by HuggingFace Hub.
This wrapper ensures the cookie is included in all subsequent HTTP
requests made by the HfFileSystem instance.
"""
_hf_fs_init_original(self, *args, endpoint=endpoint, token=token, **kwargs)
if token and isinstance(token, str):
if not hasattr(self._api, 'headers') or self._api.headers is None:
self._api.headers = {}
self._api.headers['Cookie'] = f'm_session_id={token}'
def _hf_fs_open(self, path, mode='rb', **kwargs):
@@ -769,6 +786,24 @@ class DatasetsWrapperHF:
builder_cls = get_dataset_builder_class(
dataset_module, dataset_name=dataset_name)
_config_cls = builder_cls.BUILDER_CONFIG_CLASS
if hasattr(_config_cls, '__dataclass_fields__'):
_valid_fields = set(_config_cls.__dataclass_fields__.keys())
# Also preserve parameters accepted by the builder's
# __init__ (e.g. writer_batch_size, base_path, repo_id)
# so they are not inadvertently stripped.
try:
_init_params = set(
inspect.signature(builder_cls.__init__).parameters.keys()
)
except (ValueError, TypeError):
_init_params = set()
_valid_fields = _valid_fields | _init_params
config_kwargs = {
k: v for k, v in config_kwargs.items()
if k in _valid_fields
}
builder_instance: DatasetBuilder = builder_cls(
cache_dir=cache_dir,
dataset_name=dataset_name,
@@ -1032,7 +1067,7 @@ def load_dataset_with_ctx(*args, **kwargs):
non-streaming mode) or kept alive (for streaming mode, where lazy
iteration needs the patches to remain active).
"""
global _hf_fs_open_original
global _hf_fs_open_original, _hf_fs_init_original
# Save originals
hf_endpoint_origin = config.HF_ENDPOINT
@@ -1048,6 +1083,7 @@ def load_dataset_with_ctx(*args, **kwargs):
HubDatasetModuleFactoryWithScript.get_module if _HAS_SCRIPT_LOADING else None)
generate_from_dict_origin = features.generate_from_dict
hf_fs_open_origin = HfFileSystem._open
hf_fs_init_origin = HfFileSystem.__init__
# Apply patches
config.HF_ENDPOINT = get_endpoint()
@@ -1066,6 +1102,8 @@ def load_dataset_with_ctx(*args, **kwargs):
features.generate_from_dict = generate_from_dict_ms
_hf_fs_open_original = hf_fs_open_origin
HfFileSystem._open = _hf_fs_open
_hf_fs_init_original = hf_fs_init_origin
HfFileSystem.__init__ = _hf_fs_init_with_cookie
streaming = kwargs.get('streaming', False)
@@ -1076,10 +1114,12 @@ def load_dataset_with_ctx(*args, **kwargs):
_repo_tree_cache.clear()
HubApi._dataset_id_type_cache.clear()
HfFileSystem._open = hf_fs_open_origin
_hf_fs_open_original = None
if not streaming:
HfFileSystem._open = hf_fs_open_origin
_hf_fs_open_original = None
HfFileSystem.__init__ = hf_fs_init_origin
_hf_fs_init_original = None
config.HF_ENDPOINT = hf_endpoint_origin
file_utils.get_from_cache = get_from_cache_origin
features.generate_from_dict = generate_from_dict_origin

View File

@@ -53,7 +53,7 @@ def _request_with_retry_ms(
url: str,
max_retries: int = 2,
base_wait_time: float = 0.5,
max_wait_time: float = 8,
max_wait_time: float = 3,
timeout: float = 10.0,
**params,
) -> requests.Response:
@@ -73,14 +73,35 @@ def _request_with_retry_ms(
"""
tries, success = 0, False
response = None
range_header = (params.get('headers') or {}).get('Range', '')
while not success:
tries += 1
try:
logger.debug(
'[MS_DOWNLOAD] _request_with_retry_ms sending request: '
'method=%s, url=%s, timeout=%s, Range=%s',
method, url, timeout, range_header or 'N/A',
)
t0 = time.perf_counter()
response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
elapsed = time.perf_counter() - t0
logger.debug(
'[MS_DOWNLOAD] _request_with_retry_ms response: '
'status=%s, content_length=%s, elapsed=%.3fs, url=%s',
response.status_code,
response.headers.get('Content-Length', 'N/A'),
elapsed,
url,
)
success = True
except (requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError,
requests.exceptions.ReadTimeout) as err:
except (requests.exceptions.ReadTimeout,
requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError) as err:
logger.error(
'[MS_DOWNLOAD] _request_with_retry_ms %s: '
'method=%s, url=%s, timeout=%s, error=%s',
type(err).__name__, method, url, timeout, err,
)
if tries > max_retries:
raise err
else:
@@ -111,6 +132,10 @@ def http_head_ms(
def http_get_ms(
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=300.0, max_retries=3, desc=None
) -> Optional[requests.Response]:
logger.debug(
'[MS_DOWNLOAD] http_get_ms entry: url=%s, timeout=%s, resume_size=%s',
url, timeout, resume_size,
)
headers = dict(headers) if headers is not None else {}
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
if resume_size > 0:
@@ -323,6 +348,7 @@ def get_from_cache_ms(
if scheme not in ('http', 'https'):
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
else:
logger.info('[MS_DOWNLOAD] get_from_cache_ms downloading: url=%s', url)
http_get_ms(
url,
temp_file=temp_file,