mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
|
|
import tempfile
|
||
|
|
import unittest
|
||
|
|
|
||
|
|
from tqdm import tqdm
|
||
|
|
|
||
|
|
from modelscope import snapshot_download
|
||
|
|
from modelscope.hub import ProgressCallback
|
||
|
|
|
||
|
|
|
||
|
|
class NewProgressCallback(ProgressCallback):
|
||
|
|
all_files = set() # just for test
|
||
|
|
|
||
|
|
def __init__(self, filename: str, file_size: int):
|
||
|
|
super().__init__(filename, file_size)
|
||
|
|
self.progress = tqdm(total=file_size)
|
||
|
|
self.all_files.add(filename)
|
||
|
|
|
||
|
|
def update(self, size: int):
|
||
|
|
self.progress.update(size)
|
||
|
|
|
||
|
|
def end(self):
|
||
|
|
self.all_files.remove(self.filename)
|
||
|
|
assert self.progress.n == self.progress.total == self.file_size
|
||
|
|
self.progress.close()
|
||
|
|
|
||
|
|
|
||
|
|
class ProgressCallbackTest(unittest.TestCase):
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||
|
|
|
||
|
|
def tearDown(self):
|
||
|
|
self.temp_dir.cleanup()
|
||
|
|
|
||
|
|
def test_progress_callback(self):
|
||
|
|
model_dir = snapshot_download(
|
||
|
|
'swift/test_lora',
|
||
|
|
progress_callbacks=[NewProgressCallback],
|
||
|
|
cache_dir=self.temp_dir.name)
|
||
|
|
print(f'model_dir: {model_dir}')
|
||
|
|
self.assertTrue(len(NewProgressCallback.all_files) == 0)
|
||
|
|
|
||
|
|
def test_empty_progress_callback(self):
|
||
|
|
model_dir = snapshot_download(
|
||
|
|
'swift/test_lora',
|
||
|
|
progress_callbacks=[],
|
||
|
|
cache_dir=self.temp_dir.name)
|
||
|
|
print(f'model_dir: {model_dir}')
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
unittest.main()
|