mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Download optimize (#862)
* fix #845 Supports resumption of downloads from breakpoints, optimized download progress bar, finer display granularity, better experience under low bandwidth, and added function of downloading specified directories. * restore push to hub * fix merge issue * fix ut issue --------- Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
@@ -20,7 +20,7 @@ API_HTTP_CLIENT_TIMEOUT = 60
|
||||
API_RESPONSE_FIELD_DATA = 'Data'
|
||||
API_FILE_DOWNLOAD_RETRY_TIMES = 5
|
||||
API_FILE_DOWNLOAD_TIMEOUT = 60
|
||||
API_FILE_DOWNLOAD_CHUNK_SIZE = 1024 * 1024 * 16
|
||||
API_FILE_DOWNLOAD_CHUNK_SIZE = 1024 * 1024 * 1
|
||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken'
|
||||
API_RESPONSE_FIELD_USERNAME = 'Username'
|
||||
API_RESPONSE_FIELD_EMAIL = 'Email'
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import urllib
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
@@ -23,7 +22,7 @@ from modelscope.hub.constants import (
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.file_utils import get_model_cache_root
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .errors import FileDownloadError, NotExistError
|
||||
from .errors import NotExistError
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||
model_id_to_group_owner_name)
|
||||
@@ -79,11 +78,9 @@ def model_file_download(
|
||||
cache_dir = get_model_cache_root()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||
os.makedirs(temporary_cache_dir, exist_ok=True)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp', group_or_owner, name)
|
||||
os.makedirs(temporary_cache_dir, exist_ok=True)
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
|
||||
# if local_files_only is `True` and the file already exists in cached_path
|
||||
@@ -139,14 +136,13 @@ def model_file_download(
|
||||
|
||||
# we need to download again
|
||||
url_to_download = get_file_download_url(model_id, file_path, revision)
|
||||
temp_file_name = next(tempfile._get_candidate_names())
|
||||
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_to_download_info[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
|
||||
parallel_download(
|
||||
url_to_download,
|
||||
temporary_cache_dir,
|
||||
temp_file_name,
|
||||
file_path,
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_to_download_info['Size'])
|
||||
@@ -154,17 +150,18 @@ def model_file_download(
|
||||
http_get_file(
|
||||
url_to_download,
|
||||
temporary_cache_dir,
|
||||
temp_file_name,
|
||||
file_path,
|
||||
file_size=file_to_download_info['Size'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict())
|
||||
|
||||
temp_file_path = os.path.join(temporary_cache_dir, temp_file_name)
|
||||
temp_file_path = os.path.join(temporary_cache_dir, file_path)
|
||||
# for download with commit we can't get Sha256
|
||||
if file_to_download_info[FILE_HASH] is not None:
|
||||
file_integrity_validation(temp_file_path,
|
||||
file_to_download_info[FILE_HASH])
|
||||
return cache.put_file(file_to_download_info,
|
||||
os.path.join(temporary_cache_dir, temp_file_name))
|
||||
os.path.join(temporary_cache_dir, file_path))
|
||||
|
||||
|
||||
def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||
@@ -193,18 +190,27 @@ def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||
|
||||
def download_part_with_retry(params):
|
||||
# unpack parameters
|
||||
model_file_name, progress, start, end, url, file_name, cookies, headers = params
|
||||
model_file_path, progress, start, end, url, file_name, cookies, headers = params
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (start, end)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
retry = Retry(
|
||||
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
part_file_name = model_file_path + '_%s_%s' % (start, end)
|
||||
while True:
|
||||
try:
|
||||
with open(file_name, 'rb+') as f:
|
||||
f.seek(start)
|
||||
partial_length = 0
|
||||
if os.path.exists(
|
||||
part_file_name): # download partial, continue download
|
||||
with open(part_file_name, 'rb') as f:
|
||||
partial_length = f.seek(0, io.SEEK_END)
|
||||
progress.update(partial_length)
|
||||
start = start + partial_length
|
||||
if start > end:
|
||||
break # this part is download completed.
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (start, end)
|
||||
with open(part_file_name, 'ab+') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
@@ -215,12 +221,12 @@ def download_part_with_retry(params):
|
||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
progress.update(end - start)
|
||||
progress.update(len(chunk))
|
||||
break
|
||||
except (Exception) as e: # no matter what exception, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
logger.warning('Downloading: %s failed, reason: %s will retry' %
|
||||
(model_file_name, e))
|
||||
(model_file_path, e))
|
||||
retry.sleep()
|
||||
|
||||
|
||||
@@ -233,42 +239,46 @@ def parallel_download(
|
||||
file_size: int = None,
|
||||
):
|
||||
# create temp file
|
||||
temp_file_manager = partial(
|
||||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
|
||||
with temp_file_manager() as temp_file:
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size,
|
||||
initial=0,
|
||||
desc='Downloading',
|
||||
)
|
||||
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
||||
tasks = []
|
||||
for idx in range(int(file_size / PART_SIZE)):
|
||||
start = idx * PART_SIZE
|
||||
end = (idx + 1) * PART_SIZE - 1
|
||||
tasks.append((file_name, progress, start, end, url, temp_file.name,
|
||||
cookies, headers))
|
||||
if end + 1 < file_size:
|
||||
tasks.append((file_name, progress, end + 1, file_size - 1, url,
|
||||
temp_file.name, cookies, headers))
|
||||
parallels = MODELSCOPE_DOWNLOAD_PARALLELS if MODELSCOPE_DOWNLOAD_PARALLELS <= 4 else 4
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=parallels,
|
||||
thread_name_prefix='download') as executor:
|
||||
list(executor.map(download_part_with_retry, tasks))
|
||||
|
||||
progress.close()
|
||||
|
||||
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size,
|
||||
initial=0,
|
||||
desc='Downloading',
|
||||
)
|
||||
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
||||
tasks = []
|
||||
file_path = os.path.join(local_dir, file_name)
|
||||
for idx in range(int(file_size / PART_SIZE)):
|
||||
start = idx * PART_SIZE
|
||||
end = (idx + 1) * PART_SIZE - 1
|
||||
tasks.append((file_path, progress, start, end, url, file_name, cookies,
|
||||
headers))
|
||||
if end + 1 < file_size:
|
||||
tasks.append((file_path, progress, end + 1, file_size - 1, url,
|
||||
file_name, cookies, headers))
|
||||
parallels = MODELSCOPE_DOWNLOAD_PARALLELS if MODELSCOPE_DOWNLOAD_PARALLELS <= 4 else 4
|
||||
# download every part
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=parallels, thread_name_prefix='download') as executor:
|
||||
list(executor.map(download_part_with_retry, tasks))
|
||||
progress.close()
|
||||
# merge parts.
|
||||
with open(os.path.join(local_dir, file_name), 'wb') as output_file:
|
||||
for task in tasks:
|
||||
part_file_name = task[0] + '_%s_%s' % (task[2], task[3])
|
||||
with open(part_file_name, 'rb') as part_file:
|
||||
output_file.write(part_file.read())
|
||||
os.remove(part_file_name)
|
||||
|
||||
|
||||
def http_get_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
file_size: int,
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
@@ -281,6 +291,8 @@ def http_get_file(
|
||||
local directory where the downloaded file stores
|
||||
file_name(str):
|
||||
name of the file stored in `local_dir`
|
||||
file_size(int):
|
||||
The file size.
|
||||
cookies(CookieJar):
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(Dict[str, str], optional):
|
||||
@@ -290,22 +302,36 @@ def http_get_file(
|
||||
FileDownloadError: File download failed.
|
||||
|
||||
"""
|
||||
total = -1
|
||||
temp_file_manager = partial(
|
||||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.debug('downloading %s to %s', url, temp_file.name)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
while True:
|
||||
try:
|
||||
downloaded_size = temp_file.tell()
|
||||
get_headers['Range'] = 'bytes=%d-' % downloaded_size
|
||||
temp_file_path = os.path.join(local_dir, file_name)
|
||||
logger.debug('downloading %s to %s', url, temp_file_path)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
while True:
|
||||
try:
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size,
|
||||
initial=0,
|
||||
desc='Downloading',
|
||||
)
|
||||
partial_length = 0
|
||||
if os.path.exists(
|
||||
temp_file_path): # download partial, continue download
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
partial_length = f.seek(0, io.SEEK_END)
|
||||
progress.update(partial_length)
|
||||
if partial_length > file_size:
|
||||
break
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||
file_size - 1)
|
||||
with open(temp_file_path, 'ab') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
@@ -313,35 +339,15 @@ def http_get_file(
|
||||
cookies=cookies,
|
||||
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
||||
r.raise_for_status()
|
||||
content_length = r.headers.get('Content-Length')
|
||||
total = int(
|
||||
content_length) if content_length is not None else None
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=downloaded_size,
|
||||
desc='Downloading',
|
||||
)
|
||||
for chunk in r.iter_content(
|
||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
break
|
||||
except (Exception) as e: # no matter what happen, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
f.write(chunk)
|
||||
progress.close()
|
||||
break
|
||||
except (Exception) as e: # no matter what happen, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
|
||||
logger.debug('storing %s in cache at %s', url, local_dir)
|
||||
downloaded_length = os.path.getsize(temp_file.name)
|
||||
if total != downloaded_length:
|
||||
os.remove(temp_file.name)
|
||||
msg = 'File %s download incomplete, content_length: %s but the \
|
||||
file downloaded length: %s, please download again' % (
|
||||
file_name, total, downloaded_length)
|
||||
logger.error(msg)
|
||||
raise FileDownloadError(msg)
|
||||
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -22,13 +21,15 @@ from .utils.utils import (file_integrity_validation,
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def snapshot_download(model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: List = None) -> str:
|
||||
def snapshot_download(
|
||||
model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: List = None,
|
||||
) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
is useful when you want all files from a repo, because you don't know which
|
||||
@@ -69,10 +70,9 @@ def snapshot_download(model_id: str,
|
||||
cache_dir = get_model_cache_root()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||
os.makedirs(temporary_cache_dir, exist_ok=True)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp', group_or_owner, name)
|
||||
os.makedirs(temporary_cache_dir, exist_ok=True)
|
||||
name = name.replace('.', '___')
|
||||
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
@@ -123,50 +123,48 @@ def snapshot_download(model_id: str,
|
||||
if isinstance(ignore_file_pattern, str):
|
||||
ignore_file_pattern = [ignore_file_pattern]
|
||||
|
||||
with tempfile.TemporaryDirectory(
|
||||
dir=temporary_cache_dir) as temp_cache_dir:
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree' or \
|
||||
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
|
||||
continue
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
file_name = os.path.basename(model_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!'
|
||||
)
|
||||
continue
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree' or \
|
||||
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
|
||||
continue
|
||||
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
file_name = os.path.basename(model_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
|
||||
parallel_download(
|
||||
url,
|
||||
temp_cache_dir,
|
||||
model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=None
|
||||
if cookies is None else cookies.get_dict(),
|
||||
file_size=model_file['Size'])
|
||||
else:
|
||||
http_get_file(
|
||||
url,
|
||||
temp_cache_dir,
|
||||
model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temp_cache_dir, model_file['Name'])
|
||||
if FILE_HASH in model_file:
|
||||
file_integrity_validation(temp_file, model_file[FILE_HASH])
|
||||
# put file into to cache
|
||||
cache.put_file(model_file, temp_file)
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
|
||||
parallel_download(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=model_file['Size'])
|
||||
else:
|
||||
http_get_file(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
model_file['Name'],
|
||||
file_size=model_file['Size'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temporary_cache_dir, model_file['Name'])
|
||||
if FILE_HASH in model_file:
|
||||
file_integrity_validation(temp_file, model_file[FILE_HASH])
|
||||
# put file into to cache
|
||||
cache.put_file(model_file, temp_file)
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
return os.path.join(cache.get_root_location())
|
||||
|
||||
@@ -72,6 +72,7 @@ def file_integrity_validation(file_path, expected_sha256):
|
||||
file_sha256 = compute_hash(file_path)
|
||||
if not file_sha256 == expected_sha256:
|
||||
os.remove(file_path)
|
||||
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
|
||||
msg = 'File %s integrity check failed, expected sha256 signature is %s, actual is %s, the download may be incomplete, please try again.' % ( # noqa E501
|
||||
file_path, expected_sha256, file_sha256)
|
||||
logger.error(msg)
|
||||
raise FileIntegrityError(msg)
|
||||
|
||||
@@ -113,6 +113,7 @@ class HubOperationTest(unittest.TestCase):
|
||||
url=url,
|
||||
local_dir='./',
|
||||
file_name=test_file_name,
|
||||
file_size=2957783,
|
||||
headers={},
|
||||
cookies=None)
|
||||
|
||||
@@ -154,10 +155,11 @@ class HubOperationTest(unittest.TestCase):
|
||||
url=url,
|
||||
local_dir='./',
|
||||
file_name=test_file_name,
|
||||
file_size=2957783,
|
||||
headers={},
|
||||
cookies=None)
|
||||
|
||||
assert not os.path.exists('./%s' % test_file_name)
|
||||
assert os.stat('./%s' % test_file_name).st_size == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user