Download optimize (#862)

* fix #845

Supports resumption of downloads from breakpoints, optimized download progress bar, finer display granularity, better experience under low bandwidth, and added function of downloading specified directories.

* restore push to hub

* fix merge issue

* fix ut issue

---------

Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
liuyhwangyh
2024-05-24 15:37:43 +08:00
committed by GitHub
parent f9451bfe38
commit 5c470f8941
5 changed files with 149 additions and 142 deletions

View File

@@ -20,7 +20,7 @@ API_HTTP_CLIENT_TIMEOUT = 60
API_RESPONSE_FIELD_DATA = 'Data'
API_FILE_DOWNLOAD_RETRY_TIMES = 5
API_FILE_DOWNLOAD_TIMEOUT = 60
API_FILE_DOWNLOAD_CHUNK_SIZE = 1024 * 1024 * 16
API_FILE_DOWNLOAD_CHUNK_SIZE = 1024 * 1024 * 1
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken'
API_RESPONSE_FIELD_USERNAME = 'Username'
API_RESPONSE_FIELD_EMAIL = 'Email'

View File

@@ -1,12 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import io
import os
import tempfile
import urllib
import uuid
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, Optional, Union
@@ -23,7 +22,7 @@ from modelscope.hub.constants import (
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.logger import get_logger
from .errors import FileDownloadError, NotExistError
from .errors import NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_endpoint,
model_id_to_group_owner_name)
@@ -79,11 +78,9 @@ def model_file_download(
cache_dir = get_model_cache_root()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, 'temp')
os.makedirs(temporary_cache_dir, exist_ok=True)
group_or_owner, name = model_id_to_group_owner_name(model_id)
temporary_cache_dir = os.path.join(cache_dir, 'temp', group_or_owner, name)
os.makedirs(temporary_cache_dir, exist_ok=True)
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
# if local_files_only is `True` and the file already exists in cached_path
@@ -139,14 +136,13 @@ def model_file_download(
# we need to download again
url_to_download = get_file_download_url(model_id, file_path, revision)
temp_file_name = next(tempfile._get_candidate_names())
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_to_download_info[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
parallel_download(
url_to_download,
temporary_cache_dir,
temp_file_name,
file_path,
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_to_download_info['Size'])
@@ -154,17 +150,18 @@ def model_file_download(
http_get_file(
url_to_download,
temporary_cache_dir,
temp_file_name,
file_path,
file_size=file_to_download_info['Size'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
temp_file_path = os.path.join(temporary_cache_dir, temp_file_name)
temp_file_path = os.path.join(temporary_cache_dir, file_path)
# for download with commit we can't get Sha256
if file_to_download_info[FILE_HASH] is not None:
file_integrity_validation(temp_file_path,
file_to_download_info[FILE_HASH])
return cache.put_file(file_to_download_info,
os.path.join(temporary_cache_dir, temp_file_name))
os.path.join(temporary_cache_dir, file_path))
def get_file_download_url(model_id: str, file_path: str, revision: str):
@@ -193,18 +190,27 @@ def get_file_download_url(model_id: str, file_path: str, revision: str):
def download_part_with_retry(params):
# unpack parameters
model_file_name, progress, start, end, url, file_name, cookies, headers = params
model_file_path, progress, start, end, url, file_name, cookies, headers = params
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['Range'] = 'bytes=%s-%s' % (start, end)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1,
allowed_methods=['GET'])
part_file_name = model_file_path + '_%s_%s' % (start, end)
while True:
try:
with open(file_name, 'rb+') as f:
f.seek(start)
partial_length = 0
if os.path.exists(
part_file_name): # download partial, continue download
with open(part_file_name, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length)
start = start + partial_length
if start > end:
break # this part is download completed.
get_headers['Range'] = 'bytes=%s-%s' % (start, end)
with open(part_file_name, 'ab+') as f:
r = requests.get(
url,
stream=True,
@@ -215,12 +221,12 @@ def download_part_with_retry(params):
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress.update(end - start)
progress.update(len(chunk))
break
except (Exception) as e: # no matter what exception, we will retry.
retry = retry.increment('GET', url, error=e)
logger.warning('Downloading: %s failed, reason: %s will retry' %
(model_file_name, e))
(model_file_path, e))
retry.sleep()
@@ -233,42 +239,46 @@ def parallel_download(
file_size: int = None,
):
# create temp file
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
with temp_file_manager() as temp_file:
progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size,
initial=0,
desc='Downloading',
)
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
tasks = []
for idx in range(int(file_size / PART_SIZE)):
start = idx * PART_SIZE
end = (idx + 1) * PART_SIZE - 1
tasks.append((file_name, progress, start, end, url, temp_file.name,
cookies, headers))
if end + 1 < file_size:
tasks.append((file_name, progress, end + 1, file_size - 1, url,
temp_file.name, cookies, headers))
parallels = MODELSCOPE_DOWNLOAD_PARALLELS if MODELSCOPE_DOWNLOAD_PARALLELS <= 4 else 4
with ThreadPoolExecutor(
max_workers=parallels,
thread_name_prefix='download') as executor:
list(executor.map(download_part_with_retry, tasks))
progress.close()
os.replace(temp_file.name, os.path.join(local_dir, file_name))
progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size,
initial=0,
desc='Downloading',
)
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
tasks = []
file_path = os.path.join(local_dir, file_name)
for idx in range(int(file_size / PART_SIZE)):
start = idx * PART_SIZE
end = (idx + 1) * PART_SIZE - 1
tasks.append((file_path, progress, start, end, url, file_name, cookies,
headers))
if end + 1 < file_size:
tasks.append((file_path, progress, end + 1, file_size - 1, url,
file_name, cookies, headers))
parallels = MODELSCOPE_DOWNLOAD_PARALLELS if MODELSCOPE_DOWNLOAD_PARALLELS <= 4 else 4
# download every part
with ThreadPoolExecutor(
max_workers=parallels, thread_name_prefix='download') as executor:
list(executor.map(download_part_with_retry, tasks))
progress.close()
# merge parts.
with open(os.path.join(local_dir, file_name), 'wb') as output_file:
for task in tasks:
part_file_name = task[0] + '_%s_%s' % (task[2], task[3])
with open(part_file_name, 'rb') as part_file:
output_file.write(part_file.read())
os.remove(part_file_name)
def http_get_file(
url: str,
local_dir: str,
file_name: str,
file_size: int,
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
):
@@ -281,6 +291,8 @@ def http_get_file(
local directory where the downloaded file stores
file_name(str):
name of the file stored in `local_dir`
file_size(int):
The file size.
cookies(CookieJar):
cookies used to authentication the user, which is used for downloading private repos
headers(Dict[str, str], optional):
@@ -290,22 +302,36 @@ def http_get_file(
FileDownloadError: File download failed.
"""
total = -1
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
with temp_file_manager() as temp_file:
logger.debug('downloading %s to %s', url, temp_file.name)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1,
allowed_methods=['GET'])
while True:
try:
downloaded_size = temp_file.tell()
get_headers['Range'] = 'bytes=%d-' % downloaded_size
temp_file_path = os.path.join(local_dir, file_name)
logger.debug('downloading %s to %s', url, temp_file_path)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1,
allowed_methods=['GET'])
while True:
try:
progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size,
initial=0,
desc='Downloading',
)
partial_length = 0
if os.path.exists(
temp_file_path): # download partial, continue download
with open(temp_file_path, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length)
if partial_length > file_size:
break
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
file_size - 1)
with open(temp_file_path, 'ab') as f:
r = requests.get(
url,
stream=True,
@@ -313,35 +339,15 @@ def http_get_file(
cookies=cookies,
timeout=API_FILE_DOWNLOAD_TIMEOUT)
r.raise_for_status()
content_length = r.headers.get('Content-Length')
total = int(
content_length) if content_length is not None else None
progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=total,
initial=downloaded_size,
desc='Downloading',
)
for chunk in r.iter_content(
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
break
except (Exception) as e: # no matter what happen, we will retry.
retry = retry.increment('GET', url, error=e)
retry.sleep()
f.write(chunk)
progress.close()
break
except (Exception) as e: # no matter what happen, we will retry.
retry = retry.increment('GET', url, error=e)
retry.sleep()
logger.debug('storing %s in cache at %s', url, local_dir)
downloaded_length = os.path.getsize(temp_file.name)
if total != downloaded_length:
os.remove(temp_file.name)
msg = 'File %s download incomplete, content_length: %s but the \
file downloaded length: %s, please download again' % (
file_name, total, downloaded_length)
logger.error(msg)
raise FileDownloadError(msg)
os.replace(temp_file.name, os.path.join(local_dir, file_name))

View File

@@ -2,7 +2,6 @@
import os
import re
import tempfile
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
@@ -22,13 +21,15 @@ from .utils.utils import (file_integrity_validation,
logger = get_logger()
def snapshot_download(model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: List = None) -> str:
def snapshot_download(
model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: List = None,
) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
is useful when you want all files from a repo, because you don't know which
@@ -69,10 +70,9 @@ def snapshot_download(model_id: str,
cache_dir = get_model_cache_root()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, 'temp')
os.makedirs(temporary_cache_dir, exist_ok=True)
group_or_owner, name = model_id_to_group_owner_name(model_id)
temporary_cache_dir = os.path.join(cache_dir, 'temp', group_or_owner, name)
os.makedirs(temporary_cache_dir, exist_ok=True)
name = name.replace('.', '___')
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
@@ -123,50 +123,48 @@ def snapshot_download(model_id: str,
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
with tempfile.TemporaryDirectory(
dir=temporary_cache_dir) as temp_cache_dir:
for model_file in model_files:
if model_file['Type'] == 'tree' or \
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
continue
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!'
)
continue
for model_file in model_files:
if model_file['Type'] == 'tree' or \
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
continue
# get download url
url = get_file_download_url(
model_id=model_id,
file_path=model_file['Path'],
revision=revision)
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
continue
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
parallel_download(
url,
temp_cache_dir,
model_file['Name'],
headers=headers,
cookies=None
if cookies is None else cookies.get_dict(),
file_size=model_file['Size'])
else:
http_get_file(
url,
temp_cache_dir,
model_file['Name'],
headers=headers,
cookies=cookies)
# get download url
url = get_file_download_url(
model_id=model_id,
file_path=model_file['Path'],
revision=revision)
# check file integrity
temp_file = os.path.join(temp_cache_dir, model_file['Name'])
if FILE_HASH in model_file:
file_integrity_validation(temp_file, model_file[FILE_HASH])
# put file into to cache
cache.put_file(model_file, temp_file)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
parallel_download(
url,
temporary_cache_dir,
model_file['Name'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=model_file['Size'])
else:
http_get_file(
url,
temporary_cache_dir,
model_file['Name'],
file_size=model_file['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, model_file['Name'])
if FILE_HASH in model_file:
file_integrity_validation(temp_file, model_file[FILE_HASH])
# put file into to cache
cache.put_file(model_file, temp_file)
cache.save_model_version(revision_info=revision_detail)
return os.path.join(cache.get_root_location())

View File

@@ -72,6 +72,7 @@ def file_integrity_validation(file_path, expected_sha256):
file_sha256 = compute_hash(file_path)
if not file_sha256 == expected_sha256:
os.remove(file_path)
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
msg = 'File %s integrity check failed, expected sha256 signature is %s, actual is %s, the download may be incomplete, please try again.' % ( # noqa E501
file_path, expected_sha256, file_sha256)
logger.error(msg)
raise FileIntegrityError(msg)

View File

@@ -113,6 +113,7 @@ class HubOperationTest(unittest.TestCase):
url=url,
local_dir='./',
file_name=test_file_name,
file_size=2957783,
headers={},
cookies=None)
@@ -154,10 +155,11 @@ class HubOperationTest(unittest.TestCase):
url=url,
local_dir='./',
file_name=test_file_name,
file_size=2957783,
headers={},
cookies=None)
assert not os.path.exists('./%s' % test_file_name)
assert os.stat('./%s' % test_file_name).st_size == 0
if __name__ == '__main__':