Merge branch 'master-github' into merge_github_0911

This commit is contained in:
mulin.lyh
2023-09-11 13:56:24 +08:00
4 changed files with 41 additions and 21 deletions

View File

@@ -19,7 +19,7 @@ REQUESTS_API_HTTP_METHOD = ['get', 'head', 'post', 'put', 'patch', 'delete']
API_HTTP_CLIENT_TIMEOUT = 60
API_RESPONSE_FIELD_DATA = 'Data'
API_FILE_DOWNLOAD_RETRY_TIMES = 5
API_FILE_DOWNLOAD_TIMEOUT = 60 * 5
API_FILE_DOWNLOAD_TIMEOUT = 30
API_FILE_DOWNLOAD_CHUNK_SIZE = 1024 * 1024 * 16
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken'
API_RESPONSE_FIELD_USERNAME = 'Username'

View File

@@ -187,23 +187,36 @@ def get_file_download_url(model_id: str, file_path: str, revision: str):
)
def download_part(params):
def download_part_with_retry(params):
# unpack parameters
progress, start, end, url, file_name, cookies, headers = params
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['Range'] = 'bytes=%s-%s' % (start, end)
with open(file_name, 'rb+') as f:
f.seek(start)
r = requests.get(
url,
stream=True,
headers=get_headers,
cookies=cookies,
timeout=API_FILE_DOWNLOAD_TIMEOUT)
for chunk in r.iter_content(chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress.update(len(chunk))
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1,
allowed_methods=['GET'])
while True:
try:
with open(file_name, 'rb+') as f:
f.seek(start)
r = requests.get(
url,
stream=True,
headers=get_headers,
cookies=cookies,
timeout=API_FILE_DOWNLOAD_TIMEOUT)
for chunk in r.iter_content(
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress.update(end - start)
break
except (Exception) as e: # no matter what exception, we will retry.
retry = retry.increment('GET', url, error=e)
logger.warning('Download file from: %s to: %s failed, will retry' %
(start, end))
retry.sleep()
def parallel_download(
@@ -226,7 +239,7 @@ def parallel_download(
initial=0,
desc='Downloading',
)
PART_SIZE = 160 * 1024 * 1012 # every part is 160M
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
tasks = []
for idx in range(int(file_size / PART_SIZE)):
start = idx * PART_SIZE
@@ -240,7 +253,7 @@ def parallel_download(
with ThreadPoolExecutor(
max_workers=parallels,
thread_name_prefix='download') as executor:
list(executor.map(download_part, tasks))
list(executor.map(download_part_with_retry, tasks))
progress.close()

View File

@@ -66,7 +66,7 @@ def try_to_load_hf_model(model_dir: str, task_name: str,
if use_hf and automodel_class is None:
raise ValueError(f'Model import failed. You used `use_hf={use_hf}`, '
'but the model is not a model of hf')
'but the model is not a model of hf.')
model = None
if automodel_class is not None:

View File

@@ -112,22 +112,29 @@ def check_hf_code(model_dir: str, auto_class: type,
# trust_remote_code is False or has_remote_code is False
model_type = config_dict.get('model_type', None)
if model_type is None:
raise ValueError(f'`model_type` key is not found in {config_path}')
raise ValueError(f'`model_type` key is not found in {config_path}.')
trust_remote_code_info = '.'
if not trust_remote_code:
trust_remote_code_info = ', You can try passing `trust_remote_code=True`.'
if auto_class is AutoConfigHF:
if model_type not in CONFIG_MAPPING:
raise ValueError(f'{model_type} not found in HF CONFIG_MAPPING')
raise ValueError(
f'{model_type} not found in HF `CONFIG_MAPPING`{trust_remote_code_info}'
)
elif auto_class is AutoTokenizerHF:
if model_type not in TOKENIZER_MAPPING_NAMES:
raise ValueError(
f'{model_type} not found in HF TOKENIZER_MAPPING_NAMES')
f'{model_type} not found in HF `TOKENIZER_MAPPING_NAMES`{trust_remote_code_info}'
)
else:
mapping_names = [
m.model_type for m in auto_class._model_mapping.keys()
]
if model_type not in mapping_names:
raise ValueError(
f'{model_type} not found in HF auto_class._model_mapping')
f'{model_type} not found in HF `auto_class._model_mapping`{trust_remote_code_info}'
)
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):