[hub] Support ProgressCallback (#1380)

1. Add ProgressCallback for file download
This commit is contained in:
Jintao
2025-06-26 14:02:05 +08:00
committed by GitHub
parent 694fc89735
commit 29ea57dfbc
5 changed files with 188 additions and 81 deletions

View File

@@ -0,0 +1 @@
from .callback import ProgressCallback

View File

@@ -0,0 +1,34 @@
from tqdm import tqdm
class ProgressCallback:
def __init__(self, filename: str, file_size: int):
self.filename = filename
self.file_size = file_size
def update(self, size: int):
pass
def end(self):
pass
class TqdmCallback(ProgressCallback):
def __init__(self, filename: str, file_size: int):
super().__init__(filename, file_size)
self.progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size if file_size > 0 else 1,
initial=0,
desc='Downloading [' + self.filename + ']',
leave=True)
def update(self, size: int):
self.progress.update(size)
def end(self):
self.progress.close()

View File

@@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
from functools import partial
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Type, Union
import requests
from requests.adapters import Retry
@@ -30,6 +30,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
from modelscope.utils.file_utils import (get_dataset_cache_root,
get_model_cache_root)
from modelscope.utils.logger import get_logger
from .callback import ProgressCallback, TqdmCallback
from .errors import FileDownloadError, InvalidParameter, NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_endpoint,
@@ -391,7 +392,7 @@ def get_file_download_url(model_id: str,
def download_part_with_retry(params):
# unpack parameters
model_file_path, progress, start, end, url, file_name, cookies, headers = params
model_file_path, progress_callbacks, start, end, url, file_name, cookies, headers = params
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
retry = Retry(
@@ -406,7 +407,8 @@ def download_part_with_retry(params):
part_file_name): # download partial, continue download
with open(part_file_name, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length)
for callback in progress_callbacks:
callback.update(partial_length)
download_start = start + partial_length
if download_start > end:
break # this part is download completed.
@@ -422,7 +424,8 @@ def download_part_with_retry(params):
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress.update(len(chunk))
for callback in progress_callbacks:
callback.update(len(chunk))
break
except (Exception) as e: # no matter what exception, we will retry.
retry = retry.increment('GET', url, error=e)
@@ -438,18 +441,16 @@ def parallel_download(url: str,
headers: Optional[Dict[str, str]] = None,
file_size: int = None,
disable_tqdm: bool = False,
progress_callbacks: List[Type[ProgressCallback]] = None,
endpoint: str = None):
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
)
if not disable_tqdm:
progress_callbacks.append(TqdmCallback)
progress_callbacks = [
callback(file_name, file_size) for callback in progress_callbacks
]
# create temp file
with tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size,
initial=0,
desc='Downloading [' + file_name + ']',
leave=True,
disable=disable_tqdm,
) as progress:
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
tasks = []
file_path = os.path.join(local_dir, file_name)
@@ -457,18 +458,18 @@ def parallel_download(url: str,
for idx in range(int(file_size / PART_SIZE)):
start = idx * PART_SIZE
end = (idx + 1) * PART_SIZE - 1
tasks.append((file_path, progress, start, end, url, file_name,
cookies, headers))
if end + 1 < file_size:
tasks.append((file_path, progress, end + 1, file_size - 1, url,
tasks.append((file_path, progress_callbacks, start, end, url,
file_name, cookies, headers))
if end + 1 < file_size:
tasks.append((file_path, progress_callbacks, end + 1, file_size - 1,
url, file_name, cookies, headers))
parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
# download every part
with ThreadPoolExecutor(
max_workers=parallels,
thread_name_prefix='download') as executor:
max_workers=parallels, thread_name_prefix='download') as executor:
list(executor.map(download_part_with_retry, tasks))
for callback in progress_callbacks:
callback.end()
# merge parts.
hash_sha256 = hashlib.sha256()
with open(os.path.join(local_dir, file_name), 'wb') as output_file:
@@ -493,6 +494,7 @@ def http_get_model_file(
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
disable_tqdm: bool = False,
progress_callbacks: List[Type[ProgressCallback]] = None,
):
"""Download remote file, will retry 5 times before giving up on errors.
@@ -510,11 +512,20 @@ def http_get_model_file(
headers(Dict[str, str], optional):
http headers to carry necessary info when requesting the remote file
disable_tqdm(bool, optional): Disable the progress bar with tqdm.
progress_callbacks(List[Type[ProgressCallback]], optional):
progress callbacks to track the download progress.
Raises:
FileDownloadError: File download failed.
"""
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
)
if not disable_tqdm:
progress_callbacks.append(TqdmCallback)
progress_callbacks = [
callback(file_name, file_size) for callback in progress_callbacks
]
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
temp_file_path = os.path.join(local_dir, file_name)
@@ -527,22 +538,14 @@ def http_get_model_file(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1,
allowed_methods=['GET'])
while True:
try:
with tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size if file_size > 0 else 1,
initial=0,
desc='Downloading [' + file_name + ']',
leave=True,
disable=disable_tqdm,
) as progress:
if file_size == 0:
# Avoid empty file server request
with open(temp_file_path, 'w+'):
progress.update(1)
for callback in progress_callbacks:
callback.update(1)
break
# Determine the length of any existing partial download
partial_length = 0
@@ -552,7 +555,8 @@ def http_get_model_file(
has_retry = True
with open(temp_file_path, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length)
for callback in progress_callbacks:
callback.update(partial_length)
# Check if download is complete
if partial_length >= file_size:
@@ -571,7 +575,8 @@ def http_get_model_file(
for chunk in r.iter_content(
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
for callback in progress_callbacks:
callback.update(len(chunk))
f.write(chunk)
# hash would be discarded in retry case anyway
if not has_retry:
@@ -581,6 +586,8 @@ def http_get_model_file(
has_retry = True
retry = retry.increment('GET', url, error=e)
retry.sleep()
for callback in progress_callbacks:
callback.end()
# if anything went wrong, we would discard the real-time computed hash and return None
return None if has_retry else hash_sha256.hexdigest()
@@ -675,6 +682,7 @@ def download_file(
headers,
cookies,
disable_tqdm=False,
progress_callbacks: List[Type[ProgressCallback]] = None,
):
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
@@ -686,6 +694,7 @@ def download_file(
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'],
disable_tqdm=disable_tqdm,
progress_callbacks=progress_callbacks,
)
else:
file_digest = http_get_model_file(
@@ -696,6 +705,7 @@ def download_file(
headers=headers,
cookies=cookies,
disable_tqdm=disable_tqdm,
progress_callbacks=progress_callbacks,
)
# check file integrity

View File

@@ -6,7 +6,7 @@ import re
import uuid
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import InvalidParameter
@@ -23,6 +23,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
from modelscope.utils.file_utils import get_modelscope_cache_dir
from modelscope.utils.logger import get_logger
from modelscope.utils.thread_utils import thread_executor
from .callback import ProgressCallback
logger = get_logger()
@@ -42,6 +43,7 @@ def snapshot_download(
max_workers: int = 8,
repo_id: str = None,
repo_type: Optional[str] = REPO_TYPE_MODEL,
progress_callbacks: List[Type[ProgressCallback]] = None,
) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
@@ -77,6 +79,8 @@ def snapshot_download(
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
For hugging-face compatibility.
max_workers (`int`): The maximum number of workers to download files, default 8.
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
progress callbacks to track the download progress.
Raises:
ValueError: the value details.
@@ -118,7 +122,8 @@ def snapshot_download(
local_dir=local_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
max_workers=max_workers)
max_workers=max_workers,
progress_callbacks=progress_callbacks)
def dataset_snapshot_download(
@@ -213,6 +218,7 @@ def _snapshot_download(
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8,
progress_callbacks: List[Type[ProgressCallback]] = None,
):
if not repo_type:
repo_type = REPO_TYPE_MODEL
@@ -304,6 +310,7 @@ def _snapshot_download(
allow_patterns=allow_patterns,
max_workers=max_workers,
endpoint=endpoint,
progress_callbacks=progress_callbacks,
)
if '.' in repo_id:
masked_directory = get_model_masked_directory(
@@ -362,6 +369,7 @@ def _snapshot_download(
allow_patterns=allow_patterns,
max_workers=max_workers,
endpoint=endpoint,
progress_callbacks=progress_callbacks,
)
cache.save_model_version(revision_info=revision_detail)
@@ -449,6 +457,7 @@ def _download_file_lists(
ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8,
endpoint: Optional[str] = None,
progress_callbacks: List[Type[ProgressCallback]] = None,
):
ignore_patterns = _normalize_patterns(ignore_patterns)
allow_patterns = _normalize_patterns(allow_patterns)
@@ -532,6 +541,7 @@ def _download_file_lists(
headers,
cookies,
disable_tqdm=False,
progress_callbacks=progress_callbacks,
)
if len(filtered_repo_files) > 0:

View File

@@ -0,0 +1,52 @@
import tempfile
import unittest
from tqdm import tqdm
from modelscope import snapshot_download
from modelscope.hub import ProgressCallback
class NewProgressCallback(ProgressCallback):
all_files = set() # just for test
def __init__(self, filename: str, file_size: int):
super().__init__(filename, file_size)
self.progress = tqdm(total=file_size)
self.all_files.add(filename)
def update(self, size: int):
self.progress.update(size)
def end(self):
self.all_files.remove(self.filename)
assert self.progress.n == self.progress.total == self.file_size
self.progress.close()
class ProgressCallbackTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_progress_callback(self):
model_dir = snapshot_download(
'swift/test_lora',
progress_callbacks=[NewProgressCallback],
cache_dir=self.temp_dir.name)
print(f'model_dir: {model_dir}')
self.assertTrue(len(NewProgressCallback.all_files) == 0)
def test_empty_progress_callback(self):
model_dir = snapshot_download(
'swift/test_lora',
progress_callbacks=[],
cache_dir=self.temp_dir.name)
print(f'model_dir: {model_dir}')
if __name__ == '__main__':
unittest.main()