Adapt new datasets (#1002)

* update ds==3.0

* update

* add http_get_ms func

* del unused code

* fix pr issue and update requirements
This commit is contained in:
Xingjun.Wang
2024-09-30 16:46:00 +08:00
committed by GitHub
parent 834db59952
commit 2c4505e13a
5 changed files with 139 additions and 17 deletions

View File

@@ -989,7 +989,6 @@ class DatasetsWrapperHF:
download_config=download_config,
download_mode=download_mode,
verification_mode=verification_mode,
try_from_hf_gcs=False,
num_proc=num_proc,
storage_options=storage_options,
# base_path=builder_instance.base_path,

View File

@@ -5,27 +5,138 @@
import json
import os
import re
import copy
import shutil
import time
import warnings
import inspect
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Optional, Union
from urllib.parse import urljoin, urlparse
import requests
from tqdm import tqdm
from datasets import config
from datasets.utils.file_utils import hash_url_to_filename, get_authentication_headers_for_url, ftp_head, fsspec_head, \
http_head, _raise_if_offline_mode_is_enabled, ftp_get, fsspec_get, http_get
from datasets.utils.file_utils import hash_url_to_filename, \
get_authentication_headers_for_url, fsspec_head, fsspec_get
from filelock import FileLock
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.logger import get_logger
from modelscope.hub.api import ModelScopeConfig
from modelscope import __version__
logger = get_logger()
def get_datasets_user_agent_ms(user_agent: Optional[Union[str, dict]] = None) -> str:
ua = f'datasets/{__version__}'
ua += f'; python/{config.PY_VERSION}'
ua += f'; pyarrow/{config.PYARROW_VERSION}'
if config.TORCH_AVAILABLE:
ua += f'; torch/{config.TORCH_VERSION}'
if config.TF_AVAILABLE:
ua += f'; tensorflow/{config.TF_VERSION}'
if config.JAX_AVAILABLE:
ua += f'; jax/{config.JAX_VERSION}'
if isinstance(user_agent, dict):
ua += f"; {'; '.join(f'{k}/{v}' for k, v in user_agent.items())}"
elif isinstance(user_agent, str):
ua += '; ' + user_agent
return ua
def _request_with_retry_ms(
method: str,
url: str,
max_retries: int = 2,
base_wait_time: float = 0.5,
max_wait_time: float = 2,
timeout: float = 10.0,
**params,
) -> requests.Response:
"""Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
Note that if the environment variable HF_DATASETS_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
Args:
method (str): HTTP method, such as 'GET' or 'HEAD'.
url (str): The URL of the resource to fetch.
max_retries (int): Maximum number of retries, defaults to 0 (no retries).
base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
retries then grows exponentially, capped by max_wait_time.
max_wait_time (float): Maximum amount of time between two retries, in seconds.
**params (additional keyword arguments): Params to pass to :obj:`requests.request`.
"""
tries, success = 0, False
response = None
while not success:
tries += 1
try:
response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
success = True
except (requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError) as err:
if tries > max_retries:
raise err
else:
logger.info(f'{method} request to {url} timed out, retrying... [{tries/max_retries}]')
sleep_time = min(max_wait_time, base_wait_time * 2 ** (tries - 1)) # Exponential backoff
time.sleep(sleep_time)
return response
def http_head_ms(
url, proxies=None, headers=None, cookies=None, allow_redirects=True, timeout=10.0, max_retries=0
) -> requests.Response:
headers = copy.deepcopy(headers) or {}
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
response = _request_with_retry_ms(
method='HEAD',
url=url,
proxies=proxies,
headers=headers,
cookies=cookies,
allow_redirects=allow_redirects,
timeout=timeout,
max_retries=max_retries,
)
return response
def http_get_ms(
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
) -> Optional[requests.Response]:
headers = dict(headers) if headers is not None else {}
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
if resume_size > 0:
headers['Range'] = f'bytes={resume_size:d}-'
response = _request_with_retry_ms(
method='GET',
url=url,
stream=True,
proxies=proxies,
headers=headers,
cookies=cookies,
max_retries=max_retries,
timeout=timeout,
)
if temp_file is None:
return response
if response.status_code == 416: # Range not satisfiable
return
content_length = response.headers.get('Content-Length')
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(total=total, initial=resume_size, unit_scale=True, unit='B', desc=desc or 'Downloading')
for chunk in response.iter_content(chunk_size=1024):
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache_ms(
url,
cache_dir=None,
@@ -42,6 +153,7 @@ def get_from_cache_ms(
ignore_url_params=False,
storage_options=None,
download_desc=None,
disable_tqdm=None,
) -> str:
"""
Given a URL, look for the corresponding file in the local cache.
@@ -101,16 +213,14 @@ def get_from_cache_ms(
# We don't have the file locally or we need an eTag
if not local_files_only:
scheme = urlparse(url).scheme
if scheme == 'ftp':
connected = ftp_head(url)
elif scheme not in ('http', 'https'):
if scheme not in ('http', 'https'):
response = fsspec_head(url, storage_options=storage_options)
# s3fs uses "ETag", gcsfs uses "etag"
etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
connected = True
try:
cookies = ModelScopeConfig.get_cookies()
response = http_head(
response = http_head_ms(
url,
allow_redirects=True,
proxies=proxies,
@@ -167,7 +277,6 @@ def get_from_cache_ms(
)
elif response is not None and response.status_code == 404:
raise FileNotFoundError(f"Couldn't find file at {url}")
_raise_if_offline_mode_is_enabled(f'Tried to reach {url}')
if head_error is not None:
raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
elif response is not None:
@@ -206,16 +315,21 @@ def get_from_cache_ms(
# Download to temporary file, then copy to cache path once finished.
# Otherwise, you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info(f'Downloading to {temp_file.name}')
# GET file object
if scheme == 'ftp':
ftp_get(url, temp_file)
elif scheme not in ('http', 'https'):
if scheme not in ('http', 'https'):
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
else:
http_get(url, temp_file=temp_file, proxies=proxies, resume_size=resume_size,
headers=headers, cookies=cookies, max_retries=max_retries, desc=download_desc)
http_get_ms(
url,
temp_file=temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
)
logger.info(f'storing {url} in cache at {cache_path}')
shutil.move(temp_file.name, cache_path)

View File

@@ -1,6 +1,6 @@
addict
attrs
datasets>=2.18.0,<3.0.0
datasets>=3.0.0
einops
oss2
Pillow

View File

@@ -1,6 +1,6 @@
addict
attrs
datasets>=2.18.0,<3.0.0
datasets>=3.0.0
einops
oss2
Pillow

View File

@@ -44,6 +44,15 @@ class TestStreamLoad(unittest.TestCase):
assert sample['question'], f'Failed to load sample from {repo_id}'
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_stream_swift_jsonl(self):
repo_id: str = 'iic/MSAgent-MultiRole'
ds = MsDataset.load(repo_id, split='train', use_streaming=True)
sample = next(iter(ds))
logger.info(sample)
assert sample['id'], f'Failed to load sample from {repo_id}'
if __name__ == '__main__':
unittest.main()