fix pipeline check error (#455)

* fix pipeline check error

* update
This commit is contained in:
wenmeng zhou
2023-08-11 15:52:53 +08:00
committed by wenmeng.zwm
parent 0e7ae942d7
commit 54fdc76fdb
3 changed files with 50 additions and 11 deletions

View File

@@ -222,10 +222,13 @@ TASK_INPUTS = {
},
# ============ nlp tasks ===================
Tasks.chat: {
'text': InputType.TEXT,
'history': InputType.LIST,
},
Tasks.chat: [
InputType.TEXT,
{
'text': InputType.TEXT,
'history': InputType.LIST,
}
],
Tasks.text_classification: [
InputType.TEXT,
(InputType.TEXT, InputType.TEXT),
@@ -234,7 +237,13 @@ TASK_INPUTS = {
'text2': InputType.TEXT
},
],
Tasks.sentence_similarity: (InputType.TEXT, InputType.TEXT),
Tasks.sentence_similarity: [
(InputType.TEXT, InputType.TEXT),
{
'source_text': InputType.TEXT,
'target_text': InputType.TEXT,
},
],
Tasks.nli: (InputType.TEXT, InputType.TEXT),
Tasks.sentiment_classification:
InputType.TEXT,
@@ -275,10 +284,6 @@ TASK_INPUTS = {
},
Tasks.fill_mask:
InputType.TEXT,
Tasks.task_oriented_conversation: {
'user_input': InputType.TEXT,
'history': InputType.DICT,
},
Tasks.table_question_answering: {
'question': InputType.TEXT,
'history_sql': InputType.DICT,

View File

@@ -346,12 +346,13 @@ class Pipeline(ABC):
if isinstance(input_type, str):
check_input_type(input_type, input)
elif isinstance(input_type, tuple):
assert isinstance(input, tuple), 'input should be a tuple'
for t, input_ele in zip(input_type, input):
check_input_type(t, input_ele)
elif isinstance(input_type, dict):
for k in input_type.keys():
# allow single input for multi-modal models
if k in input:
if isinstance(input, dict) and k in input:
check_input_type(input_type[k], input[k])
else:
raise ValueError(f'invalid input_type definition {input_type}')
@@ -373,7 +374,7 @@ class Pipeline(ABC):
input = input.keys() if isinstance(input,
(dict, ModelOutputBase)) else input
for k in output_keys:
if k not in input:
if isinstance(k, (dict, ModelOutputBase)) and k not in input:
missing_keys.append(k)
if len(missing_keys) > 0:
raise ValueError(f'expected output keys are {output_keys}, '

View File

@@ -141,6 +141,39 @@ class CustomPipelineTest(unittest.TestCase):
self.assertEqual(out['url'], img_url + 'dummy_end')
self.assertEqual(out['img'].shape, (1, 640, 640, 3))
def test_chat_task(self):
dummy_module = 'dummy_module'
@PIPELINES.register_module(
group_key=Tasks.chat, module_name=dummy_module)
class CustomChat(Pipeline):
def __init__(self,
config_file: str = None,
model=None,
preprocessor=None,
**kwargs):
def f(x):
return x
preprocessor = f
super().__init__(config_file, model, preprocessor, **kwargs)
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
""" Provide default implementation using self.model and user can reimplement it
"""
return inputs
def postprocess(self, out, **kwargs):
return {'response': 'xxx', 'history': []}
pipe = pipeline(
task=Tasks.chat, pipeline_name=dummy_module, model=self.model_dir)
pipe('text')
inputs = {'text': 'aaa', 'history': [('dfd', 'fds')]}
pipe(inputs)
def test_custom(self):
dummy_task = 'dummy-task'