mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
add if download interval is too small, use local cache (#752)
Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
@@ -30,6 +30,8 @@ MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
|
||||
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
|
||||
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
|
||||
MODELSCOPE_REQUEST_ID = 'X-Request-ID'
|
||||
MINIMUM_DOWNLOAD_INTERVAL_SECONDS = os.environ.get(
|
||||
'MODELSCOPE_MINIMUM_DOWNLOAD_INTERVAL_SECONDS', 10)
|
||||
|
||||
|
||||
class Licenses(object):
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -10,7 +12,8 @@ from typing import Dict, List, Optional, Union
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
from .constants import (FILE_HASH, MINIMUM_DOWNLOAD_INTERVAL_SECONDS,
|
||||
MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
||||
from .file_download import (get_file_download_url, http_get_file,
|
||||
parallel_download)
|
||||
@@ -20,6 +23,8 @@ from .utils.utils import (file_integrity_validation, get_cache_dir,
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
recent_downloaded = threading.local()
|
||||
|
||||
|
||||
def snapshot_download(model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
@@ -75,6 +80,18 @@ def snapshot_download(model_id: str,
|
||||
name = name.replace('.', '___')
|
||||
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
|
||||
is_recent_downloaded = False
|
||||
current_time = time.time()
|
||||
recent_download_models = getattr(recent_downloaded, 'models', None)
|
||||
if recent_download_models is None:
|
||||
recent_downloaded.models = {}
|
||||
else:
|
||||
if model_id in recent_download_models:
|
||||
recent_download_time = recent_download_models[model_id]
|
||||
if current_time - recent_download_time < MINIMUM_DOWNLOAD_INTERVAL_SECONDS:
|
||||
is_recent_downloaded = True
|
||||
recent_download_models[model_id] = current_time
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
raise ValueError(
|
||||
@@ -85,6 +102,9 @@ def snapshot_download(model_id: str,
|
||||
% revision)
|
||||
return cache.get_root_location(
|
||||
) # we can not confirm the cached file is for snapshot 'revision'
|
||||
elif is_recent_downloaded:
|
||||
logger.warning('Download interval is too small, use local cache')
|
||||
return cache.get_root_location()
|
||||
else:
|
||||
# make headers
|
||||
headers = {
|
||||
@@ -167,5 +187,5 @@ def snapshot_download(model_id: str,
|
||||
cache.put_file(model_file, temp_file)
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
|
||||
recent_downloaded.models[model_id] = time.time()
|
||||
return os.path.join(cache.get_root_location())
|
||||
|
||||
Reference in New Issue
Block a user