mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
resolve stopword logits processor
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13504522 * resolve stopword logits processor
This commit is contained in:
committed by
wenmeng.zwm
parent
903c5690ed
commit
e2c19e89e6
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.generation.logits_process import LogitsProcessorList
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
@@ -173,9 +174,15 @@ class QWenForTextGeneration(QWenPreTrainedModel):
|
||||
tokenizer)
|
||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||
|
||||
logits_processor_list = LogitsProcessorList([
|
||||
StopWordsLogitsProcessor(
|
||||
stop_words_ids=stop_words_ids,
|
||||
eos_token_id=self.generation_config.eos_token_id)
|
||||
])
|
||||
|
||||
outputs = self.generate(
|
||||
input_ids,
|
||||
stop_words_ids=stop_words_ids,
|
||||
logits_processor=logits_processor_list,
|
||||
return_dict_in_generate=False,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user