mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
[to #42322933]sentence-similarity
Adding the new task of sentence_similarity, in which the model is the sofa version of structbert
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9016402
* sbert-sentence-similarity
* [to #42322933] pip8
* merge with master for file dirs update
* add test cases
* pre-commit lint check
* remove useless file
* download models again~
* skip time consuming test case
* update for pr reviews
* merge with master
* add test level
* reset test level to env level
* [to #42322933] init
* [to #42322933] init
* adding purge logic in test
* merge with head
* change test level
* using sequence classification processor for similarity
This commit is contained in:
committed by
huangjun.hj
parent
d983bdfc8e
commit
ba471d4492
@@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from .base import Model
|
from .base import Model
|
||||||
from .builder import MODELS, build_model
|
from .builder import MODELS, build_model
|
||||||
from .nlp import BertForSequenceClassification
|
from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
|
from .sentence_similarity_model import * # noqa F403
|
||||||
from .sequence_classification_model import * # noqa F403
|
from .sequence_classification_model import * # noqa F403
|
||||||
from .text_generation_model import * # noqa F403
|
from .text_generation_model import * # noqa F403
|
||||||
|
|||||||
88
modelscope/models/nlp/sentence_similarity_model.py
Normal file
88
modelscope/models/nlp/sentence_similarity_model.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from sofa import SbertModel
|
||||||
|
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from ..base import Model, Tensor
|
||||||
|
from ..builder import MODELS
|
||||||
|
|
||||||
|
__all__ = ['SbertForSentenceSimilarity']
|
||||||
|
|
||||||
|
|
||||||
|
class SbertTextClassifier(SbertPreTrainedModel):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.config = config
|
||||||
|
self.encoder = SbertModel(config, add_pooling_layer=True)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
def forward(self, input_ids=None, token_type_ids=None):
|
||||||
|
outputs = self.encoder(
|
||||||
|
input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
return_dict=None,
|
||||||
|
)
|
||||||
|
pooled_output = outputs[1]
|
||||||
|
pooled_output = self.dropout(pooled_output)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(
|
||||||
|
Tasks.sentence_similarity,
|
||||||
|
module_name=r'sbert-base-chinese-sentence-similarity')
|
||||||
|
class SbertForSentenceSimilarity(Model):
|
||||||
|
|
||||||
|
def __init__(self, model_dir: str, *args, **kwargs):
|
||||||
|
"""initialize the sentence similarity model from the `model_dir` path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir (str): the model path.
|
||||||
|
model_cls (Optional[Any], optional): model loader, if None, use the
|
||||||
|
default loader to load model weights, by default None.
|
||||||
|
"""
|
||||||
|
super().__init__(model_dir, *args, **kwargs)
|
||||||
|
self.model_dir = model_dir
|
||||||
|
|
||||||
|
self.model = SbertTextClassifier.from_pretrained(
|
||||||
|
model_dir, num_labels=2)
|
||||||
|
self.model.eval()
|
||||||
|
self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
|
||||||
|
with open(self.label_path) as f:
|
||||||
|
self.label_mapping = json.load(f)
|
||||||
|
self.id2label = {idx: name for name, idx in self.label_mapping.items()}
|
||||||
|
|
||||||
|
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||||
|
"""return the result by the model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Dict[str, Any]): the preprocessed data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, np.ndarray]: results
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
'predictions': array([1]), # lable 0-negative 1-positive
|
||||||
|
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
|
||||||
|
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
|
||||||
|
token_type_ids = torch.tensor(
|
||||||
|
input['token_type_ids'], dtype=torch.long)
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(input_ids, token_type_ids)
|
||||||
|
probs = logits.softmax(-1).numpy()
|
||||||
|
pred = logits.argmax(-1).numpy()
|
||||||
|
logits = logits.numpy()
|
||||||
|
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
|
||||||
|
return res
|
||||||
@@ -15,7 +15,7 @@ from modelscope.utils.logger import get_logger
|
|||||||
from .util import is_model_name
|
from .util import is_model_name
|
||||||
|
|
||||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||||
Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
|
Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
|
||||||
InputModel = Union[str, Model]
|
InputModel = Union[str, Model]
|
||||||
|
|
||||||
output_keys = [
|
output_keys = [
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ PIPELINES = Registry('pipelines')
|
|||||||
|
|
||||||
DEFAULT_MODEL_FOR_PIPELINE = {
|
DEFAULT_MODEL_FOR_PIPELINE = {
|
||||||
# TaskName: (pipeline_module_name, model_repo)
|
# TaskName: (pipeline_module_name, model_repo)
|
||||||
|
Tasks.sentence_similarity:
|
||||||
|
('sbert-base-chinese-sentence-similarity',
|
||||||
|
'damo/nlp_structbert_sentence-similarity_chinese-base'),
|
||||||
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'),
|
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'),
|
||||||
Tasks.text_classification:
|
Tasks.text_classification:
|
||||||
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
|
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
|
from .sentence_similarity_pipeline import * # noqa F403
|
||||||
from .sequence_classification_pipeline import * # noqa F403
|
from .sequence_classification_pipeline import * # noqa F403
|
||||||
from .text_generation_pipeline import * # noqa F403
|
from .text_generation_pipeline import * # noqa F403
|
||||||
|
|||||||
65
modelscope/pipelines/nlp/sentence_similarity_pipeline.py
Normal file
65
modelscope/pipelines/nlp/sentence_similarity_pipeline.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modelscope.models.nlp import SbertForSentenceSimilarity
|
||||||
|
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from ...models import Model
|
||||||
|
from ..base import Input, Pipeline
|
||||||
|
from ..builder import PIPELINES
|
||||||
|
|
||||||
|
__all__ = ['SentenceSimilarityPipeline']
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module(
|
||||||
|
Tasks.sentence_similarity,
|
||||||
|
module_name=r'sbert-base-chinese-sentence-similarity')
|
||||||
|
class SentenceSimilarityPipeline(Pipeline):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model: Union[SbertForSentenceSimilarity, str],
|
||||||
|
preprocessor: SequenceClassificationPreprocessor = None,
|
||||||
|
**kwargs):
|
||||||
|
"""use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (SbertForSentenceSimilarity): a model instance
|
||||||
|
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
|
||||||
|
"""
|
||||||
|
assert isinstance(model, str) or isinstance(model, SbertForSentenceSimilarity), \
|
||||||
|
'model must be a single str or SbertForSentenceSimilarity'
|
||||||
|
sc_model = model if isinstance(
|
||||||
|
model,
|
||||||
|
SbertForSentenceSimilarity) else Model.from_pretrained(model)
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = SequenceClassificationPreprocessor(
|
||||||
|
sc_model.model_dir,
|
||||||
|
first_sequence='first_sequence',
|
||||||
|
second_sequence='second_sequence')
|
||||||
|
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||||
|
|
||||||
|
assert hasattr(self.model, 'id2label'), \
|
||||||
|
'id2label map should be initalizaed in init function.'
|
||||||
|
|
||||||
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""process the prediction results
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (Dict[str, Any]): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: the prediction results
|
||||||
|
"""
|
||||||
|
|
||||||
|
probs = inputs['probabilities'][0]
|
||||||
|
num_classes = probs.shape[0]
|
||||||
|
top_indices = np.argpartition(probs, -num_classes)[-num_classes:]
|
||||||
|
cls_ids = top_indices[np.argsort(-probs[top_indices], axis=-1)]
|
||||||
|
probs = probs[cls_ids].tolist()
|
||||||
|
cls_names = [self.model.id2label[cid] for cid in cls_ids]
|
||||||
|
b = 0
|
||||||
|
return {'scores': probs[b], 'labels': cls_names[b]}
|
||||||
@@ -5,4 +5,3 @@ from .builder import PREPROCESSORS, build_preprocessor
|
|||||||
from .common import Compose
|
from .common import Compose
|
||||||
from .image import LoadImage, load_image
|
from .image import LoadImage, load_image
|
||||||
from .nlp import * # noqa F403
|
from .nlp import * # noqa F403
|
||||||
from .nlp import TextGenerationPreprocessor
|
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ from modelscope.utils.type_assert import type_assert
|
|||||||
from .base import Preprocessor
|
from .base import Preprocessor
|
||||||
from .builder import PREPROCESSORS
|
from .builder import PREPROCESSORS
|
||||||
|
|
||||||
__all__ = ['Tokenize', 'SequenceClassificationPreprocessor']
|
__all__ = [
|
||||||
|
'Tokenize', 'SequenceClassificationPreprocessor',
|
||||||
|
'TextGenerationPreprocessor'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@PREPROCESSORS.register_module(Fields.nlp)
|
@PREPROCESSORS.register_module(Fields.nlp)
|
||||||
@@ -28,7 +31,7 @@ class Tokenize(Preprocessor):
|
|||||||
|
|
||||||
|
|
||||||
@PREPROCESSORS.register_module(
|
@PREPROCESSORS.register_module(
|
||||||
Fields.nlp, module_name=r'bert-sentiment-analysis')
|
Fields.nlp, module_name=r'bert-sequence-classification')
|
||||||
class SequenceClassificationPreprocessor(Preprocessor):
|
class SequenceClassificationPreprocessor(Preprocessor):
|
||||||
|
|
||||||
def __init__(self, model_dir: str, *args, **kwargs):
|
def __init__(self, model_dir: str, *args, **kwargs):
|
||||||
@@ -48,21 +51,42 @@ class SequenceClassificationPreprocessor(Preprocessor):
|
|||||||
self.sequence_length = kwargs.pop('sequence_length', 128)
|
self.sequence_length = kwargs.pop('sequence_length', 128)
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
||||||
|
print(f'this is the tokenzier {self.tokenizer}')
|
||||||
|
|
||||||
@type_assert(object, str)
|
@type_assert(object, (str, tuple))
|
||||||
def __call__(self, data: str) -> Dict[str, Any]:
|
def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]:
|
||||||
"""process the raw input data
|
"""process the raw input data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (str): a sentence
|
data (str or tuple):
|
||||||
Example:
|
sentence1 (str): a sentence
|
||||||
'you are so handsome.'
|
Example:
|
||||||
|
'you are so handsome.'
|
||||||
|
or
|
||||||
|
(sentence1, sentence2)
|
||||||
|
sentence1 (str): a sentence
|
||||||
|
Example:
|
||||||
|
'you are so handsome.'
|
||||||
|
sentence2 (str): a sentence
|
||||||
|
Example:
|
||||||
|
'you are so beautiful.'
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: the preprocessed data
|
Dict[str, Any]: the preprocessed data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
new_data = {self.first_sequence: data}
|
if not isinstance(data, tuple):
|
||||||
|
data = (
|
||||||
|
data,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sentence1, sentence2 = data
|
||||||
|
new_data = {
|
||||||
|
self.first_sequence: sentence1,
|
||||||
|
self.second_sequence: sentence2
|
||||||
|
}
|
||||||
|
|
||||||
# preprocess the data for the model input
|
# preprocess the data for the model input
|
||||||
|
|
||||||
rst = {
|
rst = {
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class Tasks(object):
|
|||||||
|
|
||||||
# nlp tasks
|
# nlp tasks
|
||||||
sentiment_analysis = 'sentiment-analysis'
|
sentiment_analysis = 'sentiment-analysis'
|
||||||
|
sentence_similarity = 'sentence-similarity'
|
||||||
text_classification = 'text-classification'
|
text_classification = 'text-classification'
|
||||||
relation_extraction = 'relation-extraction'
|
relation_extraction = 'relation-extraction'
|
||||||
zero_shot = 'zero-shot'
|
zero_shot = 'zero-shot'
|
||||||
|
|||||||
67
tests/pipelines/test_sentence_similarity.py
Normal file
67
tests/pipelines/test_sentence_similarity.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import shutil
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from maas_hub.snapshot_download import snapshot_download
|
||||||
|
|
||||||
|
from modelscope.models import Model
|
||||||
|
from modelscope.models.nlp import SbertForSentenceSimilarity
|
||||||
|
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline
|
||||||
|
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.hub import get_model_cache_dir
|
||||||
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceSimilarityTest(unittest.TestCase):
|
||||||
|
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
|
||||||
|
sentence1 = '今天气温比昨天高么?'
|
||||||
|
sentence2 = '今天湿度比昨天高么?'
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
# switch to False if downloading everytime is not desired
|
||||||
|
purge_cache = True
|
||||||
|
if purge_cache:
|
||||||
|
shutil.rmtree(
|
||||||
|
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run(self):
|
||||||
|
cache_path = snapshot_download(self.model_id)
|
||||||
|
tokenizer = SequenceClassificationPreprocessor(cache_path)
|
||||||
|
model = SbertForSentenceSimilarity(cache_path, tokenizer=tokenizer)
|
||||||
|
pipeline1 = SentenceSimilarityPipeline(model, preprocessor=tokenizer)
|
||||||
|
pipeline2 = pipeline(
|
||||||
|
Tasks.sentence_similarity, model=model, preprocessor=tokenizer)
|
||||||
|
print('test1')
|
||||||
|
print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
|
||||||
|
f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}')
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
|
||||||
|
f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}')
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run_with_model_from_modelhub(self):
|
||||||
|
model = Model.from_pretrained(self.model_id)
|
||||||
|
tokenizer = SequenceClassificationPreprocessor(model.model_dir)
|
||||||
|
pipeline_ins = pipeline(
|
||||||
|
task=Tasks.sentence_similarity,
|
||||||
|
model=model,
|
||||||
|
preprocessor=tokenizer)
|
||||||
|
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run_with_model_name(self):
|
||||||
|
pipeline_ins = pipeline(
|
||||||
|
task=Tasks.sentence_similarity, model=self.model_id)
|
||||||
|
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run_with_default_model(self):
|
||||||
|
pipeline_ins = pipeline(task=Tasks.sentence_similarity)
|
||||||
|
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user