mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
committed by
wenmeng.zwm
parent
0e7ae942d7
commit
54fdc76fdb
@@ -222,10 +222,13 @@ TASK_INPUTS = {
|
||||
},
|
||||
|
||||
# ============ nlp tasks ===================
|
||||
Tasks.chat: {
|
||||
'text': InputType.TEXT,
|
||||
'history': InputType.LIST,
|
||||
},
|
||||
Tasks.chat: [
|
||||
InputType.TEXT,
|
||||
{
|
||||
'text': InputType.TEXT,
|
||||
'history': InputType.LIST,
|
||||
}
|
||||
],
|
||||
Tasks.text_classification: [
|
||||
InputType.TEXT,
|
||||
(InputType.TEXT, InputType.TEXT),
|
||||
@@ -234,7 +237,13 @@ TASK_INPUTS = {
|
||||
'text2': InputType.TEXT
|
||||
},
|
||||
],
|
||||
Tasks.sentence_similarity: (InputType.TEXT, InputType.TEXT),
|
||||
Tasks.sentence_similarity: [
|
||||
(InputType.TEXT, InputType.TEXT),
|
||||
{
|
||||
'source_text': InputType.TEXT,
|
||||
'target_text': InputType.TEXT,
|
||||
},
|
||||
],
|
||||
Tasks.nli: (InputType.TEXT, InputType.TEXT),
|
||||
Tasks.sentiment_classification:
|
||||
InputType.TEXT,
|
||||
@@ -275,10 +284,6 @@ TASK_INPUTS = {
|
||||
},
|
||||
Tasks.fill_mask:
|
||||
InputType.TEXT,
|
||||
Tasks.task_oriented_conversation: {
|
||||
'user_input': InputType.TEXT,
|
||||
'history': InputType.DICT,
|
||||
},
|
||||
Tasks.table_question_answering: {
|
||||
'question': InputType.TEXT,
|
||||
'history_sql': InputType.DICT,
|
||||
|
||||
@@ -346,12 +346,13 @@ class Pipeline(ABC):
|
||||
if isinstance(input_type, str):
|
||||
check_input_type(input_type, input)
|
||||
elif isinstance(input_type, tuple):
|
||||
assert isinstance(input, tuple), 'input should be a tuple'
|
||||
for t, input_ele in zip(input_type, input):
|
||||
check_input_type(t, input_ele)
|
||||
elif isinstance(input_type, dict):
|
||||
for k in input_type.keys():
|
||||
# allow single input for multi-modal models
|
||||
if k in input:
|
||||
if isinstance(input, dict) and k in input:
|
||||
check_input_type(input_type[k], input[k])
|
||||
else:
|
||||
raise ValueError(f'invalid input_type definition {input_type}')
|
||||
@@ -373,7 +374,7 @@ class Pipeline(ABC):
|
||||
input = input.keys() if isinstance(input,
|
||||
(dict, ModelOutputBase)) else input
|
||||
for k in output_keys:
|
||||
if k not in input:
|
||||
if isinstance(k, (dict, ModelOutputBase)) and k not in input:
|
||||
missing_keys.append(k)
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(f'expected output keys are {output_keys}, '
|
||||
|
||||
@@ -141,6 +141,39 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
self.assertEqual(out['url'], img_url + 'dummy_end')
|
||||
self.assertEqual(out['img'].shape, (1, 640, 640, 3))
|
||||
|
||||
def test_chat_task(self):
|
||||
dummy_module = 'dummy_module'
|
||||
|
||||
@PIPELINES.register_module(
|
||||
group_key=Tasks.chat, module_name=dummy_module)
|
||||
class CustomChat(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
config_file: str = None,
|
||||
model=None,
|
||||
preprocessor=None,
|
||||
**kwargs):
|
||||
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
preprocessor = f
|
||||
super().__init__(config_file, model, preprocessor, **kwargs)
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
""" Provide default implementation using self.model and user can reimplement it
|
||||
"""
|
||||
return inputs
|
||||
|
||||
def postprocess(self, out, **kwargs):
|
||||
return {'response': 'xxx', 'history': []}
|
||||
|
||||
pipe = pipeline(
|
||||
task=Tasks.chat, pipeline_name=dummy_module, model=self.model_dir)
|
||||
pipe('text')
|
||||
inputs = {'text': 'aaa', 'history': [('dfd', 'fds')]}
|
||||
pipe(inputs)
|
||||
|
||||
def test_custom(self):
|
||||
dummy_task = 'dummy-task'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user