This commit is contained in:
Yunnglin
2024-12-04 17:36:47 +08:00
parent 0411beeb05
commit 623ebf6c61
6 changed files with 85 additions and 0 deletions

26
download_test.py Normal file
View File

@@ -0,0 +1,26 @@
import shutil
from modelscope import dataset_snapshot_download, snapshot_download
from modelscope.utils.ms_tqdm import timing_decorator
# shutil.rmtree("/root/.cache/modelscope/datasets/AlexEz", ignore_errors=True)
shutil.rmtree('/root/.cache/modelscope/hub/AlexEz', ignore_errors=True)
@timing_decorator
def total_test():
snapshot_download(model_id='AlexEz/test_model', max_workers=1)
total_test()
# dir = dataset_snapshot_download(dataset_id="AlexEz/image_dataset_example", max_workers=1)
# print(dir)
# from modelscope.msdatasets import MsDataset
# ds = MsDataset.load('clip-benchmark/wds_flickr8k', split='test')
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id='gaia-benchmark/GAIA', repo_type='dataset', force_download=True)
# print(ds[0])

View File

@@ -28,6 +28,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
from modelscope.utils.file_utils import (get_dataset_cache_root,
get_model_cache_root)
from modelscope.utils.logger import get_logger
from modelscope.utils.ms_tqdm import timing_decorator
from .errors import FileDownloadError, InvalidParameter, NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_endpoint,
@@ -372,6 +373,7 @@ def download_part_with_retry(params):
retry.sleep()
@timing_decorator
def parallel_download(
url: str,
local_dir: str,
@@ -418,6 +420,7 @@ def parallel_download(
os.remove(part_file_name)
@timing_decorator
def http_get_model_file(
url: str,
local_dir: str,
@@ -589,6 +592,7 @@ def http_get_file(
os.replace(temp_file.name, os.path.join(local_dir, file_name))
@timing_decorator
def download_file(url, file_meta, temporary_cache_dir, cache, headers,
cookies):
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[

View File

@@ -21,12 +21,14 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT)
from modelscope.utils.logger import get_logger
from modelscope.utils.ms_tqdm import timing_decorator
from .file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url)
logger = get_logger()
@timing_decorator
def snapshot_download(
model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION,

View File

@@ -10,6 +10,7 @@ from typing import Dict
from modelscope.hub.constants import FILE_HASH
from modelscope.hub.utils.utils import compute_hash
from modelscope.utils.logger import get_logger
from modelscope.utils.ms_tqdm import timing_decorator
logger = get_logger()
"""Implements caching functionality, used internally only
@@ -253,6 +254,7 @@ class ModelFileSystemCache(FileSystemCache):
}
return cache_key
@timing_decorator
def exists(self, model_file_info):
"""Check the file is cached or not. Note existence check will also cover digest check
@@ -305,6 +307,7 @@ class ModelFileSystemCache(FileSystemCache):
os.remove(file_path)
break
@timing_decorator
def put_file(self, model_file_info, model_file_location):
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache.

View File

@@ -15,6 +15,7 @@ from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
from modelscope.hub.errors import FileIntegrityError
from modelscope.utils.file_utils import get_default_modelscope_cache_dir
from modelscope.utils.logger import get_logger
from modelscope.utils.ms_tqdm import timing_decorator
logger = get_logger()
@@ -95,6 +96,7 @@ def get_endpoint():
return MODELSCOPE_URL_SCHEME + modelscope_domain
@timing_decorator
def compute_hash(file_path):
BUFFER_SIZE = 1024 * 64 # 64k buffer size
sha256_hash = hashlib.sha256()
@@ -107,6 +109,7 @@ def compute_hash(file_path):
return sha256_hash.hexdigest()
@timing_decorator
def file_integrity_validation(file_path, expected_sha256):
"""Validate the file hash is expected, if not, delete the file

View File

@@ -0,0 +1,47 @@
import inspect
import os
import threading
import time
from functools import wraps
from tqdm.auto import tqdm as old_tqdm
def timing_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 获取调用函数的文件信息
frame = inspect.currentframe()
try:
# 获取调用函数的调用者的信息
caller_frame = frame.f_back
filename = os.path.basename(caller_frame.f_code.co_filename)
line_number = caller_frame.f_lineno
finally:
del frame # 明确删除以防止循环引用
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
elapsed_time = end_time - start_time
# 打印丰富的调试信息
print(
f"Function '{func.__name__}' in {filename} - line {line_number}, took {elapsed_time:.4f} seconds."
)
return result
return wrapper
class tqdm(old_tqdm):
_lock = threading.Lock()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def update(self, n=1):
with self._lock:
super().update(n)