diff --git a/modelscope/models/nlp/qwen/text_generation.py b/modelscope/models/nlp/qwen/text_generation.py index a4b4f7ed..a25164cb 100644 --- a/modelscope/models/nlp/qwen/text_generation.py +++ b/modelscope/models/nlp/qwen/text_generation.py @@ -7,6 +7,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers import PreTrainedTokenizer +from transformers.generation.logits_process import LogitsProcessorList from transformers.modeling_outputs import CausalLMOutputWithPast from modelscope.metainfo import Models @@ -173,9 +174,15 @@ class QWenForTextGeneration(QWenPreTrainedModel): tokenizer) input_ids = torch.tensor([context_tokens]).to(self.device) + logits_processor_list = LogitsProcessorList([ + StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=self.generation_config.eos_token_id) + ]) + outputs = self.generate( input_ids, - stop_words_ids=stop_words_ids, + logits_processor=logits_processor_list, return_dict_in_generate=False, )