damo/nlp_seqgpt-560m pipeline bugfix (#511)

Co-authored-by: chengchen.cc <chengchen.cc@MacBook-Pro-7.local>
This commit is contained in:
ccyhxg
2023-08-29 20:53:26 +08:00
committed by GitHub
parent 39922160ae
commit 1ec8fe4a96

View File

@@ -445,9 +445,9 @@ class SeqGPTPipeline(Pipeline):
# define the forward pass
def forward(self, prompt: str, **forward_params) -> Dict[str, Any]:
# gen & decode
prompt += '[GEN]'
# prompt += '[GEN]'
input_ids = self.tokenizer(
prompt,
prompt + '[GEN]',
return_tensors='pt',
padding=True,
truncation=True,