mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge branch 'master' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib
This commit is contained in:
@@ -562,6 +562,7 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.facial_expression_recognition:
|
||||
[OutputKeys.SCORES, OutputKeys.LABELS],
|
||||
Tasks.general_recognition: [OutputKeys.SCORES, OutputKeys.LABELS],
|
||||
|
||||
# face processing base result for single img
|
||||
# {
|
||||
|
||||
@@ -82,14 +82,28 @@ def check_input_type(input_type, input):
|
||||
|
||||
|
||||
TASK_INPUTS = {
|
||||
|
||||
# if task input is single var, value is InputType
|
||||
# if task input is a tuple, value is tuple of InputType
|
||||
# if task input is a dict, value is a dict of InputType, where key
|
||||
# equals the one needed in pipeline input dict
|
||||
# if task input is a list, value is a set of input format, in which
|
||||
# each element corresponds to one input format as described above and
|
||||
# must include a dict format.
|
||||
|
||||
|
||||
Tasks.task_template: {
|
||||
'image': InputType.IMAGE,
|
||||
'text': InputType.TEXT
|
||||
},
|
||||
# ============ vision tasks ===================
|
||||
|
||||
Tasks.image_text_retrieval: {
|
||||
InputKeys.IMAGE: InputType.IMAGE,
|
||||
InputKeys.TEXT: InputType.TEXT
|
||||
},
|
||||
Tasks.general_recognition: {
|
||||
InputKeys.IMAGE: InputType.IMAGE,
|
||||
InputKeys.TEXT: InputType.TEXT
|
||||
},
|
||||
Tasks.general_recognition:
|
||||
InputType.IMAGE,
|
||||
Tasks.video_depth_estimation: {
|
||||
InputKeys.IMAGE: InputType.IMAGE,
|
||||
InputKeys.TEXT: InputType.TEXT
|
||||
@@ -110,20 +124,6 @@ TASK_INPUTS = {
|
||||
InputType.VIDEO,
|
||||
|
||||
|
||||
|
||||
Tasks.task_template: {
|
||||
'image': InputType.IMAGE,
|
||||
'text': InputType.TEXT
|
||||
},
|
||||
# if task input is single var, value is InputType
|
||||
# if task input is a tuple, value is tuple of InputType
|
||||
# if task input is a dict, value is a dict of InputType, where key
|
||||
# equals the one needed in pipeline input dict
|
||||
# if task input is a list, value is a set of input format, in which
|
||||
# each element corresponds to one input format as described above and
|
||||
# must include a dict format.
|
||||
|
||||
# ============ vision tasks ===================
|
||||
Tasks.ocr_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.ocr_recognition:
|
||||
|
||||
@@ -91,12 +91,13 @@ def check_hf_code(model_dir: str, auto_class: type,
|
||||
raise FileNotFoundError(f'{config_path} is not found')
|
||||
config_dict = PretrainedConfig.get_config_dict(config_path)[0]
|
||||
auto_class_name = auto_class.__name__
|
||||
if auto_class is AutoTokenizerHF:
|
||||
tokenizer_config = get_tokenizer_config(model_dir)
|
||||
# load from repo
|
||||
if trust_remote_code:
|
||||
has_remote_code = False
|
||||
if auto_class is AutoTokenizerHF:
|
||||
tokenizer_config_dict = get_tokenizer_config(model_dir)
|
||||
auto_map = tokenizer_config_dict.get('auto_map', None)
|
||||
auto_map = tokenizer_config.get('auto_map', None)
|
||||
if auto_map is not None:
|
||||
module_name = auto_map.get(auto_class_name, None)
|
||||
if module_name is not None:
|
||||
@@ -129,6 +130,9 @@ def check_hf_code(model_dir: str, auto_class: type,
|
||||
f'{model_type} not found in HF `CONFIG_MAPPING`{trust_remote_code_info}'
|
||||
)
|
||||
elif auto_class is AutoTokenizerHF:
|
||||
tokenizer_class = tokenizer_config.get('tokenizer_class')
|
||||
if tokenizer_class is not None:
|
||||
return
|
||||
if model_type not in TOKENIZER_MAPPING_NAMES:
|
||||
raise ValueError(
|
||||
f'{model_type} not found in HF `TOKENIZER_MAPPING_NAMES`{trust_remote_code_info}'
|
||||
|
||||
@@ -656,7 +656,7 @@ def service_base64_input_to_pipeline_input(task_name, body):
|
||||
|
||||
if isinstance(service_input, (str, int, float)):
|
||||
return service_input, parameters
|
||||
task_input_info = TASK_INPUTS[task_name]
|
||||
task_input_info = TASK_INPUTS.get(task_name, None)
|
||||
if isinstance(task_input_info, str): # no input key default
|
||||
if isinstance(service_input, dict):
|
||||
return base64_decoder_map[task_input_info](list(
|
||||
@@ -767,9 +767,7 @@ def pipeline_output_to_service_base64_output(task_name, pipeline_output):
|
||||
pipeline_output (object): The pipeline output.
|
||||
"""
|
||||
json_serializable_output = {}
|
||||
task_outputs = []
|
||||
if task_name in TASK_OUTPUTS:
|
||||
task_outputs = TASK_OUTPUTS[task_name]
|
||||
task_outputs = TASK_OUTPUTS.get(task_name, [])
|
||||
# TODO: for batch
|
||||
if isinstance(pipeline_output, list):
|
||||
pipeline_output = pipeline_output[0]
|
||||
|
||||
Reference in New Issue
Block a user