mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
restore http_get_file interface
This commit is contained in:
@@ -3,9 +3,11 @@
|
||||
import copy
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import urllib
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
@@ -22,7 +24,7 @@ from modelscope.hub.constants import (
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.file_utils import get_model_cache_root
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .errors import NotExistError
|
||||
from .errors import FileDownloadError, NotExistError
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||
model_id_to_group_owner_name)
|
||||
@@ -143,7 +145,7 @@ def model_file_download(
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_to_download_info['Size'])
|
||||
else:
|
||||
http_get_file(
|
||||
http_get_model_file(
|
||||
url_to_download,
|
||||
temporary_cache_dir,
|
||||
file_path,
|
||||
@@ -290,7 +292,7 @@ def parallel_download(
|
||||
os.remove(part_file_name)
|
||||
|
||||
|
||||
def http_get_file(
|
||||
def http_get_model_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
@@ -367,3 +369,85 @@ def http_get_file(
|
||||
retry.sleep()
|
||||
|
||||
logger.debug('storing %s in cache at %s', url, local_dir)
|
||||
|
||||
|
||||
def http_get_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Download remote file, will retry 5 times before giving up on errors.
|
||||
|
||||
Args:
|
||||
url(str):
|
||||
actual download url of the file
|
||||
local_dir(str):
|
||||
local directory where the downloaded file stores
|
||||
file_name(str):
|
||||
name of the file stored in `local_dir`
|
||||
cookies(CookieJar):
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(Dict[str, str], optional):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
|
||||
Raises:
|
||||
FileDownloadError: File download failed.
|
||||
|
||||
"""
|
||||
total = -1
|
||||
temp_file_manager = partial(
|
||||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.debug('downloading %s to %s', url, temp_file.name)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
while True:
|
||||
try:
|
||||
downloaded_size = temp_file.tell()
|
||||
get_headers['Range'] = 'bytes=%d-' % downloaded_size
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
headers=get_headers,
|
||||
cookies=cookies,
|
||||
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
||||
r.raise_for_status()
|
||||
content_length = r.headers.get('Content-Length')
|
||||
total = int(
|
||||
content_length) if content_length is not None else None
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=downloaded_size,
|
||||
desc='Downloading',
|
||||
)
|
||||
for chunk in r.iter_content(
|
||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
break
|
||||
except (Exception) as e: # no matter what happen, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
|
||||
logger.debug('storing %s in cache at %s', url, local_dir)
|
||||
downloaded_length = os.path.getsize(temp_file.name)
|
||||
if total != downloaded_length:
|
||||
os.remove(temp_file.name)
|
||||
msg = 'File %s download incomplete, content_length: %s but the \
|
||||
file downloaded length: %s, please download again' % (
|
||||
file_name, total, downloaded_length)
|
||||
logger.error(msg)
|
||||
raise FileDownloadError(msg)
|
||||
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
||||
|
||||
@@ -12,7 +12,7 @@ from modelscope.utils.logger import get_logger
|
||||
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
||||
from .file_download import (create_temporary_directory_and_cache,
|
||||
get_file_download_url, http_get_file,
|
||||
get_file_download_url, http_get_model_file,
|
||||
parallel_download)
|
||||
from .utils.utils import file_integrity_validation
|
||||
|
||||
@@ -156,7 +156,7 @@ def snapshot_download(
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=model_file['Size'])
|
||||
else:
|
||||
http_get_file(
|
||||
http_get_model_file(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
model_file['Name'],
|
||||
|
||||
@@ -9,7 +9,7 @@ import requests
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.file_download import http_get_file
|
||||
from modelscope.hub.file_download import http_get_model_file
|
||||
|
||||
|
||||
class HubOperationTest(unittest.TestCase):
|
||||
@@ -109,7 +109,7 @@ class HubOperationTest(unittest.TestCase):
|
||||
success_rsp,
|
||||
]
|
||||
url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
|
||||
http_get_file(
|
||||
http_get_model_file(
|
||||
url=url,
|
||||
local_dir='./',
|
||||
file_name=test_file_name,
|
||||
@@ -151,7 +151,7 @@ class HubOperationTest(unittest.TestCase):
|
||||
]
|
||||
url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
|
||||
with self.assertRaises(MaxRetryError):
|
||||
http_get_file(
|
||||
http_get_model_file(
|
||||
url=url,
|
||||
local_dir='./',
|
||||
file_name=test_file_name,
|
||||
|
||||
Reference in New Issue
Block a user