add if download interval is too small, use local cache (#752)

Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
liuyhwangyh
2024-02-04 21:25:09 +08:00
committed by GitHub
parent 3d442b5c7e
commit 5e8a8f4e93
2 changed files with 24 additions and 2 deletions

View File

@@ -30,6 +30,8 @@ MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
MODELSCOPE_REQUEST_ID = 'X-Request-ID'
MINIMUM_DOWNLOAD_INTERVAL_SECONDS = os.environ.get(
'MODELSCOPE_MINIMUM_DOWNLOAD_INTERVAL_SECONDS', 10)
class Licenses(object):

View File

@@ -3,6 +3,8 @@
import os
import re
import tempfile
import threading
import time
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
@@ -10,7 +12,8 @@ from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
from .constants import (FILE_HASH, MINIMUM_DOWNLOAD_INTERVAL_SECONDS,
MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from .file_download import (get_file_download_url, http_get_file,
parallel_download)
@@ -20,6 +23,8 @@ from .utils.utils import (file_integrity_validation, get_cache_dir,
logger = get_logger()
recent_downloaded = threading.local()
def snapshot_download(model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
@@ -75,6 +80,18 @@ def snapshot_download(model_id: str,
name = name.replace('.', '___')
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
is_recent_downloaded = False
current_time = time.time()
recent_download_models = getattr(recent_downloaded, 'models', None)
if recent_download_models is None:
recent_downloaded.models = {}
else:
if model_id in recent_download_models:
recent_download_time = recent_download_models[model_id]
if current_time - recent_download_time < MINIMUM_DOWNLOAD_INTERVAL_SECONDS:
is_recent_downloaded = True
recent_download_models[model_id] = current_time
if local_files_only:
if len(cache.cached_files) == 0:
raise ValueError(
@@ -85,6 +102,9 @@ def snapshot_download(model_id: str,
% revision)
return cache.get_root_location(
) # we can not confirm the cached file is for snapshot 'revision'
elif is_recent_downloaded:
logger.warning('Download interval is too small, use local cache')
return cache.get_root_location()
else:
# make headers
headers = {
@@ -167,5 +187,5 @@ def snapshot_download(model_id: str,
cache.put_file(model_file, temp_file)
cache.save_model_version(revision_info=revision_detail)
recent_downloaded.models[model_id] = time.time()
return os.path.join(cache.get_root_location())