merge main for tqdm

This commit is contained in:
xingjun.wang
2024-09-26 17:12:35 +08:00
2 changed files with 18 additions and 40 deletions

View File

@@ -555,7 +555,7 @@ def get_module_without_script(self) -> DatasetModule:
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = 'Downloading readme'
download_config.download_desc = 'Downloading [README.md]'
try:
url_or_filename = _ms_api.get_dataset_file_url(
file_name='README.md',

View File

@@ -9,13 +9,13 @@ import copy
import shutil
import time
import warnings
import inspect
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Optional, Union
from urllib.parse import urljoin, urlparse
import requests
from tqdm import tqdm
from datasets import config
from datasets.utils.file_utils import hash_url_to_filename, \
@@ -131,7 +131,6 @@ def http_get_ms(
content_length = response.headers.get('Content-Length')
total = resume_size + int(content_length) if content_length is not None else None
from tqdm import tqdm
progress = tqdm(total=total, initial=resume_size, unit_scale=True, unit='B', desc=desc or 'Downloading')
for chunk in response.iter_content(chunk_size=1024):
progress.update(len(chunk))
@@ -156,7 +155,7 @@ def get_from_cache_ms(
ignore_url_params=False,
storage_options=None,
download_desc=None,
disable_tqdm=False,
disable_tqdm=None,
) -> str:
"""
Given a URL, look for the corresponding file in the local cache.
@@ -202,6 +201,8 @@ def get_from_cache_ms(
# if we don't ask for 'force_download' then we spare a request
filename = hash_url_to_filename(cached_url, etag=None)
cache_path = os.path.join(cache_dir, filename)
if download_desc is None:
download_desc = 'Downloading [' + filename + ']'
if os.path.exists(cache_path) and not force_download and not use_etag:
return cache_path
@@ -316,46 +317,23 @@ def get_from_cache_ms(
# Download to temporary file, then copy to cache path once finished.
# Otherwise, you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info(f'Downloading to {temp_file.name}')
# GET file object
if scheme not in ('http', 'https'):
fsspec_get_sig = inspect.signature(fsspec_get)
if 'disable_tqdm' in fsspec_get_sig.parameters:
fsspec_get(url,
temp_file,
storage_options=storage_options,
desc=download_desc,
disable_tqdm=disable_tqdm
)
else:
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
# fsspec_get_sig = inspect.signature(fsspec_get)
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
else:
http_get_sig = inspect.signature(http_get_ms)
if 'disable_tqdm' in http_get_sig.parameters:
http_get_ms(
url,
temp_file=temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
disable_tqdm=disable_tqdm,
)
else:
http_get_ms(
url,
temp_file=temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
)
# http_get_sig = inspect.signature(http_get_ms)
http_get_ms(
url,
temp_file=temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
)
logger.info(f'storing {url} in cache at {cache_path}')
shutil.move(temp_file.name, cache_path)