mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
solve comment: 1. change MaskedLMModelBase to MaskedLanguageModelBase 2. remove a useless import
This commit is contained in:
@@ -7,10 +7,10 @@ from ...utils.constant import Tasks
|
||||
from ..base import Model, Tensor
|
||||
from ..builder import MODELS
|
||||
|
||||
__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM', 'MaskedLMModelBase']
|
||||
__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM', 'MaskedLanguageModelBase']
|
||||
|
||||
|
||||
class MaskedLMModelBase(Model):
|
||||
class MaskedLanguageModelBase(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
@@ -48,7 +48,7 @@ class MaskedLMModelBase(Model):
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)
|
||||
class StructBertForMaskedLM(MaskedLMModelBase):
|
||||
class StructBertForMaskedLM(MaskedLanguageModelBase):
|
||||
|
||||
def build_model(self):
|
||||
from sofa import SbertForMaskedLM
|
||||
@@ -56,7 +56,7 @@ class StructBertForMaskedLM(MaskedLMModelBase):
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco)
|
||||
class VecoForMaskedLM(MaskedLMModelBase):
|
||||
class VecoForMaskedLM(MaskedLanguageModelBase):
|
||||
|
||||
def build_model(self):
|
||||
from sofa import VecoForMaskedLM
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from ...metainfo import Pipelines
|
||||
from ...models import Model
|
||||
from ...models.nlp.masked_language_model import MaskedLMModelBase
|
||||
from ...models.nlp.masked_language_model import MaskedLanguageModelBase
|
||||
from ...preprocessors import FillMaskPreprocessor
|
||||
from ...utils.constant import Tasks
|
||||
from ..base import Pipeline, Tensor
|
||||
@@ -17,18 +17,18 @@ __all__ = ['FillMaskPipeline']
|
||||
class FillMaskPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[MaskedLMModelBase, str],
|
||||
model: Union[MaskedLanguageModelBase, str],
|
||||
preprocessor: Optional[FillMaskPreprocessor] = None,
|
||||
first_sequence='sentense',
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction
|
||||
|
||||
Args:
|
||||
model (MaskedLMModelBase): a model instance
|
||||
model (MaskedLanguageModelBase): a model instance
|
||||
preprocessor (FillMaskPreprocessor): a preprocessor instance
|
||||
"""
|
||||
fill_mask_model = model if isinstance(
|
||||
model, MaskedLMModelBase) else Model.from_pretrained(model)
|
||||
model, MaskedLanguageModelBase) else Model.from_pretrained(model)
|
||||
assert fill_mask_model.config is not None
|
||||
|
||||
if preprocessor is None:
|
||||
|
||||
Reference in New Issue
Block a user