solve comment: 1. change MaskedLMModelBase to MaskedLanguageModelBase 2. remove a useless import

This commit is contained in:
雨泓
2022-06-23 14:51:46 +08:00
parent aa0cebb3ec
commit c376d59143
2 changed files with 8 additions and 8 deletions

View File

@@ -7,10 +7,10 @@ from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS
__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM', 'MaskedLMModelBase']
__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM', 'MaskedLanguageModelBase']
class MaskedLMModelBase(Model):
class MaskedLanguageModelBase(Model):
def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
@@ -48,7 +48,7 @@ class MaskedLMModelBase(Model):
@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)
class StructBertForMaskedLM(MaskedLMModelBase):
class StructBertForMaskedLM(MaskedLanguageModelBase):
def build_model(self):
from sofa import SbertForMaskedLM
@@ -56,7 +56,7 @@ class StructBertForMaskedLM(MaskedLMModelBase):
@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco)
class VecoForMaskedLM(MaskedLMModelBase):
class VecoForMaskedLM(MaskedLanguageModelBase):
def build_model(self):
from sofa import VecoForMaskedLM

View File

@@ -4,7 +4,7 @@ import torch
from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp.masked_language_model import MaskedLMModelBase
from ...models.nlp.masked_language_model import MaskedLanguageModelBase
from ...preprocessors import FillMaskPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline, Tensor
@@ -17,18 +17,18 @@ __all__ = ['FillMaskPipeline']
class FillMaskPipeline(Pipeline):
def __init__(self,
model: Union[MaskedLMModelBase, str],
model: Union[MaskedLanguageModelBase, str],
preprocessor: Optional[FillMaskPreprocessor] = None,
first_sequence='sentense',
**kwargs):
"""use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction
Args:
model (MaskedLMModelBase): a model instance
model (MaskedLanguageModelBase): a model instance
preprocessor (FillMaskPreprocessor): a preprocessor instance
"""
fill_mask_model = model if isinstance(
model, MaskedLMModelBase) else Model.from_pretrained(model)
model, MaskedLanguageModelBase) else Model.from_pretrained(model)
assert fill_mask_model.config is not None
if preprocessor is None: