From e2c19e89e693905e716749838283b201b52f1b9f Mon Sep 17 00:00:00 2001 From: "lukeming.lkm" Date: Wed, 2 Aug 2023 16:14:48 +0800 Subject: [PATCH] resolve stopword logits processor Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13504522 * resolve stopword logits processor --- modelscope/models/nlp/qwen/text_generation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/modelscope/models/nlp/qwen/text_generation.py b/modelscope/models/nlp/qwen/text_generation.py index a4b4f7ed..a25164cb 100644 --- a/modelscope/models/nlp/qwen/text_generation.py +++ b/modelscope/models/nlp/qwen/text_generation.py @@ -7,6 +7,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers import PreTrainedTokenizer +from transformers.generation.logits_process import LogitsProcessorList from transformers.modeling_outputs import CausalLMOutputWithPast from modelscope.metainfo import Models @@ -173,9 +174,15 @@ class QWenForTextGeneration(QWenPreTrainedModel): tokenizer) input_ids = torch.tensor([context_tokens]).to(self.device) + logits_processor_list = LogitsProcessorList([ + StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=self.generation_config.eos_token_id) + ]) + outputs = self.generate( input_ids, - stop_words_ids=stop_words_ids, + logits_processor=logits_processor_list, return_dict_in_generate=False, )