This commit is contained in:
mulin.lyh
2023-09-27 23:57:46 +08:00
4 changed files with 27 additions and 24 deletions

View File

@@ -562,6 +562,7 @@ TASK_OUTPUTS = {
# }
Tasks.facial_expression_recognition:
[OutputKeys.SCORES, OutputKeys.LABELS],
Tasks.general_recognition: [OutputKeys.SCORES, OutputKeys.LABELS],
# face processing base result for single img
# {

View File

@@ -82,14 +82,28 @@ def check_input_type(input_type, input):
TASK_INPUTS = {
# if task input is single var, value is InputType
# if task input is a tuple, value is tuple of InputType
# if task input is a dict, value is a dict of InputType, where key
# equals the one needed in pipeline input dict
# if task input is a list, value is a set of input format, in which
# each element corresponds to one input format as described above and
# must include a dict format.
Tasks.task_template: {
'image': InputType.IMAGE,
'text': InputType.TEXT
},
# ============ vision tasks ===================
Tasks.image_text_retrieval: {
InputKeys.IMAGE: InputType.IMAGE,
InputKeys.TEXT: InputType.TEXT
},
Tasks.general_recognition: {
InputKeys.IMAGE: InputType.IMAGE,
InputKeys.TEXT: InputType.TEXT
},
Tasks.general_recognition:
InputType.IMAGE,
Tasks.video_depth_estimation: {
InputKeys.IMAGE: InputType.IMAGE,
InputKeys.TEXT: InputType.TEXT
@@ -110,20 +124,6 @@ TASK_INPUTS = {
InputType.VIDEO,
Tasks.task_template: {
'image': InputType.IMAGE,
'text': InputType.TEXT
},
# if task input is single var, value is InputType
# if task input is a tuple, value is tuple of InputType
# if task input is a dict, value is a dict of InputType, where key
# equals the one needed in pipeline input dict
# if task input is a list, value is a set of input format, in which
# each element corresponds to one input format as described above and
# must include a dict format.
# ============ vision tasks ===================
Tasks.ocr_detection:
InputType.IMAGE,
Tasks.ocr_recognition:

View File

@@ -91,12 +91,13 @@ def check_hf_code(model_dir: str, auto_class: type,
raise FileNotFoundError(f'{config_path} is not found')
config_dict = PretrainedConfig.get_config_dict(config_path)[0]
auto_class_name = auto_class.__name__
if auto_class is AutoTokenizerHF:
tokenizer_config = get_tokenizer_config(model_dir)
# load from repo
if trust_remote_code:
has_remote_code = False
if auto_class is AutoTokenizerHF:
tokenizer_config_dict = get_tokenizer_config(model_dir)
auto_map = tokenizer_config_dict.get('auto_map', None)
auto_map = tokenizer_config.get('auto_map', None)
if auto_map is not None:
module_name = auto_map.get(auto_class_name, None)
if module_name is not None:
@@ -129,6 +130,9 @@ def check_hf_code(model_dir: str, auto_class: type,
f'{model_type} not found in HF `CONFIG_MAPPING`{trust_remote_code_info}'
)
elif auto_class is AutoTokenizerHF:
tokenizer_class = tokenizer_config.get('tokenizer_class')
if tokenizer_class is not None:
return
if model_type not in TOKENIZER_MAPPING_NAMES:
raise ValueError(
f'{model_type} not found in HF `TOKENIZER_MAPPING_NAMES`{trust_remote_code_info}'

View File

@@ -656,7 +656,7 @@ def service_base64_input_to_pipeline_input(task_name, body):
if isinstance(service_input, (str, int, float)):
return service_input, parameters
task_input_info = TASK_INPUTS[task_name]
task_input_info = TASK_INPUTS.get(task_name, None)
if isinstance(task_input_info, str): # no input key default
if isinstance(service_input, dict):
return base64_decoder_map[task_input_info](list(
@@ -767,9 +767,7 @@ def pipeline_output_to_service_base64_output(task_name, pipeline_output):
pipeline_output (object): The pipeline output.
"""
json_serializable_output = {}
task_outputs = []
if task_name in TASK_OUTPUTS:
task_outputs = TASK_OUTPUTS[task_name]
task_outputs = TASK_OUTPUTS.get(task_name, [])
# TODO: for batch
if isinstance(pipeline_output, list):
pipeline_output = pipeline_output[0]