diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index 9b443b71..9eb732da 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -20,7 +20,7 @@ API_HTTP_CLIENT_TIMEOUT = 60 API_RESPONSE_FIELD_DATA = 'Data' API_FILE_DOWNLOAD_RETRY_TIMES = 5 API_FILE_DOWNLOAD_TIMEOUT = 60 -API_FILE_DOWNLOAD_CHUNK_SIZE = 1024 * 1024 * 16 +API_FILE_DOWNLOAD_CHUNK_SIZE = 1024 * 1024 * 1 API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' API_RESPONSE_FIELD_USERNAME = 'Username' API_RESPONSE_FIELD_EMAIL = 'Email' diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index c925f306..94ced672 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -1,12 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import copy +import io import os -import tempfile import urllib import uuid from concurrent.futures import ThreadPoolExecutor -from functools import partial from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, Optional, Union @@ -23,7 +22,7 @@ from modelscope.hub.constants import ( from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.file_utils import get_model_cache_root from modelscope.utils.logger import get_logger -from .errors import FileDownloadError, NotExistError +from .errors import NotExistError from .utils.caching import ModelFileSystemCache from .utils.utils import (file_integrity_validation, get_endpoint, model_id_to_group_owner_name) @@ -79,11 +78,9 @@ def model_file_download( cache_dir = get_model_cache_root() if isinstance(cache_dir, Path): cache_dir = str(cache_dir) - temporary_cache_dir = os.path.join(cache_dir, 'temp') - os.makedirs(temporary_cache_dir, exist_ok=True) - group_or_owner, name = model_id_to_group_owner_name(model_id) - + temporary_cache_dir = os.path.join(cache_dir, 'temp', group_or_owner, name) + os.makedirs(temporary_cache_dir, exist_ok=True) cache = ModelFileSystemCache(cache_dir, group_or_owner, name) # if local_files_only is `True` and the file already exists in cached_path @@ -139,14 +136,13 @@ def model_file_download( # we need to download again url_to_download = get_file_download_url(model_id, file_path, revision) - temp_file_name = next(tempfile._get_candidate_names()) if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_to_download_info[ 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: parallel_download( url_to_download, temporary_cache_dir, - temp_file_name, + file_path, headers=headers, cookies=None if cookies is None else cookies.get_dict(), file_size=file_to_download_info['Size']) @@ -154,17 +150,18 @@ def model_file_download( http_get_file( url_to_download, temporary_cache_dir, - temp_file_name, + file_path, + file_size=file_to_download_info['Size'], headers=headers, cookies=None if cookies is None else cookies.get_dict()) - temp_file_path = os.path.join(temporary_cache_dir, temp_file_name) + temp_file_path = os.path.join(temporary_cache_dir, file_path) # for download with commit we can't get Sha256 if file_to_download_info[FILE_HASH] is not None: file_integrity_validation(temp_file_path, file_to_download_info[FILE_HASH]) return cache.put_file(file_to_download_info, - os.path.join(temporary_cache_dir, temp_file_name)) + os.path.join(temporary_cache_dir, file_path)) def get_file_download_url(model_id: str, file_path: str, revision: str): @@ -193,18 +190,27 @@ def get_file_download_url(model_id: str, file_path: str, revision: str): def download_part_with_retry(params): # unpack parameters - model_file_name, progress, start, end, url, file_name, cookies, headers = params + model_file_path, progress, start, end, url, file_name, cookies, headers = params get_headers = {} if headers is None else copy.deepcopy(headers) - get_headers['Range'] = 'bytes=%s-%s' % (start, end) get_headers['X-Request-ID'] = str(uuid.uuid4().hex) retry = Retry( total=API_FILE_DOWNLOAD_RETRY_TIMES, backoff_factor=1, allowed_methods=['GET']) + part_file_name = model_file_path + '_%s_%s' % (start, end) while True: try: - with open(file_name, 'rb+') as f: - f.seek(start) + partial_length = 0 + if os.path.exists( + part_file_name): # download partial, continue download + with open(part_file_name, 'rb') as f: + partial_length = f.seek(0, io.SEEK_END) + progress.update(partial_length) + start = start + partial_length + if start > end: + break # this part is download completed. + get_headers['Range'] = 'bytes=%s-%s' % (start, end) + with open(part_file_name, 'ab+') as f: r = requests.get( url, stream=True, @@ -215,12 +221,12 @@ def download_part_with_retry(params): chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE): if chunk: # filter out keep-alive new chunks f.write(chunk) - progress.update(end - start) + progress.update(len(chunk)) break except (Exception) as e: # no matter what exception, we will retry. retry = retry.increment('GET', url, error=e) logger.warning('Downloading: %s failed, reason: %s will retry' % - (model_file_name, e)) + (model_file_path, e)) retry.sleep() @@ -233,42 +239,46 @@ def parallel_download( file_size: int = None, ): # create temp file - temp_file_manager = partial( - tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) - with temp_file_manager() as temp_file: - progress = tqdm( - unit='B', - unit_scale=True, - unit_divisor=1024, - total=file_size, - initial=0, - desc='Downloading', - ) - PART_SIZE = 160 * 1024 * 1024 # every part is 160M - tasks = [] - for idx in range(int(file_size / PART_SIZE)): - start = idx * PART_SIZE - end = (idx + 1) * PART_SIZE - 1 - tasks.append((file_name, progress, start, end, url, temp_file.name, - cookies, headers)) - if end + 1 < file_size: - tasks.append((file_name, progress, end + 1, file_size - 1, url, - temp_file.name, cookies, headers)) - parallels = MODELSCOPE_DOWNLOAD_PARALLELS if MODELSCOPE_DOWNLOAD_PARALLELS <= 4 else 4 - with ThreadPoolExecutor( - max_workers=parallels, - thread_name_prefix='download') as executor: - list(executor.map(download_part_with_retry, tasks)) - progress.close() - - os.replace(temp_file.name, os.path.join(local_dir, file_name)) + progress = tqdm( + unit='B', + unit_scale=True, + unit_divisor=1024, + total=file_size, + initial=0, + desc='Downloading', + ) + PART_SIZE = 160 * 1024 * 1024 # every part is 160M + tasks = [] + file_path = os.path.join(local_dir, file_name) + for idx in range(int(file_size / PART_SIZE)): + start = idx * PART_SIZE + end = (idx + 1) * PART_SIZE - 1 + tasks.append((file_path, progress, start, end, url, file_name, cookies, + headers)) + if end + 1 < file_size: + tasks.append((file_path, progress, end + 1, file_size - 1, url, + file_name, cookies, headers)) + parallels = MODELSCOPE_DOWNLOAD_PARALLELS if MODELSCOPE_DOWNLOAD_PARALLELS <= 4 else 4 + # download every part + with ThreadPoolExecutor( + max_workers=parallels, thread_name_prefix='download') as executor: + list(executor.map(download_part_with_retry, tasks)) + progress.close() + # merge parts. + with open(os.path.join(local_dir, file_name), 'wb') as output_file: + for task in tasks: + part_file_name = task[0] + '_%s_%s' % (task[2], task[3]) + with open(part_file_name, 'rb') as part_file: + output_file.write(part_file.read()) + os.remove(part_file_name) def http_get_file( url: str, local_dir: str, file_name: str, + file_size: int, cookies: CookieJar, headers: Optional[Dict[str, str]] = None, ): @@ -281,6 +291,8 @@ def http_get_file( local directory where the downloaded file stores file_name(str): name of the file stored in `local_dir` + file_size(int): + The file size. cookies(CookieJar): cookies used to authentication the user, which is used for downloading private repos headers(Dict[str, str], optional): @@ -290,22 +302,36 @@ def http_get_file( FileDownloadError: File download failed. """ - total = -1 - temp_file_manager = partial( - tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) get_headers = {} if headers is None else copy.deepcopy(headers) get_headers['X-Request-ID'] = str(uuid.uuid4().hex) - with temp_file_manager() as temp_file: - logger.debug('downloading %s to %s', url, temp_file.name) - # retry sleep 0.5s, 1s, 2s, 4s - retry = Retry( - total=API_FILE_DOWNLOAD_RETRY_TIMES, - backoff_factor=1, - allowed_methods=['GET']) - while True: - try: - downloaded_size = temp_file.tell() - get_headers['Range'] = 'bytes=%d-' % downloaded_size + temp_file_path = os.path.join(local_dir, file_name) + logger.debug('downloading %s to %s', url, temp_file_path) + # retry sleep 0.5s, 1s, 2s, 4s + retry = Retry( + total=API_FILE_DOWNLOAD_RETRY_TIMES, + backoff_factor=1, + allowed_methods=['GET']) + while True: + try: + progress = tqdm( + unit='B', + unit_scale=True, + unit_divisor=1024, + total=file_size, + initial=0, + desc='Downloading', + ) + partial_length = 0 + if os.path.exists( + temp_file_path): # download partial, continue download + with open(temp_file_path, 'rb') as f: + partial_length = f.seek(0, io.SEEK_END) + progress.update(partial_length) + if partial_length > file_size: + break + get_headers['Range'] = 'bytes=%s-%s' % (partial_length, + file_size - 1) + with open(temp_file_path, 'ab') as f: r = requests.get( url, stream=True, @@ -313,35 +339,15 @@ def http_get_file( cookies=cookies, timeout=API_FILE_DOWNLOAD_TIMEOUT) r.raise_for_status() - content_length = r.headers.get('Content-Length') - total = int( - content_length) if content_length is not None else None - progress = tqdm( - unit='B', - unit_scale=True, - unit_divisor=1024, - total=total, - initial=downloaded_size, - desc='Downloading', - ) for chunk in r.iter_content( chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) - temp_file.write(chunk) - progress.close() - break - except (Exception) as e: # no matter what happen, we will retry. - retry = retry.increment('GET', url, error=e) - retry.sleep() + f.write(chunk) + progress.close() + break + except (Exception) as e: # no matter what happen, we will retry. + retry = retry.increment('GET', url, error=e) + retry.sleep() logger.debug('storing %s in cache at %s', url, local_dir) - downloaded_length = os.path.getsize(temp_file.name) - if total != downloaded_length: - os.remove(temp_file.name) - msg = 'File %s download incomplete, content_length: %s but the \ - file downloaded length: %s, please download again' % ( - file_name, total, downloaded_length) - logger.error(msg) - raise FileDownloadError(msg) - os.replace(temp_file.name, os.path.join(local_dir, file_name)) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 128a251d..6ce306f3 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -2,7 +2,6 @@ import os import re -import tempfile from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, List, Optional, Union @@ -22,13 +21,15 @@ from .utils.utils import (file_integrity_validation, logger = get_logger() -def snapshot_download(model_id: str, - revision: Optional[str] = DEFAULT_MODEL_REVISION, - cache_dir: Union[str, Path, None] = None, - user_agent: Optional[Union[Dict, str]] = None, - local_files_only: Optional[bool] = False, - cookies: Optional[CookieJar] = None, - ignore_file_pattern: List = None) -> str: +def snapshot_download( + model_id: str, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cache_dir: Union[str, Path, None] = None, + user_agent: Optional[Union[Dict, str]] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, + ignore_file_pattern: List = None, +) -> str: """Download all files of a repo. Downloads a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from a repo, because you don't know which @@ -69,10 +70,9 @@ def snapshot_download(model_id: str, cache_dir = get_model_cache_root() if isinstance(cache_dir, Path): cache_dir = str(cache_dir) - temporary_cache_dir = os.path.join(cache_dir, 'temp') - os.makedirs(temporary_cache_dir, exist_ok=True) - group_or_owner, name = model_id_to_group_owner_name(model_id) + temporary_cache_dir = os.path.join(cache_dir, 'temp', group_or_owner, name) + os.makedirs(temporary_cache_dir, exist_ok=True) name = name.replace('.', '___') cache = ModelFileSystemCache(cache_dir, group_or_owner, name) @@ -123,50 +123,48 @@ def snapshot_download(model_id: str, if isinstance(ignore_file_pattern, str): ignore_file_pattern = [ignore_file_pattern] - with tempfile.TemporaryDirectory( - dir=temporary_cache_dir) as temp_cache_dir: - for model_file in model_files: - if model_file['Type'] == 'tree' or \ - any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]): - continue - # check model_file is exist in cache, if existed, skip download, otherwise download - if cache.exists(model_file): - file_name = os.path.basename(model_file['Name']) - logger.debug( - f'File {file_name} already in cache, skip downloading!' - ) - continue + for model_file in model_files: + if model_file['Type'] == 'tree' or \ + any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]): + continue - # get download url - url = get_file_download_url( - model_id=model_id, - file_path=model_file['Path'], - revision=revision) + # check model_file is exist in cache, if existed, skip download, otherwise download + if cache.exists(model_file): + file_name = os.path.basename(model_file['Name']) + logger.debug( + f'File {file_name} already in cache, skip downloading!') + continue - if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[ - 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: - parallel_download( - url, - temp_cache_dir, - model_file['Name'], - headers=headers, - cookies=None - if cookies is None else cookies.get_dict(), - file_size=model_file['Size']) - else: - http_get_file( - url, - temp_cache_dir, - model_file['Name'], - headers=headers, - cookies=cookies) + # get download url + url = get_file_download_url( + model_id=model_id, + file_path=model_file['Path'], + revision=revision) - # check file integrity - temp_file = os.path.join(temp_cache_dir, model_file['Name']) - if FILE_HASH in model_file: - file_integrity_validation(temp_file, model_file[FILE_HASH]) - # put file into to cache - cache.put_file(model_file, temp_file) + if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[ + 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: + parallel_download( + url, + temporary_cache_dir, + model_file['Name'], + headers=headers, + cookies=None if cookies is None else cookies.get_dict(), + file_size=model_file['Size']) + else: + http_get_file( + url, + temporary_cache_dir, + model_file['Name'], + file_size=model_file['Size'], + headers=headers, + cookies=cookies) + + # check file integrity + temp_file = os.path.join(temporary_cache_dir, model_file['Name']) + if FILE_HASH in model_file: + file_integrity_validation(temp_file, model_file[FILE_HASH]) + # put file into to cache + cache.put_file(model_file, temp_file) cache.save_model_version(revision_info=revision_detail) return os.path.join(cache.get_root_location()) diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index 64d9f5bb..9d0fe660 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -72,6 +72,7 @@ def file_integrity_validation(file_path, expected_sha256): file_sha256 = compute_hash(file_path) if not file_sha256 == expected_sha256: os.remove(file_path) - msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path + msg = 'File %s integrity check failed, expected sha256 signature is %s, actual is %s, the download may be incomplete, please try again.' % ( # noqa E501 + file_path, expected_sha256, file_sha256) logger.error(msg) raise FileIntegrityError(msg) diff --git a/tests/hub/test_hub_retry.py b/tests/hub/test_hub_retry.py index e294cb68..149e825a 100644 --- a/tests/hub/test_hub_retry.py +++ b/tests/hub/test_hub_retry.py @@ -113,6 +113,7 @@ class HubOperationTest(unittest.TestCase): url=url, local_dir='./', file_name=test_file_name, + file_size=2957783, headers={}, cookies=None) @@ -154,10 +155,11 @@ class HubOperationTest(unittest.TestCase): url=url, local_dir='./', file_name=test_file_name, + file_size=2957783, headers={}, cookies=None) - assert not os.path.exists('./%s' % test_file_name) + assert os.stat('./%s' % test_file_name).st_size == 0 if __name__ == '__main__':