fix message

This commit is contained in:
Yingda Chen
2024-11-21 19:07:36 +08:00
parent 2970a5f8fd
commit f5adc992a2

View File

@@ -370,6 +370,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
def preprocess(self, inputs: Union[str, Dict], **kwargs):
is_messages = kwargs.pop('is_messages')
print(kwargs)
if is_messages:
tokens = self.format_messages(inputs, self.tokenizer, **kwargs)
else:
@@ -440,6 +441,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
# for compatibility, also support input list, but we shall wrap it into Dict
if isinstance(messages, list):
messages = {'messages': messages}
kwargs['is_message'] = True
for role, content in LLMPipeline._message_iter(messages):
tokens = LLMPipeline._concat_with_special_tokens(
tokens, role, content, tokenizer)