This commit is contained in:
xingjun.wang
2023-10-22 18:55:07 +08:00
parent 010d23f9a2
commit 9ecf424d1b
3 changed files with 25 additions and 2 deletions

View File

@@ -30,10 +30,22 @@ class TorchModel(Model, torch.nn.Module):
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
# Adapting a model with only one dict arg, and the arg name must be input or inputs
print(f'\n>>>TorchModel.__call__ args: {args}, kwargs: {kwargs}')
if func_receive_dict_inputs(self.forward):
return self.postprocess(self.forward(args[0], **kwargs))
# return self.postprocess(self.forward(args[0], **kwargs))
res = self.forward(args[0], **kwargs)
print(
f'>>res in TorchModel.__call__ with func_receive_dict_inputs: {res}'
)
return self.postprocess(res)
else:
return self.postprocess(self.forward(*args, **kwargs))
# return self.postprocess(self.forward(*args, **kwargs))
res = self.forward(*args, **kwargs)
print(f'>>res in TorchModel.__call__: {res}')
return self.postprocess(res)
def _load_pretrained(self,
net,

View File

@@ -1205,7 +1205,17 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
**kwargs
}
inputs = self.build_inputs(tokenizer, query, history=history)
print(
f'\n>>inputs in ChatGLM2ForConditionalGeneration._chat:\n {inputs}, \n>shape: {inputs.shape}'
)
outputs = self.generate(**inputs, **gen_kwargs)
print(
f'\n>>outputs in ChatGLM2ForConditionalGeneration._chat:\n {outputs}, \n>shape: {outputs.shape}'
)
outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):]
response = tokenizer.decode(outputs)
response = self.process_response(response)

View File

@@ -121,6 +121,7 @@ class LLMPipeline(Pipeline):
if hasattr(self.model, 'generate'):
outputs = self.model.generate(**tokens, **forward_params)
print(f'>>>self.model.generate: {self.model.generate}')
elif hasattr(self.model, 'model') and hasattr(self.model.model,
'generate'):
outputs = self.model.model.generate(**tokens, **forward_params)