mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #43887377]fix: sdk api concurrent call snapshort download file will conflict
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9672696 * [to #43887377]fix: sdk api concurrent call snapshort download file will conflict
This commit is contained in:
@@ -79,6 +79,8 @@ def model_file_download(
|
||||
cache_dir = get_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||
os.makedirs(temporary_cache_dir, exist_ok=True)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
@@ -152,12 +154,13 @@ def model_file_download(
|
||||
temp_file_name = next(tempfile._get_candidate_names())
|
||||
http_get_file(
|
||||
url_to_download,
|
||||
cache_dir,
|
||||
temporary_cache_dir,
|
||||
temp_file_name,
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict())
|
||||
return cache.put_file(file_to_download_info,
|
||||
os.path.join(cache_dir, temp_file_name))
|
||||
return cache.put_file(
|
||||
file_to_download_info,
|
||||
os.path.join(temporary_cache_dir, temp_file_name))
|
||||
|
||||
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
@@ -58,6 +59,8 @@ def snapshot_download(model_id: str,
|
||||
cache_dir = get_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||
os.makedirs(temporary_cache_dir, exist_ok=True)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
@@ -98,31 +101,35 @@ def snapshot_download(model_id: str,
|
||||
headers=snapshot_header,
|
||||
)
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
# check model_file is exist in cache, if exist, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
file_name = os.path.basename(model_file['Name'])
|
||||
logger.info(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
with tempfile.TemporaryDirectory(
|
||||
dir=temporary_cache_dir) as temp_cache_dir:
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
# check model_file is exist in cache, if exist, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
file_name = os.path.basename(model_file['Name'])
|
||||
logger.info(
|
||||
f'File {file_name} already in cache, skip downloading!'
|
||||
)
|
||||
continue
|
||||
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
|
||||
# First download to /tmp
|
||||
http_get_file(
|
||||
url=url,
|
||||
local_dir=cache_dir,
|
||||
file_name=model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
# put file to cache
|
||||
cache.put_file(model_file,
|
||||
os.path.join(cache_dir, model_file['Name']))
|
||||
# First download to /tmp
|
||||
http_get_file(
|
||||
url=url,
|
||||
local_dir=temp_cache_dir,
|
||||
file_name=model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
# put file to cache
|
||||
cache.put_file(
|
||||
model_file, os.path.join(temp_cache_dir,
|
||||
model_file['Name']))
|
||||
|
||||
return os.path.join(cache.get_root_location())
|
||||
|
||||
@@ -21,9 +21,6 @@ DEFAULT_GIT_PATH = 'git'
|
||||
download_model_file_name = 'test.bin'
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"Access token is always change, we can't login with same access token, so skip!"
|
||||
)
|
||||
class HubOperationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -18,9 +18,6 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
|
||||
delete_credential)
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"Access token is always change, we can't login with same access token, so skip!"
|
||||
)
|
||||
class HubPrivateFileDownloadTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -15,9 +15,6 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
|
||||
DEFAULT_GIT_PATH = 'git'
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"Access token is always change, we can't login with same access token, so skip!"
|
||||
)
|
||||
class HubPrivateRepositoryTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -24,9 +24,6 @@ logger.setLevel('DEBUG')
|
||||
DEFAULT_GIT_PATH = 'git'
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"Access token is always change, we can't login with same access token, so skip!"
|
||||
)
|
||||
class HubRepositoryTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -6,8 +6,8 @@ from os.path import expanduser
|
||||
from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH
|
||||
|
||||
# for user citest and sdkdev
|
||||
TEST_ACCESS_TOKEN1 = 'OVAzNU9aZ2FYbXFhdGNzZll6VHRtalQ0T1BpZTNGeWVhMkxSSGpTSzU0dkM5WE5ObDFKdFRQWGc2U2ZIdjdPdg=='
|
||||
TEST_ACCESS_TOKEN2 = 'aXRocHhGeG0rNXRWQWhBSnJpTTZUQ0RDbUlkcUJRS1dQR2lNb0xIa0JjRDBrT1JKYklZV05DVzROTTdtamxWcg=='
|
||||
TEST_ACCESS_TOKEN1 = 'RGZZdkh2Z3BlMFU1VktjUkdIcUJtdjdqdnhQUEQrUVROdVBjclAzUGVycHFhU1BFZFBIaGtUOHB1eHQ2OTV3dQ=='
|
||||
TEST_ACCESS_TOKEN2 = 'dFpadllseTZQbHlyK0E4amQxVC84a2RtZHdkUVhmMUl3M1VXZXU4dS9GZlRuVmFUTW5yQm8yTENYWEw2SVh0Uw=='
|
||||
|
||||
TEST_MODEL_CHINESE_NAME = '内部测试模型'
|
||||
TEST_MODEL_ORG = 'citest'
|
||||
|
||||
Reference in New Issue
Block a user