mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
update
This commit is contained in:
@@ -108,8 +108,6 @@ class LLMPipeline(Pipeline):
|
||||
is_messages = isinstance(inputs, dict) and 'messages' in inputs
|
||||
tokens = self.preprocess(inputs, is_messages, **preprocess_params)
|
||||
|
||||
response = dict()
|
||||
|
||||
if output_logits:
|
||||
with torch.no_grad():
|
||||
if hasattr(self.model, 'model'):
|
||||
@@ -117,6 +115,7 @@ class LLMPipeline(Pipeline):
|
||||
else:
|
||||
outputs = self.model(
|
||||
tokens['inputs']) # [batch, seq, vocab]
|
||||
return {'logits': outputs[0], 'tokens': tokens['inputs']}
|
||||
else:
|
||||
if hasattr(self.model, 'generate'):
|
||||
outputs = self.model.generate(**tokens, **forward_params)
|
||||
@@ -130,16 +129,7 @@ class LLMPipeline(Pipeline):
|
||||
response = self.postprocess(outputs, is_messages,
|
||||
**postprocess_params)
|
||||
|
||||
if output_logits:
|
||||
logits_d: dict = {
|
||||
'logits': outputs[0].cpu(),
|
||||
'tokens': tokens['inputs'].cpu(),
|
||||
'inputs_len': len(tokens['inputs'][0]),
|
||||
'decode_func': self.tokenizer.decode
|
||||
}
|
||||
response.update(logits_d)
|
||||
|
||||
return response
|
||||
return response
|
||||
|
||||
def preprocess(self, inputs: Union[str, Dict], is_messages: bool,
|
||||
**kwargs):
|
||||
|
||||
Reference in New Issue
Block a user