mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
revert a mis modification
This commit is contained in:
@@ -31,20 +31,20 @@ class MaskedLMModelBase(Model):
|
||||
return self.model.config
|
||||
return None
|
||||
|
||||
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): the preprocessed data
|
||||
input (Dict[str, Any]): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: results
|
||||
"""
|
||||
rst = self.model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
token_type_ids=inputs['token_type_ids'])
|
||||
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']}
|
||||
input_ids=input['input_ids'],
|
||||
attention_mask=input['attention_mask'],
|
||||
token_type_ids=input['token_type_ids'])
|
||||
return {'logits': rst['logits'], 'input_ids': input['input_ids']}
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)
|
||||
|
||||
Reference in New Issue
Block a user