diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index 362f323d..0804f337 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -30,6 +30,8 @@ MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME' MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 MODELSCOPE_REQUEST_ID = 'X-Request-ID' +MINIMUM_DOWNLOAD_INTERVAL_SECONDS = os.environ.get( + 'MODELSCOPE_MINIMUM_DOWNLOAD_INTERVAL_SECONDS', 10) class Licenses(object): diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index dd332c6b..68548f60 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -3,6 +3,8 @@ import os import re import tempfile +import threading +import time from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, List, Optional, Union @@ -10,7 +12,8 @@ from typing import Dict, List, Optional, Union from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.logger import get_logger -from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS, +from .constants import (FILE_HASH, MINIMUM_DOWNLOAD_INTERVAL_SECONDS, + MODELSCOPE_DOWNLOAD_PARALLELS, MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB) from .file_download import (get_file_download_url, http_get_file, parallel_download) @@ -20,6 +23,8 @@ from .utils.utils import (file_integrity_validation, get_cache_dir, logger = get_logger() +recent_downloaded = threading.local() + def snapshot_download(model_id: str, revision: Optional[str] = DEFAULT_MODEL_REVISION, @@ -75,6 +80,18 @@ def snapshot_download(model_id: str, name = name.replace('.', '___') cache = ModelFileSystemCache(cache_dir, group_or_owner, name) + + is_recent_downloaded = False + current_time = time.time() + recent_download_models = getattr(recent_downloaded, 'models', None) + if recent_download_models is None: + recent_downloaded.models = {} + else: + if model_id in recent_download_models: + recent_download_time = recent_download_models[model_id] + if current_time - recent_download_time < MINIMUM_DOWNLOAD_INTERVAL_SECONDS: + is_recent_downloaded = True + recent_download_models[model_id] = current_time if local_files_only: if len(cache.cached_files) == 0: raise ValueError( @@ -85,6 +102,9 @@ def snapshot_download(model_id: str, % revision) return cache.get_root_location( ) # we can not confirm the cached file is for snapshot 'revision' + elif is_recent_downloaded: + logger.warning('Download interval is too small, use local cache') + return cache.get_root_location() else: # make headers headers = { @@ -167,5 +187,5 @@ def snapshot_download(model_id: str, cache.put_file(model_file, temp_file) cache.save_model_version(revision_info=revision_detail) - + recent_downloaded.models[model_id] = time.time() return os.path.join(cache.get_root_location())