fix hf bug (#567)

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14181647
* fix hf bug (#567)
This commit is contained in:
mulin.lyh
2023-09-27 09:33:32 +08:00
parent 1e7a170089
commit f426e49d3b

View File

@@ -91,12 +91,13 @@ def check_hf_code(model_dir: str, auto_class: type,
raise FileNotFoundError(f'{config_path} is not found')
config_dict = PretrainedConfig.get_config_dict(config_path)[0]
auto_class_name = auto_class.__name__
if auto_class is AutoTokenizerHF:
tokenizer_config = get_tokenizer_config(model_dir)
# load from repo
if trust_remote_code:
has_remote_code = False
if auto_class is AutoTokenizerHF:
tokenizer_config_dict = get_tokenizer_config(model_dir)
auto_map = tokenizer_config_dict.get('auto_map', None)
auto_map = tokenizer_config.get('auto_map', None)
if auto_map is not None:
module_name = auto_map.get(auto_class_name, None)
if module_name is not None:
@@ -129,7 +130,10 @@ def check_hf_code(model_dir: str, auto_class: type,
f'{model_type} not found in HF `CONFIG_MAPPING`{trust_remote_code_info}'
)
elif auto_class is AutoTokenizerHF:
if model_type not in TOKENIZER_MAPPING_NAMES:
tokenizer_class = tokenizer_config.get('tokenizer_class')
if tokenizer_class is not None:
return
if model_type in TOKENIZER_MAPPING_NAMES:
raise ValueError(
f'{model_type} not found in HF `TOKENIZER_MAPPING_NAMES`{trust_remote_code_info}'
)