From 2eb633ec931c322bb359521787598032c12e1f28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Thu, 23 Jun 2022 11:20:54 +0800 Subject: [PATCH] revert a mis modification --- modelscope/models/nlp/masked_language_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modelscope/models/nlp/masked_language_model.py b/modelscope/models/nlp/masked_language_model.py index 928bda7b..6a8c6626 100644 --- a/modelscope/models/nlp/masked_language_model.py +++ b/modelscope/models/nlp/masked_language_model.py @@ -31,20 +31,20 @@ class MaskedLMModelBase(Model): return self.model.config return None - def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]: + def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]: """return the result by the model Args: - inputs (Dict[str, Any]): the preprocessed data + input (Dict[str, Any]): the preprocessed data Returns: Dict[str, np.ndarray]: results """ rst = self.model( - input_ids=inputs['input_ids'], - attention_mask=inputs['attention_mask'], - token_type_ids=inputs['token_type_ids']) - return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} + input_ids=input['input_ids'], + attention_mask=input['attention_mask'], + token_type_ids=input['token_type_ids']) + return {'logits': rst['logits'], 'input_ids': input['input_ids']} @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)