mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
merge with fill_mask
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .bert_for_sequence_classification import * # noqa F403
|
||||
from .masked_language_model import * # noqa F403
|
||||
from .nli_model import * # noqa F403
|
||||
from .palm_for_text_generation import * # noqa F403
|
||||
from .sbert_for_sentence_similarity import * # noqa F403
|
||||
|
||||
50
modelscope/models/nlp/masked_language_model.py
Normal file
50
modelscope/models/nlp/masked_language_model.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils.constant import Tasks
|
||||
from ..base import Model, Tensor
|
||||
from ..builder import MODELS
|
||||
|
||||
__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM']
|
||||
|
||||
|
||||
class AliceMindBaseForMaskedLM(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
from sofa.utils.backend import AutoConfig, AutoModelForMaskedLM
|
||||
self.model_dir = model_dir
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
self.config = AutoConfig.from_pretrained(model_dir)
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(
|
||||
model_dir, config=self.config)
|
||||
|
||||
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
input (Dict[str, Any]): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: results
|
||||
"""
|
||||
rst = self.model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
token_type_ids=inputs['token_type_ids'])
|
||||
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']}
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.fill_mask, module_name=r'sbert')
|
||||
class StructBertForMaskedLM(AliceMindBaseForMaskedLM):
|
||||
# The StructBert for MaskedLM uses the same underlying model structure
|
||||
# as the base model class.
|
||||
pass
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.fill_mask, module_name=r'veco')
|
||||
class VecoForMaskedLM(AliceMindBaseForMaskedLM):
|
||||
# The Veco for MaskedLM uses the same underlying model structure
|
||||
# as the base model class.
|
||||
pass
|
||||
@@ -38,6 +38,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_unet_person-image-cartoon_compound-models'),
|
||||
Tasks.ocr_detection: ('ocr-detection',
|
||||
'damo/cv_resnet18_ocr-detection-line-level_damo'),
|
||||
Tasks.fill_mask: ('veco', 'damo/nlp_veco_fill-mask_large')
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .fill_mask_pipeline import * # noqa F403
|
||||
from .nli_pipeline import * # noqa F403
|
||||
from .sentence_similarity_pipeline import * # noqa F403
|
||||
from .sentiment_classification_pipeline import * # noqa F403
|
||||
|
||||
93
modelscope/pipelines/nlp/fill_mask_pipeline.py
Normal file
93
modelscope/pipelines/nlp/fill_mask_pipeline.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp.masked_language_model import \
|
||||
AliceMindBaseForMaskedLM
|
||||
from modelscope.preprocessors import FillMaskPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ..base import Pipeline, Tensor
|
||||
from ..builder import PIPELINES
|
||||
|
||||
__all__ = ['FillMaskPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(Tasks.fill_mask, module_name=r'sbert')
|
||||
@PIPELINES.register_module(Tasks.fill_mask, module_name=r'veco')
|
||||
class FillMaskPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[AliceMindBaseForMaskedLM, str],
|
||||
preprocessor: Optional[FillMaskPreprocessor] = None,
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction
|
||||
|
||||
Args:
|
||||
model (AliceMindBaseForMaskedLM): a model instance
|
||||
preprocessor (FillMaskPreprocessor): a preprocessor instance
|
||||
"""
|
||||
fill_mask_model = model if isinstance(
|
||||
model, AliceMindBaseForMaskedLM) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = FillMaskPreprocessor(
|
||||
fill_mask_model.model_dir,
|
||||
first_sequence='sentence',
|
||||
second_sequence=None)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.preprocessor = preprocessor
|
||||
self.tokenizer = preprocessor.tokenizer
|
||||
self.mask_id = {'veco': 250001, 'sbert': 103}
|
||||
|
||||
self.rep_map = {
|
||||
'sbert': {
|
||||
'[unused0]': '',
|
||||
'[PAD]': '',
|
||||
'[unused1]': '',
|
||||
r' +': ' ',
|
||||
'[SEP]': '',
|
||||
'[unused2]': '',
|
||||
'[CLS]': '',
|
||||
'[UNK]': ''
|
||||
},
|
||||
'veco': {
|
||||
r' +': ' ',
|
||||
'<mask>': '<q>',
|
||||
'<pad>': '',
|
||||
'<s>': '',
|
||||
'</s>': '',
|
||||
'<unk>': ' '
|
||||
}
|
||||
}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): _description_
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
import numpy as np
|
||||
logits = inputs['logits'].detach().numpy()
|
||||
input_ids = inputs['input_ids'].detach().numpy()
|
||||
pred_ids = np.argmax(logits, axis=-1)
|
||||
model_type = self.model.config.model_type
|
||||
rst_ids = np.where(input_ids == self.mask_id[model_type], pred_ids,
|
||||
input_ids)
|
||||
|
||||
def rep_tokens(string, rep_map):
|
||||
for k, v in rep_map.items():
|
||||
string = string.replace(k, v)
|
||||
return string.strip()
|
||||
|
||||
pred_strings = []
|
||||
for ids in rst_ids: # batch
|
||||
if self.model.config.vocab_size == 21128: # zh bert
|
||||
pred_string = self.tokenizer.convert_ids_to_tokens(ids)
|
||||
pred_string = ''.join(pred_string)
|
||||
else:
|
||||
pred_string = self.tokenizer.decode(ids)
|
||||
pred_string = rep_tokens(pred_string, self.rep_map[model_type])
|
||||
pred_strings.append(pred_string)
|
||||
|
||||
return {'text': pred_strings}
|
||||
@@ -76,6 +76,12 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.text_generation: ['text'],
|
||||
|
||||
# fill mask result for single sample
|
||||
# {
|
||||
# "text": "this is the text which masks filled by model."
|
||||
# }
|
||||
Tasks.fill_mask: ['text'],
|
||||
|
||||
# word segmentation result for single sample
|
||||
# {
|
||||
# "output": "今天 天气 不错 , 适合 出去 游玩"
|
||||
|
||||
@@ -14,7 +14,7 @@ __all__ = [
|
||||
'Tokenize', 'SequenceClassificationPreprocessor',
|
||||
'TextGenerationPreprocessor', 'ZeroShotClassificationPreprocessor',
|
||||
'TokenClassifcationPreprocessor', 'NLIPreprocessor',
|
||||
'SentimentClassificationPreprocessor'
|
||||
'SentimentClassificationPreprocessor', 'FillMaskPreprocessor'
|
||||
]
|
||||
|
||||
|
||||
@@ -311,6 +311,61 @@ class TextGenerationPreprocessor(Preprocessor):
|
||||
|
||||
rst['input_ids'].append(feature['input_ids'])
|
||||
rst['attention_mask'].append(feature['attention_mask'])
|
||||
rst['token_type_ids'].append(feature['token_type_ids'])
|
||||
return {k: torch.tensor(v) for k, v in rst.items()}
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(Fields.nlp)
|
||||
class FillMaskPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
from sofa.utils.backend import AutoTokenizer
|
||||
self.model_dir = model_dir
|
||||
self.first_sequence: str = kwargs.pop('first_sequence',
|
||||
'first_sequence')
|
||||
self.sequence_length = kwargs.pop('sequence_length', 128)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir, use_fast=False)
|
||||
|
||||
@type_assert(object, str)
|
||||
def __call__(self, data: str) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
Args:
|
||||
data (str): a sentence
|
||||
Example:
|
||||
'you are so handsome.'
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
import torch
|
||||
|
||||
new_data = {self.first_sequence: data}
|
||||
# preprocess the data for the model input
|
||||
|
||||
rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
|
||||
|
||||
max_seq_length = self.sequence_length
|
||||
|
||||
text_a = new_data[self.first_sequence]
|
||||
feature = self.tokenizer(
|
||||
text_a,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=max_seq_length,
|
||||
return_token_type_ids=True)
|
||||
|
||||
rst['input_ids'].append(feature['input_ids'])
|
||||
rst['attention_mask'].append(feature['attention_mask'])
|
||||
rst['token_type_ids'].append(feature['token_type_ids'])
|
||||
|
||||
return {k: torch.tensor(v) for k, v in rst.items()}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class Tasks(object):
|
||||
table_question_answering = 'table-question-answering'
|
||||
feature_extraction = 'feature-extraction'
|
||||
sentence_similarity = 'sentence-similarity'
|
||||
fill_mask = 'fill-mask '
|
||||
fill_mask = 'fill-mask'
|
||||
summarization = 'summarization'
|
||||
question_answering = 'question-answering'
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.2-py3-none-any.whl
|
||||
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl
|
||||
|
||||
133
tests/pipelines/test_fill_mask.py
Normal file
133
tests/pipelines/test_fill_mask.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import StructBertForMaskedLM, VecoForMaskedLM
|
||||
from modelscope.pipelines import FillMaskPipeline, pipeline
|
||||
from modelscope.preprocessors import FillMaskPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class FillMaskTest(unittest.TestCase):
|
||||
model_id_sbert = {
|
||||
'zh': 'damo/nlp_structbert_fill-mask-chinese_large',
|
||||
'en': 'damo/nlp_structbert_fill-mask-english_large'
|
||||
}
|
||||
model_id_veco = 'damo/nlp_veco_fill-mask_large'
|
||||
|
||||
ori_texts = {
|
||||
'zh':
|
||||
'段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。'
|
||||
'你师父差得动你,你师父可差不动我。',
|
||||
'en':
|
||||
'Everything in what you call reality is really just a reflection of your '
|
||||
'consciousness. Your whole universe is just a mirror reflection of your story.'
|
||||
}
|
||||
|
||||
test_inputs = {
|
||||
'zh':
|
||||
'段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你'
|
||||
'师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。',
|
||||
'en':
|
||||
'Everything in [MASK] you call reality is really [MASK] a reflection of your '
|
||||
'[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.'
|
||||
}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
# sbert
|
||||
for language in ['zh', 'en']:
|
||||
model_dir = snapshot_download(self.model_id_sbert[language])
|
||||
preprocessor = FillMaskPreprocessor(
|
||||
model_dir, first_sequence='sentence', second_sequence=None)
|
||||
model = StructBertForMaskedLM(model_dir)
|
||||
pipeline1 = FillMaskPipeline(model, preprocessor)
|
||||
pipeline2 = pipeline(
|
||||
Tasks.fill_mask, model=model, preprocessor=preprocessor)
|
||||
ori_text = self.ori_texts[language]
|
||||
test_input = self.test_inputs[language]
|
||||
print(
|
||||
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
|
||||
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n'
|
||||
)
|
||||
|
||||
# veco
|
||||
model_dir = snapshot_download(self.model_id_veco)
|
||||
preprocessor = FillMaskPreprocessor(
|
||||
model_dir, first_sequence='sentence', second_sequence=None)
|
||||
model = VecoForMaskedLM(model_dir)
|
||||
pipeline1 = FillMaskPipeline(model, preprocessor)
|
||||
pipeline2 = pipeline(
|
||||
Tasks.fill_mask, model=model, preprocessor=preprocessor)
|
||||
for language in ['zh', 'en']:
|
||||
ori_text = self.ori_texts[language]
|
||||
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
|
||||
print(
|
||||
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
|
||||
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n'
|
||||
)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
# sbert
|
||||
for language in ['zh', 'en']:
|
||||
print(self.model_id_sbert[language])
|
||||
model = Model.from_pretrained(self.model_id_sbert[language])
|
||||
preprocessor = FillMaskPreprocessor(
|
||||
model.model_dir,
|
||||
first_sequence='sentence',
|
||||
second_sequence=None)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.fill_mask, model=model, preprocessor=preprocessor)
|
||||
print(
|
||||
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
|
||||
f'{pipeline_ins(self.test_inputs[language])}\n')
|
||||
|
||||
# veco
|
||||
model = Model.from_pretrained(self.model_id_veco)
|
||||
preprocessor = FillMaskPreprocessor(
|
||||
model.model_dir, first_sequence='sentence', second_sequence=None)
|
||||
pipeline_ins = pipeline(
|
||||
Tasks.fill_mask, model=model, preprocessor=preprocessor)
|
||||
for language in ['zh', 'en']:
|
||||
ori_text = self.ori_texts[language]
|
||||
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
|
||||
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
|
||||
f'{pipeline_ins(test_input)}\n')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
# veco
|
||||
pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_veco)
|
||||
for language in ['zh', 'en']:
|
||||
ori_text = self.ori_texts[language]
|
||||
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
|
||||
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
|
||||
f'{pipeline_ins(test_input)}\n')
|
||||
|
||||
# structBert
|
||||
language = 'zh'
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.fill_mask, model=self.model_id_sbert[language])
|
||||
print(
|
||||
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
|
||||
f'{pipeline_ins(self.test_inputs[language])}\n')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(task=Tasks.fill_mask)
|
||||
language = 'en'
|
||||
ori_text = self.ori_texts[language]
|
||||
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
|
||||
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
|
||||
f'{pipeline_ins(test_input)}\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user