[to #42322933]sentence-similarity

Adding the new task of sentence_similarity, in which the model is the sofa version of structbert
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9016402

    * sbert-sentence-similarity

* [to #42322933] pip8

* merge with master for file dirs  update

* add test cases

* pre-commit lint check

* remove useless file

* download models again~

* skip time consuming test case

* update for pr reviews

* merge with master

* add test level

* reset test level to env level

* [to #42322933] init

* [to #42322933] init

* adding purge logic in test

* merge with head

* change test level

* using sequence classification processor for similarity
This commit is contained in:
zhangzhicheng.zzc
2022-06-15 23:35:12 +08:00
committed by huangjun.hj
parent d983bdfc8e
commit ba471d4492
11 changed files with 260 additions and 11 deletions

View File

@@ -2,4 +2,4 @@
from .base import Model
from .builder import MODELS, build_model
from .nlp import BertForSequenceClassification
from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity

View File

@@ -1,2 +1,3 @@
from .sentence_similarity_model import * # noqa F403
from .sequence_classification_model import * # noqa F403
from .text_generation_model import * # noqa F403

View File

@@ -0,0 +1,88 @@
import os
from typing import Any, Dict
import json
import numpy as np
import torch
from sofa import SbertModel
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
from torch import nn
from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS
__all__ = ['SbertForSentenceSimilarity']
class SbertTextClassifier(SbertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.encoder = SbertModel(config, add_pooling_layer=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids=None, token_type_ids=None):
outputs = self.encoder(
input_ids,
token_type_ids=token_type_ids,
return_dict=None,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
@MODELS.register_module(
Tasks.sentence_similarity,
module_name=r'sbert-base-chinese-sentence-similarity')
class SbertForSentenceSimilarity(Model):
def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the sentence similarity model from the `model_dir` path.
Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
self.model = SbertTextClassifier.from_pretrained(
model_dir, num_labels=2)
self.model.eval()
self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.id2label = {idx: name for name, idx in self.label_mapping.items()}
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model
Args:
input (Dict[str, Any]): the preprocessed data
Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
token_type_ids = torch.tensor(
input['token_type_ids'], dtype=torch.long)
with torch.no_grad():
logits = self.model(input_ids, token_type_ids)
probs = logits.softmax(-1).numpy()
pred = logits.argmax(-1).numpy()
logits = logits.numpy()
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
return res

View File

@@ -15,7 +15,7 @@ from modelscope.utils.logger import get_logger
from .util import is_model_name
Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
InputModel = Union[str, Model]
output_keys = [

View File

@@ -13,6 +13,9 @@ PIPELINES = Registry('pipelines')
DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo)
Tasks.sentence_similarity:
('sbert-base-chinese-sentence-similarity',
'damo/nlp_structbert_sentence-similarity_chinese-base'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),

View File

@@ -1,2 +1,3 @@
from .sentence_similarity_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403

View File

@@ -0,0 +1,65 @@
import os
import uuid
from typing import Any, Dict, Union
import json
import numpy as np
from modelscope.models.nlp import SbertForSentenceSimilarity
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES
__all__ = ['SentenceSimilarityPipeline']
@PIPELINES.register_module(
Tasks.sentence_similarity,
module_name=r'sbert-base-chinese-sentence-similarity')
class SentenceSimilarityPipeline(Pipeline):
def __init__(self,
model: Union[SbertForSentenceSimilarity, str],
preprocessor: SequenceClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction
Args:
model (SbertForSentenceSimilarity): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, SbertForSentenceSimilarity), \
'model must be a single str or SbertForSentenceSimilarity'
sc_model = model if isinstance(
model,
SbertForSentenceSimilarity) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = SequenceClassificationPreprocessor(
sc_model.model_dir,
first_sequence='first_sequence',
second_sequence='second_sequence')
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
assert hasattr(self.model, 'id2label'), \
'id2label map should be initalizaed in init function.'
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results
Args:
inputs (Dict[str, Any]): _description_
Returns:
Dict[str, str]: the prediction results
"""
probs = inputs['probabilities'][0]
num_classes = probs.shape[0]
top_indices = np.argpartition(probs, -num_classes)[-num_classes:]
cls_ids = top_indices[np.argsort(-probs[top_indices], axis=-1)]
probs = probs[cls_ids].tolist()
cls_names = [self.model.id2label[cid] for cid in cls_ids]
b = 0
return {'scores': probs[b], 'labels': cls_names[b]}

View File

@@ -5,4 +5,3 @@ from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .nlp import * # noqa F403
from .nlp import TextGenerationPreprocessor

View File

@@ -10,7 +10,10 @@ from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS
__all__ = ['Tokenize', 'SequenceClassificationPreprocessor']
__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor'
]
@PREPROCESSORS.register_module(Fields.nlp)
@@ -28,7 +31,7 @@ class Tokenize(Preprocessor):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-sentiment-analysis')
Fields.nlp, module_name=r'bert-sequence-classification')
class SequenceClassificationPreprocessor(Preprocessor):
def __init__(self, model_dir: str, *args, **kwargs):
@@ -48,21 +51,42 @@ class SequenceClassificationPreprocessor(Preprocessor):
self.sequence_length = kwargs.pop('sequence_length', 128)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
print(f'this is the tokenzier {self.tokenizer}')
@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
@type_assert(object, (str, tuple))
def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]:
"""process the raw input data
Args:
data (str): a sentence
data (str or tuple):
sentence1 (str): a sentence
Example:
'you are so handsome.'
or
(sentence1, sentence2)
sentence1 (str): a sentence
Example:
'you are so handsome.'
sentence2 (str): a sentence
Example:
'you are so beautiful.'
Returns:
Dict[str, Any]: the preprocessed data
"""
new_data = {self.first_sequence: data}
if not isinstance(data, tuple):
data = (
data,
None,
)
sentence1, sentence2 = data
new_data = {
self.first_sequence: sentence1,
self.second_sequence: sentence2
}
# preprocess the data for the model input
rst = {

View File

@@ -31,6 +31,7 @@ class Tasks(object):
# nlp tasks
sentiment_analysis = 'sentiment-analysis'
sentence_similarity = 'sentence-similarity'
text_classification = 'text-classification'
relation_extraction = 'relation-extraction'
zero_shot = 'zero-shot'

View File

@@ -0,0 +1,67 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import shutil
import unittest
from maas_hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSentenceSimilarity
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level
class SentenceSimilarityTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
sentence1 = '今天气温比昨天高么?'
sentence2 = '今天湿度比昨天高么?'
def setUp(self) -> None:
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)
tokenizer = SequenceClassificationPreprocessor(cache_path)
model = SbertForSentenceSimilarity(cache_path, tokenizer=tokenizer)
pipeline1 = SentenceSimilarityPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.sentence_similarity, model=model, preprocessor=tokenizer)
print('test1')
print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}')
print()
print(
f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = SequenceClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.sentence_similarity,
model=model,
preprocessor=tokenizer)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.sentence_similarity, model=self.model_id)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.sentence_similarity)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
if __name__ == '__main__':
unittest.main()