diff --git a/download_test.py b/download_test.py new file mode 100644 index 00000000..1dd2b29a --- /dev/null +++ b/download_test.py @@ -0,0 +1,26 @@ +import shutil + +from modelscope import dataset_snapshot_download, snapshot_download +from modelscope.utils.ms_tqdm import timing_decorator + +# shutil.rmtree("/root/.cache/modelscope/datasets/AlexEz", ignore_errors=True) +shutil.rmtree('/root/.cache/modelscope/hub/AlexEz', ignore_errors=True) + + +@timing_decorator +def total_test(): + snapshot_download(model_id='AlexEz/test_model', max_workers=1) + + +total_test() +# dir = dataset_snapshot_download(dataset_id="AlexEz/image_dataset_example", max_workers=1) + +# print(dir) + +# from modelscope.msdatasets import MsDataset +# ds = MsDataset.load('clip-benchmark/wds_flickr8k', split='test') + +# from huggingface_hub import snapshot_download +# snapshot_download(repo_id='gaia-benchmark/GAIA', repo_type='dataset', force_download=True) + +# print(ds[0]) diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 40ac8a03..25bd8025 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -28,6 +28,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, from modelscope.utils.file_utils import (get_dataset_cache_root, get_model_cache_root) from modelscope.utils.logger import get_logger +from modelscope.utils.ms_tqdm import timing_decorator from .errors import FileDownloadError, InvalidParameter, NotExistError from .utils.caching import ModelFileSystemCache from .utils.utils import (file_integrity_validation, get_endpoint, @@ -372,6 +373,7 @@ def download_part_with_retry(params): retry.sleep() +@timing_decorator def parallel_download( url: str, local_dir: str, @@ -418,6 +420,7 @@ def parallel_download( os.remove(part_file_name) +@timing_decorator def http_get_model_file( url: str, local_dir: str, @@ -589,6 +592,7 @@ def http_get_file( os.replace(temp_file.name, os.path.join(local_dir, file_name)) +@timing_decorator def download_file(url, file_meta, temporary_cache_dir, cache, headers, cookies): if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[ diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 015cadbd..ff284ebf 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -21,12 +21,14 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET, REPO_TYPE_MODEL, REPO_TYPE_SUPPORT) from modelscope.utils.logger import get_logger +from modelscope.utils.ms_tqdm import timing_decorator from .file_download import (create_temporary_directory_and_cache, download_file, get_file_download_url) logger = get_logger() +@timing_decorator def snapshot_download( model_id: str, revision: Optional[str] = DEFAULT_MODEL_REVISION, diff --git a/modelscope/hub/utils/caching.py b/modelscope/hub/utils/caching.py index 675d62a8..58823b58 100644 --- a/modelscope/hub/utils/caching.py +++ b/modelscope/hub/utils/caching.py @@ -10,6 +10,7 @@ from typing import Dict from modelscope.hub.constants import FILE_HASH from modelscope.hub.utils.utils import compute_hash from modelscope.utils.logger import get_logger +from modelscope.utils.ms_tqdm import timing_decorator logger = get_logger() """Implements caching functionality, used internally only @@ -253,6 +254,7 @@ class ModelFileSystemCache(FileSystemCache): } return cache_key + @timing_decorator def exists(self, model_file_info): """Check the file is cached or not. Note existence check will also cover digest check @@ -305,6 +307,7 @@ class ModelFileSystemCache(FileSystemCache): os.remove(file_path) break + @timing_decorator def put_file(self, model_file_info, model_file_location): """Put model on model_file_location to cache, the model first download to /tmp, and move to cache. diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index bb38f26a..f8c1e8cc 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -15,6 +15,7 @@ from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, from modelscope.hub.errors import FileIntegrityError from modelscope.utils.file_utils import get_default_modelscope_cache_dir from modelscope.utils.logger import get_logger +from modelscope.utils.ms_tqdm import timing_decorator logger = get_logger() @@ -95,6 +96,7 @@ def get_endpoint(): return MODELSCOPE_URL_SCHEME + modelscope_domain +@timing_decorator def compute_hash(file_path): BUFFER_SIZE = 1024 * 64 # 64k buffer size sha256_hash = hashlib.sha256() @@ -107,6 +109,7 @@ def compute_hash(file_path): return sha256_hash.hexdigest() +@timing_decorator def file_integrity_validation(file_path, expected_sha256): """Validate the file hash is expected, if not, delete the file diff --git a/modelscope/utils/ms_tqdm.py b/modelscope/utils/ms_tqdm.py new file mode 100644 index 00000000..4c05f67f --- /dev/null +++ b/modelscope/utils/ms_tqdm.py @@ -0,0 +1,47 @@ +import inspect +import os +import threading +import time +from functools import wraps + +from tqdm.auto import tqdm as old_tqdm + + +def timing_decorator(func): + + @wraps(func) + def wrapper(*args, **kwargs): + # 获取调用函数的文件信息 + frame = inspect.currentframe() + try: + # 获取调用函数的调用者的信息 + caller_frame = frame.f_back + filename = os.path.basename(caller_frame.f_code.co_filename) + line_number = caller_frame.f_lineno + finally: + del frame # 明确删除以防止循环引用 + + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed_time = end_time - start_time + + # 打印丰富的调试信息 + print( + f"Function '{func.__name__}' in {filename} - line {line_number}, took {elapsed_time:.4f} seconds." + ) + + return result + + return wrapper + + +class tqdm(old_tqdm): + _lock = threading.Lock() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def update(self, n=1): + with self._lock: + super().update(n)