resolve stopword logits processor

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13504522

* resolve stopword logits processor
This commit is contained in:
lukeming.lkm
2023-08-02 16:14:48 +08:00
committed by wenmeng.zwm
parent 903c5690ed
commit e2c19e89e6

View File

@@ -7,6 +7,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer
from transformers.generation.logits_process import LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithPast
from modelscope.metainfo import Models
@@ -173,9 +174,15 @@ class QWenForTextGeneration(QWenPreTrainedModel):
tokenizer)
input_ids = torch.tensor([context_tokens]).to(self.device)
logits_processor_list = LogitsProcessorList([
StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=self.generation_config.eos_token_id)
])
outputs = self.generate(
input_ids,
stop_words_ids=stop_words_ids,
logits_processor=logits_processor_list,
return_dict_in_generate=False,
)