This commit is contained in:
xingjun.wang
2023-10-23 16:49:43 +08:00
parent 0af63143f8
commit 2ff17e1c0f

View File

@@ -108,8 +108,6 @@ class LLMPipeline(Pipeline):
is_messages = isinstance(inputs, dict) and 'messages' in inputs
tokens = self.preprocess(inputs, is_messages, **preprocess_params)
response = dict()
if output_logits:
with torch.no_grad():
if hasattr(self.model, 'model'):
@@ -117,6 +115,7 @@ class LLMPipeline(Pipeline):
else:
outputs = self.model(
tokens['inputs']) # [batch, seq, vocab]
return {'logits': outputs[0], 'tokens': tokens['inputs']}
else:
if hasattr(self.model, 'generate'):
outputs = self.model.generate(**tokens, **forward_params)
@@ -130,16 +129,7 @@ class LLMPipeline(Pipeline):
response = self.postprocess(outputs, is_messages,
**postprocess_params)
if output_logits:
logits_d: dict = {
'logits': outputs[0].cpu(),
'tokens': tokens['inputs'].cpu(),
'inputs_len': len(tokens['inputs'][0]),
'decode_func': self.tokenizer.decode
}
response.update(logits_d)
return response
return response
def preprocess(self, inputs: Union[str, Dict], is_messages: bool,
**kwargs):