Adapt new datasets (#1002)

* update ds==3.0

* update

* add http_get_ms func

* del unused code

* fix pr issue and update requirements
This commit is contained in:
Xingjun.Wang
2024-09-30 16:46:00 +08:00
committed by GitHub
parent 834db59952
commit 2c4505e13a
5 changed files with 139 additions and 17 deletions

View File

@@ -989,7 +989,6 @@ class DatasetsWrapperHF:
download_config=download_config, download_config=download_config,
download_mode=download_mode, download_mode=download_mode,
verification_mode=verification_mode, verification_mode=verification_mode,
try_from_hf_gcs=False,
num_proc=num_proc, num_proc=num_proc,
storage_options=storage_options, storage_options=storage_options,
# base_path=builder_instance.base_path, # base_path=builder_instance.base_path,

View File

@@ -5,27 +5,138 @@
import json import json
import os import os
import re import re
import copy
import shutil import shutil
import time
import warnings import warnings
import inspect
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import requests import requests
from tqdm import tqdm
from datasets import config from datasets import config
from datasets.utils.file_utils import hash_url_to_filename, get_authentication_headers_for_url, ftp_head, fsspec_head, \ from datasets.utils.file_utils import hash_url_to_filename, \
http_head, _raise_if_offline_mode_is_enabled, ftp_get, fsspec_get, http_get get_authentication_headers_for_url, fsspec_head, fsspec_get
from filelock import FileLock from filelock import FileLock
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.hub.api import ModelScopeConfig from modelscope.hub.api import ModelScopeConfig
from modelscope import __version__
logger = get_logger() logger = get_logger()
def get_datasets_user_agent_ms(user_agent: Optional[Union[str, dict]] = None) -> str:
ua = f'datasets/{__version__}'
ua += f'; python/{config.PY_VERSION}'
ua += f'; pyarrow/{config.PYARROW_VERSION}'
if config.TORCH_AVAILABLE:
ua += f'; torch/{config.TORCH_VERSION}'
if config.TF_AVAILABLE:
ua += f'; tensorflow/{config.TF_VERSION}'
if config.JAX_AVAILABLE:
ua += f'; jax/{config.JAX_VERSION}'
if isinstance(user_agent, dict):
ua += f"; {'; '.join(f'{k}/{v}' for k, v in user_agent.items())}"
elif isinstance(user_agent, str):
ua += '; ' + user_agent
return ua
def _request_with_retry_ms(
method: str,
url: str,
max_retries: int = 2,
base_wait_time: float = 0.5,
max_wait_time: float = 2,
timeout: float = 10.0,
**params,
) -> requests.Response:
"""Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
Note that if the environment variable HF_DATASETS_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
Args:
method (str): HTTP method, such as 'GET' or 'HEAD'.
url (str): The URL of the resource to fetch.
max_retries (int): Maximum number of retries, defaults to 0 (no retries).
base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
retries then grows exponentially, capped by max_wait_time.
max_wait_time (float): Maximum amount of time between two retries, in seconds.
**params (additional keyword arguments): Params to pass to :obj:`requests.request`.
"""
tries, success = 0, False
response = None
while not success:
tries += 1
try:
response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
success = True
except (requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError) as err:
if tries > max_retries:
raise err
else:
logger.info(f'{method} request to {url} timed out, retrying... [{tries/max_retries}]')
sleep_time = min(max_wait_time, base_wait_time * 2 ** (tries - 1)) # Exponential backoff
time.sleep(sleep_time)
return response
def http_head_ms(
url, proxies=None, headers=None, cookies=None, allow_redirects=True, timeout=10.0, max_retries=0
) -> requests.Response:
headers = copy.deepcopy(headers) or {}
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
response = _request_with_retry_ms(
method='HEAD',
url=url,
proxies=proxies,
headers=headers,
cookies=cookies,
allow_redirects=allow_redirects,
timeout=timeout,
max_retries=max_retries,
)
return response
def http_get_ms(
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
) -> Optional[requests.Response]:
headers = dict(headers) if headers is not None else {}
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
if resume_size > 0:
headers['Range'] = f'bytes={resume_size:d}-'
response = _request_with_retry_ms(
method='GET',
url=url,
stream=True,
proxies=proxies,
headers=headers,
cookies=cookies,
max_retries=max_retries,
timeout=timeout,
)
if temp_file is None:
return response
if response.status_code == 416: # Range not satisfiable
return
content_length = response.headers.get('Content-Length')
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(total=total, initial=resume_size, unit_scale=True, unit='B', desc=desc or 'Downloading')
for chunk in response.iter_content(chunk_size=1024):
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache_ms( def get_from_cache_ms(
url, url,
cache_dir=None, cache_dir=None,
@@ -42,6 +153,7 @@ def get_from_cache_ms(
ignore_url_params=False, ignore_url_params=False,
storage_options=None, storage_options=None,
download_desc=None, download_desc=None,
disable_tqdm=None,
) -> str: ) -> str:
""" """
Given a URL, look for the corresponding file in the local cache. Given a URL, look for the corresponding file in the local cache.
@@ -101,16 +213,14 @@ def get_from_cache_ms(
# We don't have the file locally or we need an eTag # We don't have the file locally or we need an eTag
if not local_files_only: if not local_files_only:
scheme = urlparse(url).scheme scheme = urlparse(url).scheme
if scheme == 'ftp': if scheme not in ('http', 'https'):
connected = ftp_head(url)
elif scheme not in ('http', 'https'):
response = fsspec_head(url, storage_options=storage_options) response = fsspec_head(url, storage_options=storage_options)
# s3fs uses "ETag", gcsfs uses "etag" # s3fs uses "ETag", gcsfs uses "etag"
etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
connected = True connected = True
try: try:
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
response = http_head( response = http_head_ms(
url, url,
allow_redirects=True, allow_redirects=True,
proxies=proxies, proxies=proxies,
@@ -167,7 +277,6 @@ def get_from_cache_ms(
) )
elif response is not None and response.status_code == 404: elif response is not None and response.status_code == 404:
raise FileNotFoundError(f"Couldn't find file at {url}") raise FileNotFoundError(f"Couldn't find file at {url}")
_raise_if_offline_mode_is_enabled(f'Tried to reach {url}')
if head_error is not None: if head_error is not None:
raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})") raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
elif response is not None: elif response is not None:
@@ -206,16 +315,21 @@ def get_from_cache_ms(
# Download to temporary file, then copy to cache path once finished. # Download to temporary file, then copy to cache path once finished.
# Otherwise, you get corrupt cache entries if the download gets interrupted. # Otherwise, you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
logger.info(f'Downloading to {temp_file.name}')
# GET file object # GET file object
if scheme == 'ftp': if scheme not in ('http', 'https'):
ftp_get(url, temp_file)
elif scheme not in ('http', 'https'):
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc) fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
else: else:
http_get(url, temp_file=temp_file, proxies=proxies, resume_size=resume_size, http_get_ms(
headers=headers, cookies=cookies, max_retries=max_retries, desc=download_desc) url,
temp_file=temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
)
logger.info(f'storing {url} in cache at {cache_path}') logger.info(f'storing {url} in cache at {cache_path}')
shutil.move(temp_file.name, cache_path) shutil.move(temp_file.name, cache_path)

View File

@@ -1,6 +1,6 @@
addict addict
attrs attrs
datasets>=2.18.0,<3.0.0 datasets>=3.0.0
einops einops
oss2 oss2
Pillow Pillow

View File

@@ -1,6 +1,6 @@
addict addict
attrs attrs
datasets>=2.18.0,<3.0.0 datasets>=3.0.0
einops einops
oss2 oss2
Pillow Pillow

View File

@@ -44,6 +44,15 @@ class TestStreamLoad(unittest.TestCase):
assert sample['question'], f'Failed to load sample from {repo_id}' assert sample['question'], f'Failed to load sample from {repo_id}'
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_stream_swift_jsonl(self):
repo_id: str = 'iic/MSAgent-MultiRole'
ds = MsDataset.load(repo_id, split='train', use_streaming=True)
sample = next(iter(ds))
logger.info(sample)
assert sample['id'], f'Failed to load sample from {repo_id}'
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()