This commit is contained in:
xingjun.wang
2023-10-22 01:13:32 +08:00
parent 88f98820a7
commit 5b817f916e

View File

@@ -108,6 +108,9 @@ class LlamaForTextGeneration(MsModelMixin, LlamaForCausalLM, TorchModel):
max_length=gen_kwargs['max_length'],
tokenizer=tokenizer)
input_ids = prompt_ids.to(self.device)
print(f'>>>input_ids in text_generation: {input_ids}')
generate_ids = self.generate(input_ids, **gen_kwargs)
# remove input tokens
generate_ids = generate_ids[:, input_ids.shape[1]:]