Files
modelscope/tests/hub/test_download_callback.py

53 lines
1.4 KiB
Python
Raw Permalink Normal View History

import tempfile
import unittest
from tqdm import tqdm
from modelscope import snapshot_download
from modelscope.hub import ProgressCallback
class NewProgressCallback(ProgressCallback):
all_files = set() # just for test
def __init__(self, filename: str, file_size: int):
super().__init__(filename, file_size)
self.progress = tqdm(total=file_size)
self.all_files.add(filename)
def update(self, size: int):
self.progress.update(size)
def end(self):
self.all_files.remove(self.filename)
assert self.progress.n == self.progress.total == self.file_size
self.progress.close()
class ProgressCallbackTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_progress_callback(self):
model_dir = snapshot_download(
'swift/test_lora',
progress_callbacks=[NewProgressCallback],
cache_dir=self.temp_dir.name)
print(f'model_dir: {model_dir}')
self.assertTrue(len(NewProgressCallback.all_files) == 0)
def test_empty_progress_callback(self):
model_dir = snapshot_download(
'swift/test_lora',
progress_callbacks=[],
cache_dir=self.temp_dir.name)
print(f'model_dir: {model_dir}')
if __name__ == '__main__':
unittest.main()