mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
745 lines
30 KiB
Python
745 lines
30 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import copy
|
|
import hashlib
|
|
import io
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import urllib
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from functools import partial
|
|
from http.cookiejar import CookieJar
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Type, Union
|
|
|
|
import requests
|
|
from requests.adapters import Retry
|
|
from tqdm.auto import tqdm
|
|
|
|
from modelscope.hub.api import HubApi, ModelScopeConfig
|
|
from modelscope.hub.constants import (
|
|
API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES,
|
|
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
|
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME)
|
|
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
|
DEFAULT_MODEL_REVISION,
|
|
INTRA_CLOUD_ACCELERATION,
|
|
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
|
REPO_TYPE_SUPPORT)
|
|
from modelscope.utils.file_utils import (get_dataset_cache_root,
|
|
get_model_cache_root)
|
|
from modelscope.utils.logger import get_logger
|
|
from .callback import ProgressCallback, TqdmCallback
|
|
from .errors import FileDownloadError, InvalidParameter, NotExistError
|
|
from .utils.caching import ModelFileSystemCache
|
|
from .utils.utils import (file_integrity_validation, get_endpoint,
|
|
model_id_to_group_owner_name)
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def model_file_download(
|
|
model_id: str,
|
|
file_path: str,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
cache_dir: Optional[str] = None,
|
|
user_agent: Union[Dict, str, None] = None,
|
|
local_files_only: Optional[bool] = False,
|
|
cookies: Optional[CookieJar] = None,
|
|
local_dir: Optional[str] = None,
|
|
token: Optional[str] = None,
|
|
) -> Optional[str]: # pragma: no cover
|
|
"""Download from a given URL and cache it if it's not already present in the local cache.
|
|
|
|
Given a URL, this function looks for the corresponding file in the local
|
|
cache. If it's not there, download it. Then return the path to the cached
|
|
file.
|
|
|
|
Args:
|
|
model_id (str): The model to whom the file to be downloaded belongs.
|
|
file_path(str): Path of the file to be downloaded, relative to the root of model repo.
|
|
revision(str, optional): revision of the model file to be downloaded.
|
|
Can be any of a branch, tag or commit hash.
|
|
cache_dir (str, Path, optional): Path to the folder where cached files are stored.
|
|
user_agent (dict, str, optional): The user-agent info in the form of a dictionary or a string.
|
|
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
|
local cached file if it exists. if `False`, download the file anyway even it exists.
|
|
cookies (CookieJar, optional): The cookie of download request.
|
|
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
|
token (str, optional): The user token.
|
|
|
|
Returns:
|
|
string: string of local file or if networking is off, last version of
|
|
file cached on disk.
|
|
|
|
Raises:
|
|
NotExistError: The file is not exist.
|
|
ValueError: The request parameter error.
|
|
|
|
Note:
|
|
Raises the following errors:
|
|
|
|
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
|
if `use_auth_token=True` and the token cannot be found.
|
|
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
|
if ETag cannot be determined.
|
|
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
if some parameter value is invalid
|
|
"""
|
|
return _repo_file_download(
|
|
model_id,
|
|
file_path,
|
|
repo_type=REPO_TYPE_MODEL,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
user_agent=user_agent,
|
|
local_files_only=local_files_only,
|
|
cookies=cookies,
|
|
local_dir=local_dir,
|
|
token=token)
|
|
|
|
|
|
def dataset_file_download(
|
|
dataset_id: str,
|
|
file_path: str,
|
|
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
|
cache_dir: Union[str, Path, None] = None,
|
|
local_dir: Optional[str] = None,
|
|
user_agent: Optional[Union[Dict, str]] = None,
|
|
local_files_only: Optional[bool] = False,
|
|
cookies: Optional[CookieJar] = None,
|
|
token: Optional[str] = None,
|
|
) -> str:
|
|
"""Download raw files of a dataset.
|
|
Downloads all files at the specified revision. This
|
|
is useful when you want all files from a dataset, because you don't know which
|
|
ones you will need a priori. All files are nested inside a folder in order
|
|
to keep their actual filename relative to that folder.
|
|
|
|
An alternative would be to just clone a dataset but this would require that the
|
|
user always has git and git-lfs installed, and properly configured.
|
|
|
|
Args:
|
|
dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
|
|
file_path (str): The relative path of the file to download.
|
|
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
|
commit hash. NOTE: currently only branch and tag name is supported
|
|
cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
|
|
be save as cache_dir/dataset_id/THE_DATASET_FILES.
|
|
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
|
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
|
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
|
local cached file if it exists.
|
|
cookies (CookieJar, optional): The cookie of the request, default None.
|
|
token (str, optional): The user token.
|
|
Raises:
|
|
ValueError: the value details.
|
|
|
|
Returns:
|
|
str: Local folder path (string) of repo snapshot
|
|
|
|
Note:
|
|
Raises the following errors:
|
|
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
|
if `use_auth_token=True` and the token cannot be found.
|
|
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
|
ETag cannot be determined.
|
|
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
if some parameter value is invalid
|
|
"""
|
|
return _repo_file_download(
|
|
dataset_id,
|
|
file_path,
|
|
repo_type=REPO_TYPE_DATASET,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
user_agent=user_agent,
|
|
local_files_only=local_files_only,
|
|
cookies=cookies,
|
|
local_dir=local_dir,
|
|
token=token)
|
|
|
|
|
|
def _repo_file_download(
|
|
repo_id: str,
|
|
file_path: str,
|
|
*,
|
|
repo_type: str = None,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
cache_dir: Optional[str] = None,
|
|
user_agent: Union[Dict, str, None] = None,
|
|
local_files_only: Optional[bool] = False,
|
|
cookies: Optional[CookieJar] = None,
|
|
local_dir: Optional[str] = None,
|
|
disable_tqdm: bool = False,
|
|
token: Optional[str] = None,
|
|
) -> Optional[str]: # pragma: no cover
|
|
|
|
if not repo_type:
|
|
repo_type = REPO_TYPE_MODEL
|
|
if repo_type not in REPO_TYPE_SUPPORT:
|
|
raise InvalidParameter('Invalid repo type: %s, only support: %s' %
|
|
(repo_type, REPO_TYPE_SUPPORT))
|
|
|
|
temporary_cache_dir, cache = create_temporary_directory_and_cache(
|
|
repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
|
|
|
|
# if local_files_only is `True` and the file already exists in cached_path
|
|
# return the cached path
|
|
if local_files_only:
|
|
cached_file_path = cache.get_file_by_path(file_path)
|
|
if cached_file_path is not None:
|
|
logger.warning(
|
|
"File exists in local cache, but we're not sure it's up to date"
|
|
)
|
|
return cached_file_path
|
|
else:
|
|
raise ValueError(
|
|
'Cannot find the requested files in the cached path and outgoing'
|
|
' traffic has been disabled. To enable look-ups and downloads'
|
|
" online, set 'local_files_only' to False.")
|
|
|
|
_api = HubApi(token=token)
|
|
|
|
headers = {
|
|
'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ),
|
|
'snapshot-identifier': str(uuid.uuid4()),
|
|
}
|
|
|
|
if INTRA_CLOUD_ACCELERATION == 'true':
|
|
region_id: str = (
|
|
os.getenv('INTRA_CLOUD_ACCELERATION_REGION')
|
|
or _api._get_internal_acceleration_domain())
|
|
if region_id:
|
|
logger.info(
|
|
f'Intra-cloud acceleration enabled for downloading from {repo_id}'
|
|
)
|
|
headers['x-aliyun-region-id'] = region_id
|
|
|
|
if cookies is None:
|
|
cookies = _api.get_cookies()
|
|
repo_files = []
|
|
endpoint = _api.get_endpoint_for_read(repo_id=repo_id, repo_type=repo_type)
|
|
file_to_download_meta = None
|
|
if repo_type == REPO_TYPE_MODEL:
|
|
revision = _api.get_valid_revision(
|
|
repo_id, revision=revision, cookies=cookies, endpoint=endpoint)
|
|
# we need to confirm the version is up-to-date
|
|
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
|
|
repo_files = _api.get_model_files(
|
|
model_id=repo_id,
|
|
revision=revision,
|
|
recursive=True,
|
|
use_cookies=False if cookies is None else cookies,
|
|
endpoint=endpoint)
|
|
for repo_file in repo_files:
|
|
if repo_file['Type'] == 'tree':
|
|
continue
|
|
|
|
if repo_file['Path'] == file_path:
|
|
if cache.exists(repo_file):
|
|
file_name = repo_file['Name']
|
|
logger.debug(
|
|
f'File {file_name} already in cache with identical hash, skip downloading!'
|
|
)
|
|
return cache.get_file_by_info(repo_file)
|
|
else:
|
|
file_to_download_meta = repo_file
|
|
break
|
|
elif repo_type == REPO_TYPE_DATASET:
|
|
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
|
if not revision:
|
|
revision = DEFAULT_DATASET_REVISION
|
|
page_number = 1
|
|
page_size = 100
|
|
while True:
|
|
try:
|
|
dataset_files = _api.get_dataset_files(
|
|
repo_id=repo_id,
|
|
revision=revision,
|
|
root_path='/',
|
|
recursive=True,
|
|
page_number=page_number,
|
|
page_size=page_size,
|
|
endpoint=endpoint)
|
|
except Exception as e:
|
|
logger.error(
|
|
f'Get dataset: {repo_id} file list failed, error: {e}')
|
|
break
|
|
|
|
is_exist = False
|
|
for repo_file in dataset_files:
|
|
if repo_file['Type'] == 'tree':
|
|
continue
|
|
|
|
if repo_file['Path'] == file_path:
|
|
if cache.exists(repo_file):
|
|
file_name = repo_file['Name']
|
|
logger.debug(
|
|
f'File {file_name} already in cache with identical hash, skip downloading!'
|
|
)
|
|
return cache.get_file_by_info(repo_file)
|
|
else:
|
|
file_to_download_meta = repo_file
|
|
is_exist = True
|
|
break
|
|
if len(dataset_files) < page_size or is_exist:
|
|
break
|
|
page_number += 1
|
|
|
|
if file_to_download_meta is None:
|
|
raise NotExistError('The file path: %s not exist in: %s' %
|
|
(file_path, repo_id))
|
|
|
|
# we need to download again
|
|
if repo_type == REPO_TYPE_MODEL:
|
|
url_to_download = get_file_download_url(repo_id, file_path, revision,
|
|
endpoint)
|
|
elif repo_type == REPO_TYPE_DATASET:
|
|
url_to_download = _api.get_dataset_file_url(
|
|
file_name=file_to_download_meta['Path'],
|
|
dataset_name=name,
|
|
namespace=group_or_owner,
|
|
revision=revision,
|
|
endpoint=endpoint)
|
|
else:
|
|
raise ValueError(f'Invalid repo type {repo_type}')
|
|
|
|
return download_file(url_to_download, file_to_download_meta,
|
|
temporary_cache_dir, cache, headers, cookies)
|
|
|
|
|
|
def move_legacy_cache_to_standard_dir(cache_dir: str, model_id: str):
|
|
if cache_dir.endswith(os.path.sep):
|
|
cache_dir = cache_dir.strip(os.path.sep)
|
|
legacy_cache_root = os.path.dirname(cache_dir)
|
|
base_name = os.path.basename(cache_dir)
|
|
if base_name == 'datasets':
|
|
# datasets will not be not affected
|
|
return
|
|
if not legacy_cache_root.endswith('hub'):
|
|
# Two scenarios:
|
|
# We have restructured ModelScope cache directory,
|
|
# Scenery 1:
|
|
# When MODELSCOPE_CACHE is not set, the default directory remains
|
|
# the same at ~/.cache/modelscope/hub
|
|
# Scenery 2:
|
|
# When MODELSCOPE_CACHE is set, the cache directory is moved from
|
|
# $MODELSCOPE_CACHE/hub to $MODELSCOPE_CACHE/. In this case,
|
|
# we will be migrating the hub directory accordingly.
|
|
legacy_cache_root = os.path.join(legacy_cache_root, 'hub')
|
|
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
|
name = name.replace('.', '___')
|
|
temporary_cache_dir = os.path.join(cache_dir, group_or_owner, name)
|
|
legacy_cache_dir = os.path.join(legacy_cache_root, group_or_owner, name)
|
|
if os.path.exists(
|
|
legacy_cache_dir) and not os.path.exists(temporary_cache_dir):
|
|
logger.info(
|
|
f'Legacy cache dir exists: {legacy_cache_dir}, move to {temporary_cache_dir}'
|
|
)
|
|
try:
|
|
shutil.move(legacy_cache_dir, temporary_cache_dir)
|
|
except Exception: # noqa
|
|
# Failed, skip
|
|
pass
|
|
|
|
|
|
def create_temporary_directory_and_cache(model_id: str,
|
|
local_dir: str = None,
|
|
cache_dir: str = None,
|
|
repo_type: str = REPO_TYPE_MODEL):
|
|
if repo_type == REPO_TYPE_MODEL:
|
|
default_cache_root = get_model_cache_root()
|
|
elif repo_type == REPO_TYPE_DATASET:
|
|
default_cache_root = get_dataset_cache_root()
|
|
else:
|
|
raise ValueError(
|
|
f'repo_type only support model and dataset, but now is : {repo_type}'
|
|
)
|
|
|
|
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
|
if local_dir is not None:
|
|
temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
|
|
cache = ModelFileSystemCache(local_dir)
|
|
else:
|
|
if cache_dir is None:
|
|
cache_dir = default_cache_root
|
|
move_legacy_cache_to_standard_dir(cache_dir, model_id)
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME,
|
|
group_or_owner, name)
|
|
name = name.replace('.', '___')
|
|
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
|
|
|
os.makedirs(temporary_cache_dir, exist_ok=True)
|
|
return temporary_cache_dir, cache
|
|
|
|
|
|
def get_file_download_url(model_id: str,
|
|
file_path: str,
|
|
revision: str,
|
|
endpoint: Optional[str] = None):
|
|
"""Format file download url according to `model_id`, `revision` and `file_path`.
|
|
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
|
|
the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
|
|
|
|
Args:
|
|
model_id (str): The model_id.
|
|
file_path (str): File path
|
|
revision (str): File revision.
|
|
endpoint (str): The remote endpoint
|
|
|
|
Returns:
|
|
str: The file url.
|
|
"""
|
|
file_path = urllib.parse.quote_plus(file_path)
|
|
revision = urllib.parse.quote_plus(revision)
|
|
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
|
|
if not endpoint:
|
|
endpoint = get_endpoint()
|
|
return download_url_template.format(
|
|
endpoint=endpoint,
|
|
model_id=model_id,
|
|
revision=revision,
|
|
file_path=file_path,
|
|
)
|
|
|
|
|
|
def download_part_with_retry(params):
|
|
# unpack parameters
|
|
model_file_path, progress_callbacks, start, end, url, file_name, cookies, headers = params
|
|
get_headers = {} if headers is None else copy.deepcopy(headers)
|
|
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
|
retry = Retry(
|
|
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
|
backoff_factor=1,
|
|
allowed_methods=['GET'])
|
|
part_file_name = model_file_path + '_%s_%s' % (start, end)
|
|
while True:
|
|
try:
|
|
partial_length = 0
|
|
if os.path.exists(
|
|
part_file_name): # download partial, continue download
|
|
with open(part_file_name, 'rb') as f:
|
|
partial_length = f.seek(0, io.SEEK_END)
|
|
for callback in progress_callbacks:
|
|
callback.update(partial_length)
|
|
download_start = start + partial_length
|
|
if download_start > end:
|
|
break # this part is download completed.
|
|
get_headers['Range'] = 'bytes=%s-%s' % (download_start, end)
|
|
with open(part_file_name, 'ab+') as f:
|
|
r = requests.get(
|
|
url,
|
|
stream=True,
|
|
headers=get_headers,
|
|
cookies=cookies,
|
|
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
|
for chunk in r.iter_content(
|
|
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
|
if chunk: # filter out keep-alive new chunks
|
|
f.write(chunk)
|
|
for callback in progress_callbacks:
|
|
callback.update(len(chunk))
|
|
break
|
|
except (Exception) as e: # no matter what exception, we will retry.
|
|
retry = retry.increment('GET', url, error=e)
|
|
logger.warning('Downloading: %s failed, reason: %s will retry' %
|
|
(model_file_path, e))
|
|
retry.sleep()
|
|
|
|
|
|
def parallel_download(url: str,
|
|
local_dir: str,
|
|
file_name: str,
|
|
cookies: CookieJar,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
file_size: int = None,
|
|
disable_tqdm: bool = False,
|
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
|
endpoint: str = None):
|
|
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
|
|
)
|
|
if not disable_tqdm:
|
|
progress_callbacks.append(TqdmCallback)
|
|
progress_callbacks = [
|
|
callback(file_name, file_size) for callback in progress_callbacks
|
|
]
|
|
# create temp file
|
|
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
|
tasks = []
|
|
file_path = os.path.join(local_dir, file_name)
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
for idx in range(int(file_size / PART_SIZE)):
|
|
start = idx * PART_SIZE
|
|
end = (idx + 1) * PART_SIZE - 1
|
|
tasks.append((file_path, progress_callbacks, start, end, url,
|
|
file_name, cookies, headers))
|
|
if end + 1 < file_size:
|
|
tasks.append((file_path, progress_callbacks, end + 1, file_size - 1,
|
|
url, file_name, cookies, headers))
|
|
parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
|
|
# download every part
|
|
with ThreadPoolExecutor(
|
|
max_workers=parallels, thread_name_prefix='download') as executor:
|
|
list(executor.map(download_part_with_retry, tasks))
|
|
for callback in progress_callbacks:
|
|
callback.end()
|
|
# merge parts.
|
|
hash_sha256 = hashlib.sha256()
|
|
with open(os.path.join(local_dir, file_name), 'wb') as output_file:
|
|
for task in tasks:
|
|
part_file_name = task[0] + '_%s_%s' % (task[2], task[3])
|
|
with open(part_file_name, 'rb') as part_file:
|
|
while True:
|
|
chunk = part_file.read(16 * API_FILE_DOWNLOAD_CHUNK_SIZE)
|
|
if not chunk:
|
|
break
|
|
output_file.write(chunk)
|
|
hash_sha256.update(chunk)
|
|
os.remove(part_file_name)
|
|
return hash_sha256.hexdigest()
|
|
|
|
|
|
def http_get_model_file(
|
|
url: str,
|
|
local_dir: str,
|
|
file_name: str,
|
|
file_size: int,
|
|
cookies: CookieJar,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
disable_tqdm: bool = False,
|
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
|
):
|
|
"""Download remote file, will retry 5 times before giving up on errors.
|
|
|
|
Args:
|
|
url(str):
|
|
actual download url of the file
|
|
local_dir(str):
|
|
local directory where the downloaded file stores
|
|
file_name(str):
|
|
name of the file stored in `local_dir`
|
|
file_size(int):
|
|
The file size.
|
|
cookies(CookieJar):
|
|
cookies used to authentication the user, which is used for downloading private repos
|
|
headers(Dict[str, str], optional):
|
|
http headers to carry necessary info when requesting the remote file
|
|
disable_tqdm(bool, optional): Disable the progress bar with tqdm.
|
|
progress_callbacks(List[Type[ProgressCallback]], optional):
|
|
progress callbacks to track the download progress.
|
|
|
|
Raises:
|
|
FileDownloadError: File download failed.
|
|
|
|
"""
|
|
progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
|
|
)
|
|
if not disable_tqdm:
|
|
progress_callbacks.append(TqdmCallback)
|
|
progress_callbacks = [
|
|
callback(file_name, file_size) for callback in progress_callbacks
|
|
]
|
|
get_headers = {} if headers is None else copy.deepcopy(headers)
|
|
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
|
temp_file_path = os.path.join(local_dir, file_name)
|
|
os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
|
|
logger.debug('downloading %s to %s', url, temp_file_path)
|
|
# retry sleep 0.5s, 1s, 2s, 4s
|
|
has_retry = False
|
|
hash_sha256 = hashlib.sha256()
|
|
retry = Retry(
|
|
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
|
backoff_factor=1,
|
|
allowed_methods=['GET'])
|
|
|
|
while True:
|
|
try:
|
|
if file_size == 0:
|
|
# Avoid empty file server request
|
|
with open(temp_file_path, 'w+'):
|
|
for callback in progress_callbacks:
|
|
callback.update(1)
|
|
break
|
|
# Determine the length of any existing partial download
|
|
partial_length = 0
|
|
# download partial, continue download
|
|
if os.path.exists(temp_file_path):
|
|
# resuming from interrupted download is also considered as retry
|
|
has_retry = True
|
|
with open(temp_file_path, 'rb') as f:
|
|
partial_length = f.seek(0, io.SEEK_END)
|
|
for callback in progress_callbacks:
|
|
callback.update(partial_length)
|
|
|
|
# Check if download is complete
|
|
if partial_length >= file_size:
|
|
break
|
|
# closed range[], from 0.
|
|
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
|
file_size - 1)
|
|
with open(temp_file_path, 'ab+') as f:
|
|
r = requests.get(
|
|
url,
|
|
stream=True,
|
|
headers=get_headers,
|
|
cookies=cookies,
|
|
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
|
r.raise_for_status()
|
|
for chunk in r.iter_content(
|
|
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
|
if chunk: # filter out keep-alive new chunks
|
|
for callback in progress_callbacks:
|
|
callback.update(len(chunk))
|
|
f.write(chunk)
|
|
# hash would be discarded in retry case anyway
|
|
if not has_retry:
|
|
hash_sha256.update(chunk)
|
|
break
|
|
except Exception as e: # no matter what happen, we will retry.
|
|
has_retry = True
|
|
retry = retry.increment('GET', url, error=e)
|
|
retry.sleep()
|
|
for callback in progress_callbacks:
|
|
callback.end()
|
|
# if anything went wrong, we would discard the real-time computed hash and return None
|
|
return None if has_retry else hash_sha256.hexdigest()
|
|
|
|
|
|
def http_get_file(
|
|
url: str,
|
|
local_dir: str,
|
|
file_name: str,
|
|
cookies: CookieJar,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
):
|
|
"""Download remote file, will retry 5 times before giving up on errors.
|
|
|
|
Args:
|
|
url(str):
|
|
actual download url of the file
|
|
local_dir(str):
|
|
local directory where the downloaded file stores
|
|
file_name(str):
|
|
name of the file stored in `local_dir`
|
|
cookies(CookieJar):
|
|
cookies used to authentication the user, which is used for downloading private repos
|
|
headers(Dict[str, str], optional):
|
|
http headers to carry necessary info when requesting the remote file
|
|
|
|
Raises:
|
|
FileDownloadError: File download failed.
|
|
|
|
"""
|
|
total = -1
|
|
temp_file_manager = partial(
|
|
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
|
|
get_headers = {} if headers is None else copy.deepcopy(headers)
|
|
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
|
with temp_file_manager() as temp_file:
|
|
logger.debug('downloading %s to %s', url, temp_file.name)
|
|
# retry sleep 0.5s, 1s, 2s, 4s
|
|
retry = Retry(
|
|
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
|
backoff_factor=1,
|
|
allowed_methods=['GET'])
|
|
while True:
|
|
try:
|
|
downloaded_size = temp_file.tell()
|
|
get_headers['Range'] = 'bytes=%d-' % downloaded_size
|
|
r = requests.get(
|
|
url,
|
|
stream=True,
|
|
headers=get_headers,
|
|
cookies=cookies,
|
|
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
|
r.raise_for_status()
|
|
content_length = r.headers.get('Content-Length')
|
|
total = int(
|
|
content_length) if content_length is not None else None
|
|
progress = tqdm(
|
|
unit='B',
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
total=total,
|
|
initial=downloaded_size,
|
|
desc='Downloading [' + file_name + ']',
|
|
)
|
|
for chunk in r.iter_content(
|
|
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
|
if chunk: # filter out keep-alive new chunks
|
|
progress.update(len(chunk))
|
|
temp_file.write(chunk)
|
|
progress.close()
|
|
break
|
|
except (Exception) as e: # no matter what happen, we will retry.
|
|
retry = retry.increment('GET', url, error=e)
|
|
retry.sleep()
|
|
|
|
logger.debug('storing %s in cache at %s', url, local_dir)
|
|
downloaded_length = os.path.getsize(temp_file.name)
|
|
if total != downloaded_length:
|
|
os.remove(temp_file.name)
|
|
msg = 'File %s download incomplete, content_length: %s but the \
|
|
file downloaded length: %s, please download again' % (
|
|
file_name, total, downloaded_length)
|
|
logger.error(msg)
|
|
raise FileDownloadError(msg)
|
|
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
|
|
|
|
|
def download_file(
|
|
url,
|
|
file_meta,
|
|
temporary_cache_dir,
|
|
cache,
|
|
headers,
|
|
cookies,
|
|
disable_tqdm=False,
|
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
|
):
|
|
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
|
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
|
file_digest = parallel_download(
|
|
url,
|
|
temporary_cache_dir,
|
|
file_meta['Path'],
|
|
headers=headers,
|
|
cookies=None if cookies is None else cookies.get_dict(),
|
|
file_size=file_meta['Size'],
|
|
disable_tqdm=disable_tqdm,
|
|
progress_callbacks=progress_callbacks,
|
|
)
|
|
else:
|
|
file_digest = http_get_model_file(
|
|
url,
|
|
temporary_cache_dir,
|
|
file_meta['Path'],
|
|
file_size=file_meta['Size'],
|
|
headers=headers,
|
|
cookies=cookies,
|
|
disable_tqdm=disable_tqdm,
|
|
progress_callbacks=progress_callbacks,
|
|
)
|
|
|
|
# check file integrity
|
|
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
|
|
if FILE_HASH in file_meta:
|
|
expected_hash = file_meta[FILE_HASH]
|
|
# if a real-time hash has been computed
|
|
if file_digest is not None:
|
|
# if real-time hash mismatched, try to compute it again
|
|
if file_digest != expected_hash:
|
|
print(
|
|
'Mismatched real-time digest found, falling back to lump-sum hash computation'
|
|
)
|
|
file_integrity_validation(temp_file, expected_hash)
|
|
else:
|
|
file_integrity_validation(temp_file, expected_hash)
|
|
# put file into to cache
|
|
return cache.put_file(file_meta, temp_file)
|