mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[hub] Support ProgressCallback (#1380)
1. Add ProgressCallback for file download
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from .callback import ProgressCallback
|
||||
|
||||
34
modelscope/hub/callback.py
Normal file
34
modelscope/hub/callback.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class ProgressCallback:
|
||||
|
||||
def __init__(self, filename: str, file_size: int):
|
||||
self.filename = filename
|
||||
self.file_size = file_size
|
||||
|
||||
def update(self, size: int):
|
||||
pass
|
||||
|
||||
def end(self):
|
||||
pass
|
||||
|
||||
|
||||
class TqdmCallback(ProgressCallback):
|
||||
|
||||
def __init__(self, filename: str, file_size: int):
|
||||
super().__init__(filename, file_size)
|
||||
self.progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size if file_size > 0 else 1,
|
||||
initial=0,
|
||||
desc='Downloading [' + self.filename + ']',
|
||||
leave=True)
|
||||
|
||||
def update(self, size: int):
|
||||
self.progress.update(size)
|
||||
|
||||
def end(self):
|
||||
self.progress.close()
|
||||
@@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
import requests
|
||||
from requests.adapters import Retry
|
||||
@@ -30,6 +30,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
from modelscope.utils.file_utils import (get_dataset_cache_root,
|
||||
get_model_cache_root)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .callback import ProgressCallback, TqdmCallback
|
||||
from .errors import FileDownloadError, InvalidParameter, NotExistError
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||
@@ -391,7 +392,7 @@ def get_file_download_url(model_id: str,
|
||||
|
||||
def download_part_with_retry(params):
|
||||
# unpack parameters
|
||||
model_file_path, progress, start, end, url, file_name, cookies, headers = params
|
||||
model_file_path, progress_callbacks, start, end, url, file_name, cookies, headers = params
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
retry = Retry(
|
||||
@@ -406,7 +407,8 @@ def download_part_with_retry(params):
|
||||
part_file_name): # download partial, continue download
|
||||
with open(part_file_name, 'rb') as f:
|
||||
partial_length = f.seek(0, io.SEEK_END)
|
||||
progress.update(partial_length)
|
||||
for callback in progress_callbacks:
|
||||
callback.update(partial_length)
|
||||
download_start = start + partial_length
|
||||
if download_start > end:
|
||||
break # this part is download completed.
|
||||
@@ -422,7 +424,8 @@ def download_part_with_retry(params):
|
||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
progress.update(len(chunk))
|
||||
for callback in progress_callbacks:
|
||||
callback.update(len(chunk))
|
||||
break
|
||||
except (Exception) as e: # no matter what exception, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
@@ -438,37 +441,35 @@ def parallel_download(url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
file_size: int = None,
|
||||
disable_tqdm: bool = False,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
endpoint: str = None):
|
||||
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
|
||||
)
|
||||
if not disable_tqdm:
|
||||
progress_callbacks.append(TqdmCallback)
|
||||
progress_callbacks = [
|
||||
callback(file_name, file_size) for callback in progress_callbacks
|
||||
]
|
||||
# create temp file
|
||||
with tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size,
|
||||
initial=0,
|
||||
desc='Downloading [' + file_name + ']',
|
||||
leave=True,
|
||||
disable=disable_tqdm,
|
||||
) as progress:
|
||||
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
||||
tasks = []
|
||||
file_path = os.path.join(local_dir, file_name)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
for idx in range(int(file_size / PART_SIZE)):
|
||||
start = idx * PART_SIZE
|
||||
end = (idx + 1) * PART_SIZE - 1
|
||||
tasks.append((file_path, progress, start, end, url, file_name,
|
||||
cookies, headers))
|
||||
if end + 1 < file_size:
|
||||
tasks.append((file_path, progress, end + 1, file_size - 1, url,
|
||||
file_name, cookies, headers))
|
||||
parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
|
||||
# download every part
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=parallels,
|
||||
thread_name_prefix='download') as executor:
|
||||
list(executor.map(download_part_with_retry, tasks))
|
||||
|
||||
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
||||
tasks = []
|
||||
file_path = os.path.join(local_dir, file_name)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
for idx in range(int(file_size / PART_SIZE)):
|
||||
start = idx * PART_SIZE
|
||||
end = (idx + 1) * PART_SIZE - 1
|
||||
tasks.append((file_path, progress_callbacks, start, end, url,
|
||||
file_name, cookies, headers))
|
||||
if end + 1 < file_size:
|
||||
tasks.append((file_path, progress_callbacks, end + 1, file_size - 1,
|
||||
url, file_name, cookies, headers))
|
||||
parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
|
||||
# download every part
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=parallels, thread_name_prefix='download') as executor:
|
||||
list(executor.map(download_part_with_retry, tasks))
|
||||
for callback in progress_callbacks:
|
||||
callback.end()
|
||||
# merge parts.
|
||||
hash_sha256 = hashlib.sha256()
|
||||
with open(os.path.join(local_dir, file_name), 'wb') as output_file:
|
||||
@@ -493,6 +494,7 @@ def http_get_model_file(
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
disable_tqdm: bool = False,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
):
|
||||
"""Download remote file, will retry 5 times before giving up on errors.
|
||||
|
||||
@@ -510,11 +512,20 @@ def http_get_model_file(
|
||||
headers(Dict[str, str], optional):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
disable_tqdm(bool, optional): Disable the progress bar with tqdm.
|
||||
progress_callbacks(List[Type[ProgressCallback]], optional):
|
||||
progress callbacks to track the download progress.
|
||||
|
||||
Raises:
|
||||
FileDownloadError: File download failed.
|
||||
|
||||
"""
|
||||
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
|
||||
)
|
||||
if not disable_tqdm:
|
||||
progress_callbacks.append(TqdmCallback)
|
||||
progress_callbacks = [
|
||||
callback(file_name, file_size) for callback in progress_callbacks
|
||||
]
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
temp_file_path = os.path.join(local_dir, file_name)
|
||||
@@ -527,60 +538,56 @@ def http_get_model_file(
|
||||
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
|
||||
while True:
|
||||
try:
|
||||
with tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size if file_size > 0 else 1,
|
||||
initial=0,
|
||||
desc='Downloading [' + file_name + ']',
|
||||
leave=True,
|
||||
disable=disable_tqdm,
|
||||
) as progress:
|
||||
if file_size == 0:
|
||||
# Avoid empty file server request
|
||||
with open(temp_file_path, 'w+'):
|
||||
progress.update(1)
|
||||
break
|
||||
# Determine the length of any existing partial download
|
||||
partial_length = 0
|
||||
# download partial, continue download
|
||||
if os.path.exists(temp_file_path):
|
||||
# resuming from interrupted download is also considered as retry
|
||||
has_retry = True
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
partial_length = f.seek(0, io.SEEK_END)
|
||||
progress.update(partial_length)
|
||||
if file_size == 0:
|
||||
# Avoid empty file server request
|
||||
with open(temp_file_path, 'w+'):
|
||||
for callback in progress_callbacks:
|
||||
callback.update(1)
|
||||
break
|
||||
# Determine the length of any existing partial download
|
||||
partial_length = 0
|
||||
# download partial, continue download
|
||||
if os.path.exists(temp_file_path):
|
||||
# resuming from interrupted download is also considered as retry
|
||||
has_retry = True
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
partial_length = f.seek(0, io.SEEK_END)
|
||||
for callback in progress_callbacks:
|
||||
callback.update(partial_length)
|
||||
|
||||
# Check if download is complete
|
||||
if partial_length >= file_size:
|
||||
break
|
||||
# closed range[], from 0.
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||
file_size - 1)
|
||||
with open(temp_file_path, 'ab+') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
headers=get_headers,
|
||||
cookies=cookies,
|
||||
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
||||
r.raise_for_status()
|
||||
for chunk in r.iter_content(
|
||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
f.write(chunk)
|
||||
# hash would be discarded in retry case anyway
|
||||
if not has_retry:
|
||||
hash_sha256.update(chunk)
|
||||
# Check if download is complete
|
||||
if partial_length >= file_size:
|
||||
break
|
||||
# closed range[], from 0.
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||
file_size - 1)
|
||||
with open(temp_file_path, 'ab+') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
headers=get_headers,
|
||||
cookies=cookies,
|
||||
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
||||
r.raise_for_status()
|
||||
for chunk in r.iter_content(
|
||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
for callback in progress_callbacks:
|
||||
callback.update(len(chunk))
|
||||
f.write(chunk)
|
||||
# hash would be discarded in retry case anyway
|
||||
if not has_retry:
|
||||
hash_sha256.update(chunk)
|
||||
break
|
||||
except Exception as e: # no matter what happen, we will retry.
|
||||
has_retry = True
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
for callback in progress_callbacks:
|
||||
callback.end()
|
||||
# if anything went wrong, we would discard the real-time computed hash and return None
|
||||
return None if has_retry else hash_sha256.hexdigest()
|
||||
|
||||
@@ -675,6 +682,7 @@ def download_file(
|
||||
headers,
|
||||
cookies,
|
||||
disable_tqdm=False,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
):
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||
@@ -686,6 +694,7 @@ def download_file(
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_meta['Size'],
|
||||
disable_tqdm=disable_tqdm,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
else:
|
||||
file_digest = http_get_model_file(
|
||||
@@ -696,6 +705,7 @@ def download_file(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
disable_tqdm=disable_tqdm,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
|
||||
# check file integrity
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
import uuid
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.errors import InvalidParameter
|
||||
@@ -23,6 +23,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.thread_utils import thread_executor
|
||||
from .callback import ProgressCallback
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -42,6 +43,7 @@ def snapshot_download(
|
||||
max_workers: int = 8,
|
||||
repo_id: str = None,
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
@@ -77,6 +79,8 @@ def snapshot_download(
|
||||
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
||||
For hugging-face compatibility.
|
||||
max_workers (`int`): The maximum number of workers to download files, default 8.
|
||||
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
|
||||
progress callbacks to track the download progress.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
@@ -118,7 +122,8 @@ def snapshot_download(
|
||||
local_dir=local_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns,
|
||||
max_workers=max_workers)
|
||||
max_workers=max_workers,
|
||||
progress_callbacks=progress_callbacks)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
@@ -213,6 +218,7 @@ def _snapshot_download(
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
max_workers: int = 8,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
):
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
@@ -304,6 +310,7 @@ def _snapshot_download(
|
||||
allow_patterns=allow_patterns,
|
||||
max_workers=max_workers,
|
||||
endpoint=endpoint,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
if '.' in repo_id:
|
||||
masked_directory = get_model_masked_directory(
|
||||
@@ -362,6 +369,7 @@ def _snapshot_download(
|
||||
allow_patterns=allow_patterns,
|
||||
max_workers=max_workers,
|
||||
endpoint=endpoint,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
@@ -449,6 +457,7 @@ def _download_file_lists(
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
max_workers: int = 8,
|
||||
endpoint: Optional[str] = None,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
):
|
||||
ignore_patterns = _normalize_patterns(ignore_patterns)
|
||||
allow_patterns = _normalize_patterns(allow_patterns)
|
||||
@@ -532,6 +541,7 @@ def _download_file_lists(
|
||||
headers,
|
||||
cookies,
|
||||
disable_tqdm=False,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
|
||||
if len(filtered_repo_files) > 0:
|
||||
|
||||
52
tests/hub/test_download_callback.py
Normal file
52
tests/hub/test_download_callback.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.hub import ProgressCallback
|
||||
|
||||
|
||||
class NewProgressCallback(ProgressCallback):
|
||||
all_files = set() # just for test
|
||||
|
||||
def __init__(self, filename: str, file_size: int):
|
||||
super().__init__(filename, file_size)
|
||||
self.progress = tqdm(total=file_size)
|
||||
self.all_files.add(filename)
|
||||
|
||||
def update(self, size: int):
|
||||
self.progress.update(size)
|
||||
|
||||
def end(self):
|
||||
self.all_files.remove(self.filename)
|
||||
assert self.progress.n == self.progress.total == self.file_size
|
||||
self.progress.close()
|
||||
|
||||
|
||||
class ProgressCallbackTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_progress_callback(self):
|
||||
model_dir = snapshot_download(
|
||||
'swift/test_lora',
|
||||
progress_callbacks=[NewProgressCallback],
|
||||
cache_dir=self.temp_dir.name)
|
||||
print(f'model_dir: {model_dir}')
|
||||
self.assertTrue(len(NewProgressCallback.all_files) == 0)
|
||||
|
||||
def test_empty_progress_callback(self):
|
||||
model_dir = snapshot_download(
|
||||
'swift/test_lora',
|
||||
progress_callbacks=[],
|
||||
cache_dir=self.temp_dir.name)
|
||||
print(f'model_dir: {model_dir}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user