diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index 82c5ce10..368abad6 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -562,6 +562,7 @@ TASK_OUTPUTS = { # } Tasks.facial_expression_recognition: [OutputKeys.SCORES, OutputKeys.LABELS], + Tasks.general_recognition: [OutputKeys.SCORES, OutputKeys.LABELS], # face processing base result for single img # { diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index f465a722..bffbebbd 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -82,14 +82,28 @@ def check_input_type(input_type, input): TASK_INPUTS = { + + # if task input is single var, value is InputType + # if task input is a tuple, value is tuple of InputType + # if task input is a dict, value is a dict of InputType, where key + # equals the one needed in pipeline input dict + # if task input is a list, value is a set of input format, in which + # each element corresponds to one input format as described above and + # must include a dict format. + + + Tasks.task_template: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + # ============ vision tasks =================== + Tasks.image_text_retrieval: { InputKeys.IMAGE: InputType.IMAGE, InputKeys.TEXT: InputType.TEXT }, - Tasks.general_recognition: { - InputKeys.IMAGE: InputType.IMAGE, - InputKeys.TEXT: InputType.TEXT - }, + Tasks.general_recognition: + InputType.IMAGE, Tasks.video_depth_estimation: { InputKeys.IMAGE: InputType.IMAGE, InputKeys.TEXT: InputType.TEXT @@ -110,20 +124,6 @@ TASK_INPUTS = { InputType.VIDEO, - - Tasks.task_template: { - 'image': InputType.IMAGE, - 'text': InputType.TEXT - }, - # if task input is single var, value is InputType - # if task input is a tuple, value is tuple of InputType - # if task input is a dict, value is a dict of InputType, where key - # equals the one needed in pipeline input dict - # if task input is a list, value is a set of input format, in which - # each element corresponds to one input format as described above and - # must include a dict format. - - # ============ vision tasks =================== Tasks.ocr_detection: InputType.IMAGE, Tasks.ocr_recognition: diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index 3abcce6d..6ef98ccf 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -91,12 +91,13 @@ def check_hf_code(model_dir: str, auto_class: type, raise FileNotFoundError(f'{config_path} is not found') config_dict = PretrainedConfig.get_config_dict(config_path)[0] auto_class_name = auto_class.__name__ + if auto_class is AutoTokenizerHF: + tokenizer_config = get_tokenizer_config(model_dir) # load from repo if trust_remote_code: has_remote_code = False if auto_class is AutoTokenizerHF: - tokenizer_config_dict = get_tokenizer_config(model_dir) - auto_map = tokenizer_config_dict.get('auto_map', None) + auto_map = tokenizer_config.get('auto_map', None) if auto_map is not None: module_name = auto_map.get(auto_class_name, None) if module_name is not None: @@ -129,6 +130,9 @@ def check_hf_code(model_dir: str, auto_class: type, f'{model_type} not found in HF `CONFIG_MAPPING`{trust_remote_code_info}' ) elif auto_class is AutoTokenizerHF: + tokenizer_class = tokenizer_config.get('tokenizer_class') + if tokenizer_class is not None: + return if model_type not in TOKENIZER_MAPPING_NAMES: raise ValueError( f'{model_type} not found in HF `TOKENIZER_MAPPING_NAMES`{trust_remote_code_info}' diff --git a/modelscope/utils/input_output.py b/modelscope/utils/input_output.py index dbe5861d..d8e32cce 100644 --- a/modelscope/utils/input_output.py +++ b/modelscope/utils/input_output.py @@ -656,7 +656,7 @@ def service_base64_input_to_pipeline_input(task_name, body): if isinstance(service_input, (str, int, float)): return service_input, parameters - task_input_info = TASK_INPUTS[task_name] + task_input_info = TASK_INPUTS.get(task_name, None) if isinstance(task_input_info, str): # no input key default if isinstance(service_input, dict): return base64_decoder_map[task_input_info](list( @@ -767,9 +767,7 @@ def pipeline_output_to_service_base64_output(task_name, pipeline_output): pipeline_output (object): The pipeline output. """ json_serializable_output = {} - task_outputs = [] - if task_name in TASK_OUTPUTS: - task_outputs = TASK_OUTPUTS[task_name] + task_outputs = TASK_OUTPUTS.get(task_name, []) # TODO: for batch if isinstance(pipeline_output, list): pipeline_output = pipeline_output[0]