revert a mis modification

This commit is contained in:
雨泓
2022-06-23 11:20:54 +08:00
parent 75dd131ada
commit 2eb633ec93

View File

@@ -31,20 +31,20 @@ class MaskedLMModelBase(Model):
return self.model.config
return None
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
"""return the result by the model
Args:
inputs (Dict[str, Any]): the preprocessed data
input (Dict[str, Any]): the preprocessed data
Returns:
Dict[str, np.ndarray]: results
"""
rst = self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
token_type_ids=inputs['token_type_ids'])
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']}
input_ids=input['input_ids'],
attention_mask=input['attention_mask'],
token_type_ids=input['token_type_ids'])
return {'logits': rst['logits'], 'input_ids': input['input_ids']}
@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)