[hub] Support ProgressCallback (#1380)

1. Add ProgressCallback for file download
This commit is contained in:
Jintao
2025-06-26 14:02:05 +08:00
committed by GitHub
parent 694fc89735
commit 29ea57dfbc
5 changed files with 188 additions and 81 deletions

View File

@@ -0,0 +1 @@
from .callback import ProgressCallback

View File

@@ -0,0 +1,34 @@
from tqdm import tqdm
class ProgressCallback:
def __init__(self, filename: str, file_size: int):
self.filename = filename
self.file_size = file_size
def update(self, size: int):
pass
def end(self):
pass
class TqdmCallback(ProgressCallback):
def __init__(self, filename: str, file_size: int):
super().__init__(filename, file_size)
self.progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size if file_size > 0 else 1,
initial=0,
desc='Downloading [' + self.filename + ']',
leave=True)
def update(self, size: int):
self.progress.update(size)
def end(self):
self.progress.close()

View File

@@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, List, Optional, Type, Union
import requests import requests
from requests.adapters import Retry from requests.adapters import Retry
@@ -30,6 +30,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
from modelscope.utils.file_utils import (get_dataset_cache_root, from modelscope.utils.file_utils import (get_dataset_cache_root,
get_model_cache_root) get_model_cache_root)
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from .callback import ProgressCallback, TqdmCallback
from .errors import FileDownloadError, InvalidParameter, NotExistError from .errors import FileDownloadError, InvalidParameter, NotExistError
from .utils.caching import ModelFileSystemCache from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_endpoint, from .utils.utils import (file_integrity_validation, get_endpoint,
@@ -391,7 +392,7 @@ def get_file_download_url(model_id: str,
def download_part_with_retry(params): def download_part_with_retry(params):
# unpack parameters # unpack parameters
model_file_path, progress, start, end, url, file_name, cookies, headers = params model_file_path, progress_callbacks, start, end, url, file_name, cookies, headers = params
get_headers = {} if headers is None else copy.deepcopy(headers) get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex) get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
retry = Retry( retry = Retry(
@@ -406,7 +407,8 @@ def download_part_with_retry(params):
part_file_name): # download partial, continue download part_file_name): # download partial, continue download
with open(part_file_name, 'rb') as f: with open(part_file_name, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END) partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length) for callback in progress_callbacks:
callback.update(partial_length)
download_start = start + partial_length download_start = start + partial_length
if download_start > end: if download_start > end:
break # this part is download completed. break # this part is download completed.
@@ -422,7 +424,8 @@ def download_part_with_retry(params):
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE): chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
f.write(chunk) f.write(chunk)
progress.update(len(chunk)) for callback in progress_callbacks:
callback.update(len(chunk))
break break
except (Exception) as e: # no matter what exception, we will retry. except (Exception) as e: # no matter what exception, we will retry.
retry = retry.increment('GET', url, error=e) retry = retry.increment('GET', url, error=e)
@@ -438,37 +441,35 @@ def parallel_download(url: str,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
file_size: int = None, file_size: int = None,
disable_tqdm: bool = False, disable_tqdm: bool = False,
progress_callbacks: List[Type[ProgressCallback]] = None,
endpoint: str = None): endpoint: str = None):
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
)
if not disable_tqdm:
progress_callbacks.append(TqdmCallback)
progress_callbacks = [
callback(file_name, file_size) for callback in progress_callbacks
]
# create temp file # create temp file
with tqdm( PART_SIZE = 160 * 1024 * 1024 # every part is 160M
unit='B', tasks = []
unit_scale=True, file_path = os.path.join(local_dir, file_name)
unit_divisor=1024, os.makedirs(os.path.dirname(file_path), exist_ok=True)
total=file_size, for idx in range(int(file_size / PART_SIZE)):
initial=0, start = idx * PART_SIZE
desc='Downloading [' + file_name + ']', end = (idx + 1) * PART_SIZE - 1
leave=True, tasks.append((file_path, progress_callbacks, start, end, url,
disable=disable_tqdm, file_name, cookies, headers))
) as progress: if end + 1 < file_size:
PART_SIZE = 160 * 1024 * 1024 # every part is 160M tasks.append((file_path, progress_callbacks, end + 1, file_size - 1,
tasks = [] url, file_name, cookies, headers))
file_path = os.path.join(local_dir, file_name) parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
os.makedirs(os.path.dirname(file_path), exist_ok=True) # download every part
for idx in range(int(file_size / PART_SIZE)): with ThreadPoolExecutor(
start = idx * PART_SIZE max_workers=parallels, thread_name_prefix='download') as executor:
end = (idx + 1) * PART_SIZE - 1 list(executor.map(download_part_with_retry, tasks))
tasks.append((file_path, progress, start, end, url, file_name, for callback in progress_callbacks:
cookies, headers)) callback.end()
if end + 1 < file_size:
tasks.append((file_path, progress, end + 1, file_size - 1, url,
file_name, cookies, headers))
parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
# download every part
with ThreadPoolExecutor(
max_workers=parallels,
thread_name_prefix='download') as executor:
list(executor.map(download_part_with_retry, tasks))
# merge parts. # merge parts.
hash_sha256 = hashlib.sha256() hash_sha256 = hashlib.sha256()
with open(os.path.join(local_dir, file_name), 'wb') as output_file: with open(os.path.join(local_dir, file_name), 'wb') as output_file:
@@ -493,6 +494,7 @@ def http_get_model_file(
cookies: CookieJar, cookies: CookieJar,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
disable_tqdm: bool = False, disable_tqdm: bool = False,
progress_callbacks: List[Type[ProgressCallback]] = None,
): ):
"""Download remote file, will retry 5 times before giving up on errors. """Download remote file, will retry 5 times before giving up on errors.
@@ -510,11 +512,20 @@ def http_get_model_file(
headers(Dict[str, str], optional): headers(Dict[str, str], optional):
http headers to carry necessary info when requesting the remote file http headers to carry necessary info when requesting the remote file
disable_tqdm(bool, optional): Disable the progress bar with tqdm. disable_tqdm(bool, optional): Disable the progress bar with tqdm.
progress_callbacks(List[Type[ProgressCallback]], optional):
progress callbacks to track the download progress.
Raises: Raises:
FileDownloadError: File download failed. FileDownloadError: File download failed.
""" """
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
)
if not disable_tqdm:
progress_callbacks.append(TqdmCallback)
progress_callbacks = [
callback(file_name, file_size) for callback in progress_callbacks
]
get_headers = {} if headers is None else copy.deepcopy(headers) get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex) get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
temp_file_path = os.path.join(local_dir, file_name) temp_file_path = os.path.join(local_dir, file_name)
@@ -527,60 +538,56 @@ def http_get_model_file(
total=API_FILE_DOWNLOAD_RETRY_TIMES, total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1, backoff_factor=1,
allowed_methods=['GET']) allowed_methods=['GET'])
while True: while True:
try: try:
with tqdm( if file_size == 0:
unit='B', # Avoid empty file server request
unit_scale=True, with open(temp_file_path, 'w+'):
unit_divisor=1024, for callback in progress_callbacks:
total=file_size if file_size > 0 else 1, callback.update(1)
initial=0, break
desc='Downloading [' + file_name + ']', # Determine the length of any existing partial download
leave=True, partial_length = 0
disable=disable_tqdm, # download partial, continue download
) as progress: if os.path.exists(temp_file_path):
if file_size == 0: # resuming from interrupted download is also considered as retry
# Avoid empty file server request has_retry = True
with open(temp_file_path, 'w+'): with open(temp_file_path, 'rb') as f:
progress.update(1) partial_length = f.seek(0, io.SEEK_END)
break for callback in progress_callbacks:
# Determine the length of any existing partial download callback.update(partial_length)
partial_length = 0
# download partial, continue download
if os.path.exists(temp_file_path):
# resuming from interrupted download is also considered as retry
has_retry = True
with open(temp_file_path, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length)
# Check if download is complete # Check if download is complete
if partial_length >= file_size: if partial_length >= file_size:
break break
# closed range[], from 0. # closed range[], from 0.
get_headers['Range'] = 'bytes=%s-%s' % (partial_length, get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
file_size - 1) file_size - 1)
with open(temp_file_path, 'ab+') as f: with open(temp_file_path, 'ab+') as f:
r = requests.get( r = requests.get(
url, url,
stream=True, stream=True,
headers=get_headers, headers=get_headers,
cookies=cookies, cookies=cookies,
timeout=API_FILE_DOWNLOAD_TIMEOUT) timeout=API_FILE_DOWNLOAD_TIMEOUT)
r.raise_for_status() r.raise_for_status()
for chunk in r.iter_content( for chunk in r.iter_content(
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE): chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) for callback in progress_callbacks:
f.write(chunk) callback.update(len(chunk))
# hash would be discarded in retry case anyway f.write(chunk)
if not has_retry: # hash would be discarded in retry case anyway
hash_sha256.update(chunk) if not has_retry:
hash_sha256.update(chunk)
break break
except Exception as e: # no matter what happen, we will retry. except Exception as e: # no matter what happen, we will retry.
has_retry = True has_retry = True
retry = retry.increment('GET', url, error=e) retry = retry.increment('GET', url, error=e)
retry.sleep() retry.sleep()
for callback in progress_callbacks:
callback.end()
# if anything went wrong, we would discard the real-time computed hash and return None # if anything went wrong, we would discard the real-time computed hash and return None
return None if has_retry else hash_sha256.hexdigest() return None if has_retry else hash_sha256.hexdigest()
@@ -675,6 +682,7 @@ def download_file(
headers, headers,
cookies, cookies,
disable_tqdm=False, disable_tqdm=False,
progress_callbacks: List[Type[ProgressCallback]] = None,
): ):
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[ if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file. 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
@@ -686,6 +694,7 @@ def download_file(
cookies=None if cookies is None else cookies.get_dict(), cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'], file_size=file_meta['Size'],
disable_tqdm=disable_tqdm, disable_tqdm=disable_tqdm,
progress_callbacks=progress_callbacks,
) )
else: else:
file_digest = http_get_model_file( file_digest = http_get_model_file(
@@ -696,6 +705,7 @@ def download_file(
headers=headers, headers=headers,
cookies=cookies, cookies=cookies,
disable_tqdm=disable_tqdm, disable_tqdm=disable_tqdm,
progress_callbacks=progress_callbacks,
) )
# check file integrity # check file integrity

View File

@@ -6,7 +6,7 @@ import re
import uuid import uuid
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Type, Union
from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import InvalidParameter from modelscope.hub.errors import InvalidParameter
@@ -23,6 +23,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
from modelscope.utils.file_utils import get_modelscope_cache_dir from modelscope.utils.file_utils import get_modelscope_cache_dir
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.utils.thread_utils import thread_executor from modelscope.utils.thread_utils import thread_executor
from .callback import ProgressCallback
logger = get_logger() logger = get_logger()
@@ -42,6 +43,7 @@ def snapshot_download(
max_workers: int = 8, max_workers: int = 8,
repo_id: str = None, repo_id: str = None,
repo_type: Optional[str] = REPO_TYPE_MODEL, repo_type: Optional[str] = REPO_TYPE_MODEL,
progress_callbacks: List[Type[ProgressCallback]] = None,
) -> str: ) -> str:
"""Download all files of a repo. """Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This Downloads a whole snapshot of a repo's files at the specified revision. This
@@ -77,6 +79,8 @@ def snapshot_download(
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern. If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
For hugging-face compatibility. For hugging-face compatibility.
max_workers (`int`): The maximum number of workers to download files, default 8. max_workers (`int`): The maximum number of workers to download files, default 8.
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
progress callbacks to track the download progress.
Raises: Raises:
ValueError: the value details. ValueError: the value details.
@@ -118,7 +122,8 @@ def snapshot_download(
local_dir=local_dir, local_dir=local_dir,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
max_workers=max_workers) max_workers=max_workers,
progress_callbacks=progress_callbacks)
def dataset_snapshot_download( def dataset_snapshot_download(
@@ -213,6 +218,7 @@ def _snapshot_download(
allow_patterns: Optional[Union[List[str], str]] = None, allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8, max_workers: int = 8,
progress_callbacks: List[Type[ProgressCallback]] = None,
): ):
if not repo_type: if not repo_type:
repo_type = REPO_TYPE_MODEL repo_type = REPO_TYPE_MODEL
@@ -304,6 +310,7 @@ def _snapshot_download(
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
max_workers=max_workers, max_workers=max_workers,
endpoint=endpoint, endpoint=endpoint,
progress_callbacks=progress_callbacks,
) )
if '.' in repo_id: if '.' in repo_id:
masked_directory = get_model_masked_directory( masked_directory = get_model_masked_directory(
@@ -362,6 +369,7 @@ def _snapshot_download(
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
max_workers=max_workers, max_workers=max_workers,
endpoint=endpoint, endpoint=endpoint,
progress_callbacks=progress_callbacks,
) )
cache.save_model_version(revision_info=revision_detail) cache.save_model_version(revision_info=revision_detail)
@@ -449,6 +457,7 @@ def _download_file_lists(
ignore_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8, max_workers: int = 8,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
progress_callbacks: List[Type[ProgressCallback]] = None,
): ):
ignore_patterns = _normalize_patterns(ignore_patterns) ignore_patterns = _normalize_patterns(ignore_patterns)
allow_patterns = _normalize_patterns(allow_patterns) allow_patterns = _normalize_patterns(allow_patterns)
@@ -532,6 +541,7 @@ def _download_file_lists(
headers, headers,
cookies, cookies,
disable_tqdm=False, disable_tqdm=False,
progress_callbacks=progress_callbacks,
) )
if len(filtered_repo_files) > 0: if len(filtered_repo_files) > 0:

View File

@@ -0,0 +1,52 @@
import tempfile
import unittest
from tqdm import tqdm
from modelscope import snapshot_download
from modelscope.hub import ProgressCallback
class NewProgressCallback(ProgressCallback):
all_files = set() # just for test
def __init__(self, filename: str, file_size: int):
super().__init__(filename, file_size)
self.progress = tqdm(total=file_size)
self.all_files.add(filename)
def update(self, size: int):
self.progress.update(size)
def end(self):
self.all_files.remove(self.filename)
assert self.progress.n == self.progress.total == self.file_size
self.progress.close()
class ProgressCallbackTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_progress_callback(self):
model_dir = snapshot_download(
'swift/test_lora',
progress_callbacks=[NewProgressCallback],
cache_dir=self.temp_dir.name)
print(f'model_dir: {model_dir}')
self.assertTrue(len(NewProgressCallback.all_files) == 0)
def test_empty_progress_callback(self):
model_dir = snapshot_download(
'swift/test_lora',
progress_callbacks=[],
cache_dir=self.temp_dir.name)
print(f'model_dir: {model_dir}')
if __name__ == '__main__':
unittest.main()