merge with sentiment_classification

This commit is contained in:
智丞
2022-06-22 11:25:23 +08:00
9 changed files with 302 additions and 5 deletions

View File

@@ -3,3 +3,4 @@ from .nli_model import * # noqa F403
from .palm_for_text_generation import * # noqa F403
from .sbert_for_sentence_similarity import * # noqa F403
from .sbert_for_token_classification import * # noqa F403
from .sentiment_classification_model import * # noqa F403

View File

@@ -0,0 +1,85 @@
import os
from typing import Any, Dict
import numpy as np
import torch
from sofa import SbertConfig, SbertModel
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
from torch import nn
from transformers.activations import ACT2FN, get_activation
from transformers.models.bert.modeling_bert import SequenceClassifierOutput
from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS
__all__ = ['SbertForSentimentClassification']
class SbertTextClassifier(SbertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.encoder = SbertModel(config, add_pooling_layer=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids=None, token_type_ids=None):
outputs = self.encoder(
input_ids,
token_type_ids=token_type_ids,
return_dict=None,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
@MODELS.register_module(
Tasks.sentiment_classification,
module_name=r'sbert-sentiment-classification')
class SbertForSentimentClassification(Model):
def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.
Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
self.model = SbertTextClassifier.from_pretrained(
model_dir, num_labels=2)
self.model.eval()
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model
Args:
input (Dict[str, Any]): the preprocessed data
Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
token_type_ids = torch.tensor(
input['token_type_ids'], dtype=torch.long)
with torch.no_grad():
logits = self.model(input_ids, token_type_ids)
probs = logits.softmax(-1).numpy()
pred = logits.argmax(-1).numpy()
logits = logits.numpy()
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
return res

View File

@@ -22,8 +22,11 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
Tasks.nli: ('nlp_structbert_nli_chinese-base',
'damo/nlp_structbert_nli_chinese-base'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.sentiment_classification:
('sbert-sentiment-classification',
'damo/nlp_structbert_sentiment-classification_chinese-base'),
Tasks.text_classification: ('bert-sentiment-analysis',
'damo/bert-base-sst2'),
Tasks.text_generation: ('palm2.0',
'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'),

View File

@@ -1,5 +1,6 @@
from .nli_pipeline import * # noqa F403
from .sentence_similarity_pipeline import * # noqa F403
from .sentiment_classification_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .word_segmentation_pipeline import * # noqa F403

View File

@@ -0,0 +1,90 @@
import os
import uuid
from typing import Any, Dict, Union
import json
import numpy as np
from modelscope.models.nlp import SbertForSentimentClassification
from modelscope.preprocessors import SentimentClassificationPreprocessor
from modelscope.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES
__all__ = ['SentimentClassificationPipeline']
@PIPELINES.register_module(
Tasks.sentiment_classification,
module_name=r'sbert-sentiment-classification')
class SentimentClassificationPipeline(Pipeline):
def __init__(self,
model: Union[SbertForSentimentClassification, str],
preprocessor: SentimentClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
Args:
model (SbertForSentimentClassification): a model instance
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \
'model must be a single str or SbertForSentimentClassification'
sc_model = model if isinstance(
model,
SbertForSentimentClassification) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = SentimentClassificationPreprocessor(
sc_model.model_dir,
first_sequence='first_sequence',
second_sequence='second_sequence')
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
self.label_path = os.path.join(sc_model.model_dir,
'label_mapping.json')
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.label_id_to_name = {
idx: name
for name, idx in self.label_mapping.items()
}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results
Args:
inputs (Dict[str, Any]): _description_
Returns:
Dict[str, str]: the prediction results
"""
probs = inputs['probabilities']
logits = inputs['logits']
predictions = np.argsort(-probs, axis=-1)
preds = predictions[0]
b = 0
new_result = list()
for pred in preds:
new_result.append({
'pred': self.label_id_to_name[pred],
'prob': float(probs[b][pred]),
'logit': float(logits[b][pred])
})
new_results = list()
new_results.append({
'id':
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
'output':
new_result,
'predictions':
new_result[0]['pred'],
'probabilities':
','.join([str(t) for t in inputs['probabilities'][b]]),
'logits':
','.join([str(t) for t in inputs['logits'][b]])
})
return new_results[0]

View File

@@ -7,5 +7,4 @@ from .common import Compose
from .image import LoadImage, load_image
from .multi_model import OfaImageCaptionPreprocessor
from .nlp import * # noqa F403
from .nlp import NLIPreprocessor, TextGenerationPreprocessor
from .text_to_speech import * # noqa F403

View File

@@ -13,7 +13,7 @@ from .builder import PREPROCESSORS
__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor',
'NLIPreprocessor'
'NLIPreprocessor', 'SentimentClassificationPreprocessor'
]
@@ -65,7 +65,6 @@ class NLIPreprocessor(Preprocessor):
sentence2 (str): a sentence
Example:
'you are so beautiful.'
Returns:
Dict[str, Any]: the preprocessed data
"""
@@ -102,6 +101,70 @@ class NLIPreprocessor(Preprocessor):
return rst
@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'sbert-sentiment-classification')
class SentimentClassificationPreprocessor(Preprocessor):
def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)
@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data
Args:
data (str): a sentence
Example:
'you are so handsome.'
Returns:
Dict[str, Any]: the preprocessed data
"""
new_data = {self.first_sequence: data}
# preprocess the data for the model input
rst = {
'id': [],
'input_ids': [],
'attention_mask': [],
'token_type_ids': []
}
max_seq_length = self.sequence_length
text_a = new_data[self.first_sequence]
text_b = new_data.get(self.second_sequence, None)
feature = self.tokenizer(
text_a,
text_b,
padding='max_length',
truncation=True,
max_length=max_seq_length)
rst['id'].append(new_data.get('id', str(uuid.uuid4())))
rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])
return rst
@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-sequence-classification')
class SequenceClassificationPreprocessor(Preprocessor):

View File

@@ -33,6 +33,7 @@ class Tasks(object):
# nlp tasks
word_segmentation = 'word-segmentation'
nli = 'nli'
sentiment_classification = 'sentiment-classification'
sentiment_analysis = 'sentiment-analysis'
sentence_similarity = 'sentence-similarity'
text_classification = 'text-classification'

View File

@@ -0,0 +1,54 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from maas_hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSentimentClassification
from modelscope.pipelines import SentimentClassificationPipeline, pipeline
from modelscope.preprocessors import SentimentClassificationPreprocessor
from modelscope.utils.constant import Tasks
class SentimentClassificationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
sentence1 = '启动的时候很大声音然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音'
def test_run_from_local(self):
cache_path = snapshot_download(self.model_id)
tokenizer = SentimentClassificationPreprocessor(cache_path)
model = SbertForSentimentClassification(
cache_path, tokenizer=tokenizer)
pipeline1 = SentimentClassificationPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.sentiment_classification,
model=model,
preprocessor=tokenizer)
print(f'sentence1: {self.sentence1}\n'
f'pipeline1:{pipeline1(input=self.sentence1)}')
print()
print(f'sentence1: {self.sentence1}\n'
f'pipeline1: {pipeline2(input=self.sentence1)}')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = SentimentClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.sentiment_classification,
model=model,
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence1))
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.sentiment_classification, model=self.model_id)
print(pipeline_ins(input=self.sentence1))
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.sentiment_classification)
print(pipeline_ins(input=self.sentence1))
if __name__ == '__main__':
unittest.main()