|
|
|
@@ -5,27 +5,138 @@
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
|
|
|
|
import copy
|
|
|
|
import shutil
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
import time
|
|
|
|
import warnings
|
|
|
|
import warnings
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from functools import partial
|
|
|
|
from functools import partial
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
import requests
|
|
|
|
import requests
|
|
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
from datasets import config
|
|
|
|
from datasets import config
|
|
|
|
from datasets.utils.file_utils import hash_url_to_filename, get_authentication_headers_for_url, ftp_head, fsspec_head, \
|
|
|
|
from datasets.utils.file_utils import hash_url_to_filename, \
|
|
|
|
http_head, _raise_if_offline_mode_is_enabled, ftp_get, fsspec_get, http_get
|
|
|
|
get_authentication_headers_for_url, fsspec_head, fsspec_get
|
|
|
|
from filelock import FileLock
|
|
|
|
from filelock import FileLock
|
|
|
|
|
|
|
|
|
|
|
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
|
|
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
|
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
from modelscope.hub.api import ModelScopeConfig
|
|
|
|
from modelscope.hub.api import ModelScopeConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from modelscope import __version__
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_datasets_user_agent_ms(user_agent: Optional[Union[str, dict]] = None) -> str:
|
|
|
|
|
|
|
|
ua = f'datasets/{__version__}'
|
|
|
|
|
|
|
|
ua += f'; python/{config.PY_VERSION}'
|
|
|
|
|
|
|
|
ua += f'; pyarrow/{config.PYARROW_VERSION}'
|
|
|
|
|
|
|
|
if config.TORCH_AVAILABLE:
|
|
|
|
|
|
|
|
ua += f'; torch/{config.TORCH_VERSION}'
|
|
|
|
|
|
|
|
if config.TF_AVAILABLE:
|
|
|
|
|
|
|
|
ua += f'; tensorflow/{config.TF_VERSION}'
|
|
|
|
|
|
|
|
if config.JAX_AVAILABLE:
|
|
|
|
|
|
|
|
ua += f'; jax/{config.JAX_VERSION}'
|
|
|
|
|
|
|
|
if isinstance(user_agent, dict):
|
|
|
|
|
|
|
|
ua += f"; {'; '.join(f'{k}/{v}' for k, v in user_agent.items())}"
|
|
|
|
|
|
|
|
elif isinstance(user_agent, str):
|
|
|
|
|
|
|
|
ua += '; ' + user_agent
|
|
|
|
|
|
|
|
return ua
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _request_with_retry_ms(
|
|
|
|
|
|
|
|
method: str,
|
|
|
|
|
|
|
|
url: str,
|
|
|
|
|
|
|
|
max_retries: int = 2,
|
|
|
|
|
|
|
|
base_wait_time: float = 0.5,
|
|
|
|
|
|
|
|
max_wait_time: float = 2,
|
|
|
|
|
|
|
|
timeout: float = 10.0,
|
|
|
|
|
|
|
|
**params,
|
|
|
|
|
|
|
|
) -> requests.Response:
|
|
|
|
|
|
|
|
"""Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Note that if the environment variable HF_DATASETS_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
method (str): HTTP method, such as 'GET' or 'HEAD'.
|
|
|
|
|
|
|
|
url (str): The URL of the resource to fetch.
|
|
|
|
|
|
|
|
max_retries (int): Maximum number of retries, defaults to 0 (no retries).
|
|
|
|
|
|
|
|
base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
|
|
|
|
|
|
|
|
retries then grows exponentially, capped by max_wait_time.
|
|
|
|
|
|
|
|
max_wait_time (float): Maximum amount of time between two retries, in seconds.
|
|
|
|
|
|
|
|
**params (additional keyword arguments): Params to pass to :obj:`requests.request`.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
tries, success = 0, False
|
|
|
|
|
|
|
|
response = None
|
|
|
|
|
|
|
|
while not success:
|
|
|
|
|
|
|
|
tries += 1
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
|
|
|
|
|
|
|
|
success = True
|
|
|
|
|
|
|
|
except (requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError) as err:
|
|
|
|
|
|
|
|
if tries > max_retries:
|
|
|
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
logger.info(f'{method} request to {url} timed out, retrying... [{tries/max_retries}]')
|
|
|
|
|
|
|
|
sleep_time = min(max_wait_time, base_wait_time * 2 ** (tries - 1)) # Exponential backoff
|
|
|
|
|
|
|
|
time.sleep(sleep_time)
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def http_head_ms(
|
|
|
|
|
|
|
|
url, proxies=None, headers=None, cookies=None, allow_redirects=True, timeout=10.0, max_retries=0
|
|
|
|
|
|
|
|
) -> requests.Response:
|
|
|
|
|
|
|
|
headers = copy.deepcopy(headers) or {}
|
|
|
|
|
|
|
|
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
|
|
|
|
|
|
|
|
response = _request_with_retry_ms(
|
|
|
|
|
|
|
|
method='HEAD',
|
|
|
|
|
|
|
|
url=url,
|
|
|
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
|
|
cookies=cookies,
|
|
|
|
|
|
|
|
allow_redirects=allow_redirects,
|
|
|
|
|
|
|
|
timeout=timeout,
|
|
|
|
|
|
|
|
max_retries=max_retries,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def http_get_ms(
|
|
|
|
|
|
|
|
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
|
|
|
|
|
|
|
|
) -> Optional[requests.Response]:
|
|
|
|
|
|
|
|
headers = dict(headers) if headers is not None else {}
|
|
|
|
|
|
|
|
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
|
|
|
|
|
|
|
|
if resume_size > 0:
|
|
|
|
|
|
|
|
headers['Range'] = f'bytes={resume_size:d}-'
|
|
|
|
|
|
|
|
response = _request_with_retry_ms(
|
|
|
|
|
|
|
|
method='GET',
|
|
|
|
|
|
|
|
url=url,
|
|
|
|
|
|
|
|
stream=True,
|
|
|
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
|
|
cookies=cookies,
|
|
|
|
|
|
|
|
max_retries=max_retries,
|
|
|
|
|
|
|
|
timeout=timeout,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
if temp_file is None:
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
if response.status_code == 416: # Range not satisfiable
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
content_length = response.headers.get('Content-Length')
|
|
|
|
|
|
|
|
total = resume_size + int(content_length) if content_length is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress = tqdm(total=total, initial=resume_size, unit_scale=True, unit='B', desc=desc or 'Downloading')
|
|
|
|
|
|
|
|
for chunk in response.iter_content(chunk_size=1024):
|
|
|
|
|
|
|
|
progress.update(len(chunk))
|
|
|
|
|
|
|
|
temp_file.write(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_from_cache_ms(
|
|
|
|
def get_from_cache_ms(
|
|
|
|
url,
|
|
|
|
url,
|
|
|
|
cache_dir=None,
|
|
|
|
cache_dir=None,
|
|
|
|
@@ -42,6 +153,7 @@ def get_from_cache_ms(
|
|
|
|
ignore_url_params=False,
|
|
|
|
ignore_url_params=False,
|
|
|
|
storage_options=None,
|
|
|
|
storage_options=None,
|
|
|
|
download_desc=None,
|
|
|
|
download_desc=None,
|
|
|
|
|
|
|
|
disable_tqdm=None,
|
|
|
|
) -> str:
|
|
|
|
) -> str:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Given a URL, look for the corresponding file in the local cache.
|
|
|
|
Given a URL, look for the corresponding file in the local cache.
|
|
|
|
@@ -101,16 +213,14 @@ def get_from_cache_ms(
|
|
|
|
# We don't have the file locally or we need an eTag
|
|
|
|
# We don't have the file locally or we need an eTag
|
|
|
|
if not local_files_only:
|
|
|
|
if not local_files_only:
|
|
|
|
scheme = urlparse(url).scheme
|
|
|
|
scheme = urlparse(url).scheme
|
|
|
|
if scheme == 'ftp':
|
|
|
|
if scheme not in ('http', 'https'):
|
|
|
|
connected = ftp_head(url)
|
|
|
|
|
|
|
|
elif scheme not in ('http', 'https'):
|
|
|
|
|
|
|
|
response = fsspec_head(url, storage_options=storage_options)
|
|
|
|
response = fsspec_head(url, storage_options=storage_options)
|
|
|
|
# s3fs uses "ETag", gcsfs uses "etag"
|
|
|
|
# s3fs uses "ETag", gcsfs uses "etag"
|
|
|
|
etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
|
|
|
|
etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
|
|
|
|
connected = True
|
|
|
|
connected = True
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
response = http_head(
|
|
|
|
response = http_head_ms(
|
|
|
|
url,
|
|
|
|
url,
|
|
|
|
allow_redirects=True,
|
|
|
|
allow_redirects=True,
|
|
|
|
proxies=proxies,
|
|
|
|
proxies=proxies,
|
|
|
|
@@ -167,7 +277,6 @@ def get_from_cache_ms(
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif response is not None and response.status_code == 404:
|
|
|
|
elif response is not None and response.status_code == 404:
|
|
|
|
raise FileNotFoundError(f"Couldn't find file at {url}")
|
|
|
|
raise FileNotFoundError(f"Couldn't find file at {url}")
|
|
|
|
_raise_if_offline_mode_is_enabled(f'Tried to reach {url}')
|
|
|
|
|
|
|
|
if head_error is not None:
|
|
|
|
if head_error is not None:
|
|
|
|
raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
|
|
|
|
raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
|
|
|
|
elif response is not None:
|
|
|
|
elif response is not None:
|
|
|
|
@@ -206,16 +315,21 @@ def get_from_cache_ms(
|
|
|
|
# Download to temporary file, then copy to cache path once finished.
|
|
|
|
# Download to temporary file, then copy to cache path once finished.
|
|
|
|
# Otherwise, you get corrupt cache entries if the download gets interrupted.
|
|
|
|
# Otherwise, you get corrupt cache entries if the download gets interrupted.
|
|
|
|
with temp_file_manager() as temp_file:
|
|
|
|
with temp_file_manager() as temp_file:
|
|
|
|
logger.info(f'Downloading to {temp_file.name}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# GET file object
|
|
|
|
# GET file object
|
|
|
|
if scheme == 'ftp':
|
|
|
|
if scheme not in ('http', 'https'):
|
|
|
|
ftp_get(url, temp_file)
|
|
|
|
|
|
|
|
elif scheme not in ('http', 'https'):
|
|
|
|
|
|
|
|
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
|
|
|
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
http_get(url, temp_file=temp_file, proxies=proxies, resume_size=resume_size,
|
|
|
|
http_get_ms(
|
|
|
|
headers=headers, cookies=cookies, max_retries=max_retries, desc=download_desc)
|
|
|
|
url,
|
|
|
|
|
|
|
|
temp_file=temp_file,
|
|
|
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
|
|
|
resume_size=resume_size,
|
|
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
|
|
cookies=cookies,
|
|
|
|
|
|
|
|
max_retries=max_retries,
|
|
|
|
|
|
|
|
desc=download_desc,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f'storing {url} in cache at {cache_path}')
|
|
|
|
logger.info(f'storing {url} in cache at {cache_path}')
|
|
|
|
shutil.move(temp_file.name, cache_path)
|
|
|
|
shutil.move(temp_file.name, cache_path)
|
|
|
|
|