restore http_get_file interface

This commit is contained in:
mulin.lyh
2024-05-28 15:09:05 +08:00
parent e2d8a6d45f
commit 4a22b05891
3 changed files with 92 additions and 8 deletions

View File

@@ -3,9 +3,11 @@
import copy
import io
import os
import tempfile
import urllib
import uuid
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, Optional, Union
@@ -22,7 +24,7 @@ from modelscope.hub.constants import (
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.logger import get_logger
from .errors import NotExistError
from .errors import FileDownloadError, NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_endpoint,
model_id_to_group_owner_name)
@@ -143,7 +145,7 @@ def model_file_download(
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_to_download_info['Size'])
else:
http_get_file(
http_get_model_file(
url_to_download,
temporary_cache_dir,
file_path,
@@ -290,7 +292,7 @@ def parallel_download(
os.remove(part_file_name)
def http_get_file(
def http_get_model_file(
url: str,
local_dir: str,
file_name: str,
@@ -367,3 +369,85 @@ def http_get_file(
retry.sleep()
logger.debug('storing %s in cache at %s', url, local_dir)
def http_get_file(
url: str,
local_dir: str,
file_name: str,
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
):
"""Download remote file, will retry 5 times before giving up on errors.
Args:
url(str):
actual download url of the file
local_dir(str):
local directory where the downloaded file stores
file_name(str):
name of the file stored in `local_dir`
cookies(CookieJar):
cookies used to authentication the user, which is used for downloading private repos
headers(Dict[str, str], optional):
http headers to carry necessary info when requesting the remote file
Raises:
FileDownloadError: File download failed.
"""
total = -1
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
with temp_file_manager() as temp_file:
logger.debug('downloading %s to %s', url, temp_file.name)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1,
allowed_methods=['GET'])
while True:
try:
downloaded_size = temp_file.tell()
get_headers['Range'] = 'bytes=%d-' % downloaded_size
r = requests.get(
url,
stream=True,
headers=get_headers,
cookies=cookies,
timeout=API_FILE_DOWNLOAD_TIMEOUT)
r.raise_for_status()
content_length = r.headers.get('Content-Length')
total = int(
content_length) if content_length is not None else None
progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=total,
initial=downloaded_size,
desc='Downloading',
)
for chunk in r.iter_content(
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
break
except (Exception) as e: # no matter what happen, we will retry.
retry = retry.increment('GET', url, error=e)
retry.sleep()
logger.debug('storing %s in cache at %s', url, local_dir)
downloaded_length = os.path.getsize(temp_file.name)
if total != downloaded_length:
os.remove(temp_file.name)
msg = 'File %s download incomplete, content_length: %s but the \
file downloaded length: %s, please download again' % (
file_name, total, downloaded_length)
logger.error(msg)
raise FileDownloadError(msg)
os.replace(temp_file.name, os.path.join(local_dir, file_name))

View File

@@ -12,7 +12,7 @@ from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from .file_download import (create_temporary_directory_and_cache,
get_file_download_url, http_get_file,
get_file_download_url, http_get_model_file,
parallel_download)
from .utils.utils import file_integrity_validation
@@ -156,7 +156,7 @@ def snapshot_download(
cookies=None if cookies is None else cookies.get_dict(),
file_size=model_file['Size'])
else:
http_get_file(
http_get_model_file(
url,
temporary_cache_dir,
model_file['Name'],

View File

@@ -9,7 +9,7 @@ import requests
from urllib3.exceptions import MaxRetryError
from modelscope.hub.api import HubApi
from modelscope.hub.file_download import http_get_file
from modelscope.hub.file_download import http_get_model_file
class HubOperationTest(unittest.TestCase):
@@ -109,7 +109,7 @@ class HubOperationTest(unittest.TestCase):
success_rsp,
]
url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
http_get_file(
http_get_model_file(
url=url,
local_dir='./',
file_name=test_file_name,
@@ -151,7 +151,7 @@ class HubOperationTest(unittest.TestCase):
]
url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
with self.assertRaises(MaxRetryError):
http_get_file(
http_get_model_file(
url=url,
local_dir='./',
file_name=test_file_name,