mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[hub] Support ProgressCallback (#1380)
1. Add ProgressCallback for file download
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
from .callback import ProgressCallback
|
||||||
|
|||||||
34
modelscope/hub/callback.py
Normal file
34
modelscope/hub/callback.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressCallback:
|
||||||
|
|
||||||
|
def __init__(self, filename: str, file_size: int):
|
||||||
|
self.filename = filename
|
||||||
|
self.file_size = file_size
|
||||||
|
|
||||||
|
def update(self, size: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TqdmCallback(ProgressCallback):
|
||||||
|
|
||||||
|
def __init__(self, filename: str, file_size: int):
|
||||||
|
super().__init__(filename, file_size)
|
||||||
|
self.progress = tqdm(
|
||||||
|
unit='B',
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
total=file_size if file_size > 0 else 1,
|
||||||
|
initial=0,
|
||||||
|
desc='Downloading [' + self.filename + ']',
|
||||||
|
leave=True)
|
||||||
|
|
||||||
|
def update(self, size: int):
|
||||||
|
self.progress.update(size)
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
self.progress.close()
|
||||||
@@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from http.cookiejar import CookieJar
|
from http.cookiejar import CookieJar
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from requests.adapters import Retry
|
from requests.adapters import Retry
|
||||||
@@ -30,6 +30,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
|||||||
from modelscope.utils.file_utils import (get_dataset_cache_root,
|
from modelscope.utils.file_utils import (get_dataset_cache_root,
|
||||||
get_model_cache_root)
|
get_model_cache_root)
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
from .callback import ProgressCallback, TqdmCallback
|
||||||
from .errors import FileDownloadError, InvalidParameter, NotExistError
|
from .errors import FileDownloadError, InvalidParameter, NotExistError
|
||||||
from .utils.caching import ModelFileSystemCache
|
from .utils.caching import ModelFileSystemCache
|
||||||
from .utils.utils import (file_integrity_validation, get_endpoint,
|
from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||||
@@ -391,7 +392,7 @@ def get_file_download_url(model_id: str,
|
|||||||
|
|
||||||
def download_part_with_retry(params):
|
def download_part_with_retry(params):
|
||||||
# unpack parameters
|
# unpack parameters
|
||||||
model_file_path, progress, start, end, url, file_name, cookies, headers = params
|
model_file_path, progress_callbacks, start, end, url, file_name, cookies, headers = params
|
||||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||||
retry = Retry(
|
retry = Retry(
|
||||||
@@ -406,7 +407,8 @@ def download_part_with_retry(params):
|
|||||||
part_file_name): # download partial, continue download
|
part_file_name): # download partial, continue download
|
||||||
with open(part_file_name, 'rb') as f:
|
with open(part_file_name, 'rb') as f:
|
||||||
partial_length = f.seek(0, io.SEEK_END)
|
partial_length = f.seek(0, io.SEEK_END)
|
||||||
progress.update(partial_length)
|
for callback in progress_callbacks:
|
||||||
|
callback.update(partial_length)
|
||||||
download_start = start + partial_length
|
download_start = start + partial_length
|
||||||
if download_start > end:
|
if download_start > end:
|
||||||
break # this part is download completed.
|
break # this part is download completed.
|
||||||
@@ -422,7 +424,8 @@ def download_part_with_retry(params):
|
|||||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||||
if chunk: # filter out keep-alive new chunks
|
if chunk: # filter out keep-alive new chunks
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
progress.update(len(chunk))
|
for callback in progress_callbacks:
|
||||||
|
callback.update(len(chunk))
|
||||||
break
|
break
|
||||||
except (Exception) as e: # no matter what exception, we will retry.
|
except (Exception) as e: # no matter what exception, we will retry.
|
||||||
retry = retry.increment('GET', url, error=e)
|
retry = retry.increment('GET', url, error=e)
|
||||||
@@ -438,37 +441,35 @@ def parallel_download(url: str,
|
|||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
file_size: int = None,
|
file_size: int = None,
|
||||||
disable_tqdm: bool = False,
|
disable_tqdm: bool = False,
|
||||||
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
endpoint: str = None):
|
endpoint: str = None):
|
||||||
|
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
|
||||||
|
)
|
||||||
|
if not disable_tqdm:
|
||||||
|
progress_callbacks.append(TqdmCallback)
|
||||||
|
progress_callbacks = [
|
||||||
|
callback(file_name, file_size) for callback in progress_callbacks
|
||||||
|
]
|
||||||
# create temp file
|
# create temp file
|
||||||
with tqdm(
|
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
||||||
unit='B',
|
tasks = []
|
||||||
unit_scale=True,
|
file_path = os.path.join(local_dir, file_name)
|
||||||
unit_divisor=1024,
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
total=file_size,
|
for idx in range(int(file_size / PART_SIZE)):
|
||||||
initial=0,
|
start = idx * PART_SIZE
|
||||||
desc='Downloading [' + file_name + ']',
|
end = (idx + 1) * PART_SIZE - 1
|
||||||
leave=True,
|
tasks.append((file_path, progress_callbacks, start, end, url,
|
||||||
disable=disable_tqdm,
|
file_name, cookies, headers))
|
||||||
) as progress:
|
if end + 1 < file_size:
|
||||||
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
tasks.append((file_path, progress_callbacks, end + 1, file_size - 1,
|
||||||
tasks = []
|
url, file_name, cookies, headers))
|
||||||
file_path = os.path.join(local_dir, file_name)
|
parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
|
||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
# download every part
|
||||||
for idx in range(int(file_size / PART_SIZE)):
|
with ThreadPoolExecutor(
|
||||||
start = idx * PART_SIZE
|
max_workers=parallels, thread_name_prefix='download') as executor:
|
||||||
end = (idx + 1) * PART_SIZE - 1
|
list(executor.map(download_part_with_retry, tasks))
|
||||||
tasks.append((file_path, progress, start, end, url, file_name,
|
for callback in progress_callbacks:
|
||||||
cookies, headers))
|
callback.end()
|
||||||
if end + 1 < file_size:
|
|
||||||
tasks.append((file_path, progress, end + 1, file_size - 1, url,
|
|
||||||
file_name, cookies, headers))
|
|
||||||
parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
|
|
||||||
# download every part
|
|
||||||
with ThreadPoolExecutor(
|
|
||||||
max_workers=parallels,
|
|
||||||
thread_name_prefix='download') as executor:
|
|
||||||
list(executor.map(download_part_with_retry, tasks))
|
|
||||||
|
|
||||||
# merge parts.
|
# merge parts.
|
||||||
hash_sha256 = hashlib.sha256()
|
hash_sha256 = hashlib.sha256()
|
||||||
with open(os.path.join(local_dir, file_name), 'wb') as output_file:
|
with open(os.path.join(local_dir, file_name), 'wb') as output_file:
|
||||||
@@ -493,6 +494,7 @@ def http_get_model_file(
|
|||||||
cookies: CookieJar,
|
cookies: CookieJar,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
disable_tqdm: bool = False,
|
disable_tqdm: bool = False,
|
||||||
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
):
|
):
|
||||||
"""Download remote file, will retry 5 times before giving up on errors.
|
"""Download remote file, will retry 5 times before giving up on errors.
|
||||||
|
|
||||||
@@ -510,11 +512,20 @@ def http_get_model_file(
|
|||||||
headers(Dict[str, str], optional):
|
headers(Dict[str, str], optional):
|
||||||
http headers to carry necessary info when requesting the remote file
|
http headers to carry necessary info when requesting the remote file
|
||||||
disable_tqdm(bool, optional): Disable the progress bar with tqdm.
|
disable_tqdm(bool, optional): Disable the progress bar with tqdm.
|
||||||
|
progress_callbacks(List[Type[ProgressCallback]], optional):
|
||||||
|
progress callbacks to track the download progress.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileDownloadError: File download failed.
|
FileDownloadError: File download failed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
|
||||||
|
)
|
||||||
|
if not disable_tqdm:
|
||||||
|
progress_callbacks.append(TqdmCallback)
|
||||||
|
progress_callbacks = [
|
||||||
|
callback(file_name, file_size) for callback in progress_callbacks
|
||||||
|
]
|
||||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||||
temp_file_path = os.path.join(local_dir, file_name)
|
temp_file_path = os.path.join(local_dir, file_name)
|
||||||
@@ -527,60 +538,56 @@ def http_get_model_file(
|
|||||||
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||||
backoff_factor=1,
|
backoff_factor=1,
|
||||||
allowed_methods=['GET'])
|
allowed_methods=['GET'])
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
with tqdm(
|
if file_size == 0:
|
||||||
unit='B',
|
# Avoid empty file server request
|
||||||
unit_scale=True,
|
with open(temp_file_path, 'w+'):
|
||||||
unit_divisor=1024,
|
for callback in progress_callbacks:
|
||||||
total=file_size if file_size > 0 else 1,
|
callback.update(1)
|
||||||
initial=0,
|
break
|
||||||
desc='Downloading [' + file_name + ']',
|
# Determine the length of any existing partial download
|
||||||
leave=True,
|
partial_length = 0
|
||||||
disable=disable_tqdm,
|
# download partial, continue download
|
||||||
) as progress:
|
if os.path.exists(temp_file_path):
|
||||||
if file_size == 0:
|
# resuming from interrupted download is also considered as retry
|
||||||
# Avoid empty file server request
|
has_retry = True
|
||||||
with open(temp_file_path, 'w+'):
|
with open(temp_file_path, 'rb') as f:
|
||||||
progress.update(1)
|
partial_length = f.seek(0, io.SEEK_END)
|
||||||
break
|
for callback in progress_callbacks:
|
||||||
# Determine the length of any existing partial download
|
callback.update(partial_length)
|
||||||
partial_length = 0
|
|
||||||
# download partial, continue download
|
|
||||||
if os.path.exists(temp_file_path):
|
|
||||||
# resuming from interrupted download is also considered as retry
|
|
||||||
has_retry = True
|
|
||||||
with open(temp_file_path, 'rb') as f:
|
|
||||||
partial_length = f.seek(0, io.SEEK_END)
|
|
||||||
progress.update(partial_length)
|
|
||||||
|
|
||||||
# Check if download is complete
|
# Check if download is complete
|
||||||
if partial_length >= file_size:
|
if partial_length >= file_size:
|
||||||
break
|
break
|
||||||
# closed range[], from 0.
|
# closed range[], from 0.
|
||||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||||
file_size - 1)
|
file_size - 1)
|
||||||
with open(temp_file_path, 'ab+') as f:
|
with open(temp_file_path, 'ab+') as f:
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
url,
|
url,
|
||||||
stream=True,
|
stream=True,
|
||||||
headers=get_headers,
|
headers=get_headers,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
for chunk in r.iter_content(
|
for chunk in r.iter_content(
|
||||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||||
if chunk: # filter out keep-alive new chunks
|
if chunk: # filter out keep-alive new chunks
|
||||||
progress.update(len(chunk))
|
for callback in progress_callbacks:
|
||||||
f.write(chunk)
|
callback.update(len(chunk))
|
||||||
# hash would be discarded in retry case anyway
|
f.write(chunk)
|
||||||
if not has_retry:
|
# hash would be discarded in retry case anyway
|
||||||
hash_sha256.update(chunk)
|
if not has_retry:
|
||||||
|
hash_sha256.update(chunk)
|
||||||
break
|
break
|
||||||
except Exception as e: # no matter what happen, we will retry.
|
except Exception as e: # no matter what happen, we will retry.
|
||||||
has_retry = True
|
has_retry = True
|
||||||
retry = retry.increment('GET', url, error=e)
|
retry = retry.increment('GET', url, error=e)
|
||||||
retry.sleep()
|
retry.sleep()
|
||||||
|
for callback in progress_callbacks:
|
||||||
|
callback.end()
|
||||||
# if anything went wrong, we would discard the real-time computed hash and return None
|
# if anything went wrong, we would discard the real-time computed hash and return None
|
||||||
return None if has_retry else hash_sha256.hexdigest()
|
return None if has_retry else hash_sha256.hexdigest()
|
||||||
|
|
||||||
@@ -675,6 +682,7 @@ def download_file(
|
|||||||
headers,
|
headers,
|
||||||
cookies,
|
cookies,
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
):
|
):
|
||||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
||||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||||
@@ -686,6 +694,7 @@ def download_file(
|
|||||||
cookies=None if cookies is None else cookies.get_dict(),
|
cookies=None if cookies is None else cookies.get_dict(),
|
||||||
file_size=file_meta['Size'],
|
file_size=file_meta['Size'],
|
||||||
disable_tqdm=disable_tqdm,
|
disable_tqdm=disable_tqdm,
|
||||||
|
progress_callbacks=progress_callbacks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
file_digest = http_get_model_file(
|
file_digest = http_get_model_file(
|
||||||
@@ -696,6 +705,7 @@ def download_file(
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
disable_tqdm=disable_tqdm,
|
disable_tqdm=disable_tqdm,
|
||||||
|
progress_callbacks=progress_callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# check file integrity
|
# check file integrity
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from http.cookiejar import CookieJar
|
from http.cookiejar import CookieJar
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||||
from modelscope.hub.errors import InvalidParameter
|
from modelscope.hub.errors import InvalidParameter
|
||||||
@@ -23,6 +23,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
|||||||
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from modelscope.utils.thread_utils import thread_executor
|
from modelscope.utils.thread_utils import thread_executor
|
||||||
|
from .callback import ProgressCallback
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
@@ -42,6 +43,7 @@ def snapshot_download(
|
|||||||
max_workers: int = 8,
|
max_workers: int = 8,
|
||||||
repo_id: str = None,
|
repo_id: str = None,
|
||||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||||
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Download all files of a repo.
|
"""Download all files of a repo.
|
||||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||||
@@ -77,6 +79,8 @@ def snapshot_download(
|
|||||||
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
||||||
For hugging-face compatibility.
|
For hugging-face compatibility.
|
||||||
max_workers (`int`): The maximum number of workers to download files, default 8.
|
max_workers (`int`): The maximum number of workers to download files, default 8.
|
||||||
|
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
|
||||||
|
progress callbacks to track the download progress.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: the value details.
|
ValueError: the value details.
|
||||||
|
|
||||||
@@ -118,7 +122,8 @@ def snapshot_download(
|
|||||||
local_dir=local_dir,
|
local_dir=local_dir,
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
max_workers=max_workers)
|
max_workers=max_workers,
|
||||||
|
progress_callbacks=progress_callbacks)
|
||||||
|
|
||||||
|
|
||||||
def dataset_snapshot_download(
|
def dataset_snapshot_download(
|
||||||
@@ -213,6 +218,7 @@ def _snapshot_download(
|
|||||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||||
max_workers: int = 8,
|
max_workers: int = 8,
|
||||||
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
):
|
):
|
||||||
if not repo_type:
|
if not repo_type:
|
||||||
repo_type = REPO_TYPE_MODEL
|
repo_type = REPO_TYPE_MODEL
|
||||||
@@ -304,6 +310,7 @@ def _snapshot_download(
|
|||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
max_workers=max_workers,
|
max_workers=max_workers,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
|
progress_callbacks=progress_callbacks,
|
||||||
)
|
)
|
||||||
if '.' in repo_id:
|
if '.' in repo_id:
|
||||||
masked_directory = get_model_masked_directory(
|
masked_directory = get_model_masked_directory(
|
||||||
@@ -362,6 +369,7 @@ def _snapshot_download(
|
|||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
max_workers=max_workers,
|
max_workers=max_workers,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
|
progress_callbacks=progress_callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
cache.save_model_version(revision_info=revision_detail)
|
cache.save_model_version(revision_info=revision_detail)
|
||||||
@@ -449,6 +457,7 @@ def _download_file_lists(
|
|||||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||||
max_workers: int = 8,
|
max_workers: int = 8,
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
):
|
):
|
||||||
ignore_patterns = _normalize_patterns(ignore_patterns)
|
ignore_patterns = _normalize_patterns(ignore_patterns)
|
||||||
allow_patterns = _normalize_patterns(allow_patterns)
|
allow_patterns = _normalize_patterns(allow_patterns)
|
||||||
@@ -532,6 +541,7 @@ def _download_file_lists(
|
|||||||
headers,
|
headers,
|
||||||
cookies,
|
cookies,
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
|
progress_callbacks=progress_callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(filtered_repo_files) > 0:
|
if len(filtered_repo_files) > 0:
|
||||||
|
|||||||
52
tests/hub/test_download_callback.py
Normal file
52
tests/hub/test_download_callback.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from modelscope.hub import ProgressCallback
|
||||||
|
|
||||||
|
|
||||||
|
class NewProgressCallback(ProgressCallback):
|
||||||
|
all_files = set() # just for test
|
||||||
|
|
||||||
|
def __init__(self, filename: str, file_size: int):
|
||||||
|
super().__init__(filename, file_size)
|
||||||
|
self.progress = tqdm(total=file_size)
|
||||||
|
self.all_files.add(filename)
|
||||||
|
|
||||||
|
def update(self, size: int):
|
||||||
|
self.progress.update(size)
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
self.all_files.remove(self.filename)
|
||||||
|
assert self.progress.n == self.progress.total == self.file_size
|
||||||
|
self.progress.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressCallbackTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
|
def test_progress_callback(self):
|
||||||
|
model_dir = snapshot_download(
|
||||||
|
'swift/test_lora',
|
||||||
|
progress_callbacks=[NewProgressCallback],
|
||||||
|
cache_dir=self.temp_dir.name)
|
||||||
|
print(f'model_dir: {model_dir}')
|
||||||
|
self.assertTrue(len(NewProgressCallback.all_files) == 0)
|
||||||
|
|
||||||
|
def test_empty_progress_callback(self):
|
||||||
|
model_dir = snapshot_download(
|
||||||
|
'swift/test_lora',
|
||||||
|
progress_callbacks=[],
|
||||||
|
cache_dir=self.temp_dir.name)
|
||||||
|
print(f'model_dir: {model_dir}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user