From cf112dbbf7a2781fc9a7ba7f7b87b59dd40ad952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=80=9D=E5=AE=8F?= Date: Tue, 14 Jun 2022 11:45:05 +0800 Subject: [PATCH 1/3] [to #42322933] init --- modelscope/models/nlp/__init__.py | 1 + .../nlp/sentiment_classification_model.py | 85 ++++++++++++++++++ modelscope/pipelines/builder.py | 3 + modelscope/pipelines/nlp/__init__.py | 1 + .../nlp/sentiment_classification_pipeline.py | 90 +++++++++++++++++++ modelscope/preprocessors/__init__.py | 3 +- modelscope/preprocessors/nlp.py | 65 ++++++++++++++ modelscope/utils/constant.py | 1 + 8 files changed, 248 insertions(+), 1 deletion(-) create mode 100644 modelscope/models/nlp/sentiment_classification_model.py create mode 100644 modelscope/pipelines/nlp/sentiment_classification_pipeline.py diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index b2a1d43b..207f7065 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -1,2 +1,3 @@ +from .sentiment_classification_model import * # noqa F403 from .sequence_classification_model import * # noqa F403 from .text_generation_model import * # noqa F403 diff --git a/modelscope/models/nlp/sentiment_classification_model.py b/modelscope/models/nlp/sentiment_classification_model.py new file mode 100644 index 00000000..fb32aff2 --- /dev/null +++ b/modelscope/models/nlp/sentiment_classification_model.py @@ -0,0 +1,85 @@ +import os +from typing import Any, Dict + +import numpy as np +import torch +from sofa import SbertConfig, SbertModel +from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel +from torch import nn +from transformers.activations import ACT2FN, get_activation +from transformers.models.bert.modeling_bert import SequenceClassifierOutput + +from modelscope.utils.constant import Tasks +from ..base import Model, Tensor +from ..builder import MODELS + +__all__ = ['SbertForSentimentClassification'] + + +class SbertTextClassifier(SbertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.encoder = SbertModel(config, add_pooling_layer=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, input_ids=None, token_type_ids=None): + outputs = self.encoder( + input_ids, + token_type_ids=token_type_ids, + return_dict=None, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + + +@MODELS.register_module( + Tasks.sentiment_classification, + module_name=r'sbert-sentiment-classification') +class SbertForSentimentClassification(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + + self.model = SbertTextClassifier.from_pretrained( + model_dir, num_labels=3) + self.model.eval() + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + input_ids = torch.tensor(input['input_ids'], dtype=torch.long) + token_type_ids = torch.tensor( + input['token_type_ids'], dtype=torch.long) + with torch.no_grad(): + logits = self.model(input_ids, token_type_ids) + probs = logits.softmax(-1).numpy() + pred = logits.argmax(-1).numpy() + logits = logits.numpy() + res = {'predictions': pred, 'probabilities': probs, 'logits': logits} + return res diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 6495a5db..5957a367 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -18,6 +18,9 @@ PIPELINES = Registry('pipelines') DEFAULT_MODEL_FOR_PIPELINE = { # TaskName: (pipeline_module_name, model_repo) Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), + Tasks.sentiment_classification: + ('sbert-sentiment-classification', + 'damo/nlp_structbert_sentiment-classification_chinese-base'), Tasks.text_classification: ('bert-sentiment-analysis', 'damo/bert-base-sst2'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 3dbbc1bb..677c097f 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -1,2 +1,3 @@ +from .sentiment_classification_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/nlp/sentiment_classification_pipeline.py b/modelscope/pipelines/nlp/sentiment_classification_pipeline.py new file mode 100644 index 00000000..818c792d --- /dev/null +++ b/modelscope/pipelines/nlp/sentiment_classification_pipeline.py @@ -0,0 +1,90 @@ +import os +import uuid +from typing import Any, Dict, Union + +import json +import numpy as np + +from modelscope.models.nlp import SbertForSentimentClassification +from modelscope.preprocessors import SentimentClassificationPreprocessor +from modelscope.utils.constant import Tasks +from ...models import Model +from ..base import Input, Pipeline +from ..builder import PIPELINES + +__all__ = ['SentimentClassificationPipeline'] + + +@PIPELINES.register_module( + Tasks.sentiment_classification, + module_name=r'sbert-sentiment-classification') +class SentimentClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[SbertForSentimentClassification, str], + preprocessor: SentimentClassificationPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SbertForSentimentClassification): a model instance + preprocessor (SentimentClassificationPreprocessor): a preprocessor instance + """ + assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \ + 'model must be a single str or SbertForSentimentClassification' + sc_model = model if isinstance( + model, + SbertForSentimentClassification) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = SentimentClassificationPreprocessor( + sc_model.model_dir, + first_sequence='first_sequence', + second_sequence='second_sequence') + super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) + + self.label_path = os.path.join(sc_model.model_dir, + 'label_mapping.json') + with open(self.label_path) as f: + self.label_mapping = json.load(f) + self.label_id_to_name = { + idx: name + for name, idx in self.label_mapping.items() + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + probs = inputs['probabilities'] + logits = inputs['logits'] + predictions = np.argsort(-probs, axis=-1) + preds = predictions[0] + b = 0 + new_result = list() + for pred in preds: + new_result.append({ + 'pred': self.label_id_to_name[pred], + 'prob': float(probs[b][pred]), + 'logit': float(logits[b][pred]) + }) + new_results = list() + new_results.append({ + 'id': + inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), + 'output': + new_result, + 'predictions': + new_result[0]['pred'], + 'probabilities': + ','.join([str(t) for t in inputs['probabilities'][b]]), + 'logits': + ','.join([str(t) for t in inputs['logits'][b]]) + }) + + return new_results[0] diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 518ea977..a686cb55 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,4 +5,5 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .nlp import * # noqa F403 -from .nlp import TextGenerationPreprocessor +from .nlp import (SentimentClassificationPreprocessor, + TextGenerationPreprocessor) diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 0de41bfc..31d5acb3 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -27,6 +27,71 @@ class Tokenize(Preprocessor): return data +@PREPROCESSORS.register_module( + Fields.sentiment_classification, + module_name=r'sbert-sentiment-classification') +class SentimentClassificationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + from sofa import SbertTokenizer + self.model_dir: str = model_dir + self.first_sequence: str = kwargs.pop('first_sequence', + 'first_sequence') + self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') + self.sequence_length = kwargs.pop('sequence_length', 128) + + self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + new_data = {self.first_sequence: data} + # preprocess the data for the model input + + rst = { + 'id': [], + 'input_ids': [], + 'attention_mask': [], + 'token_type_ids': [] + } + + max_seq_length = self.sequence_length + + text_a = new_data[self.first_sequence] + text_b = new_data.get(self.second_sequence, None) + feature = self.tokenizer( + text_a, + text_b, + padding='max_length', + truncation=True, + max_length=max_seq_length) + + rst['id'].append(new_data.get('id', str(uuid.uuid4()))) + rst['input_ids'].append(feature['input_ids']) + rst['attention_mask'].append(feature['attention_mask']) + rst['token_type_ids'].append(feature['token_type_ids']) + + return rst + + @PREPROCESSORS.register_module( Fields.nlp, module_name=r'bert-sentiment-analysis') class SequenceClassificationPreprocessor(Preprocessor): diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index fa30dd2a..aa4d38c8 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -30,6 +30,7 @@ class Tasks(object): image_matting = 'image-matting' # nlp tasks + sentiment_classification = 'sentiment-classification' sentiment_analysis = 'sentiment-analysis' text_classification = 'text-classification' relation_extraction = 'relation-extraction' From 7140cdb670532db401117456ec8d386663ed5092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=80=9D=E5=AE=8F?= Date: Tue, 14 Jun 2022 13:58:43 +0800 Subject: [PATCH 2/3] [to #42322933] init --- .../nlp/sentiment_classification_model.py | 2 +- .../test_sentiment_classification.py | 52 +++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 tests/pipelines/test_sentiment_classification.py diff --git a/modelscope/models/nlp/sentiment_classification_model.py b/modelscope/models/nlp/sentiment_classification_model.py index fb32aff2..d0ab6698 100644 --- a/modelscope/models/nlp/sentiment_classification_model.py +++ b/modelscope/models/nlp/sentiment_classification_model.py @@ -55,7 +55,7 @@ class SbertForSentimentClassification(Model): self.model_dir = model_dir self.model = SbertTextClassifier.from_pretrained( - model_dir, num_labels=3) + model_dir, num_labels=2) self.model.eval() def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: diff --git a/tests/pipelines/test_sentiment_classification.py b/tests/pipelines/test_sentiment_classification.py new file mode 100644 index 00000000..1576b335 --- /dev/null +++ b/tests/pipelines/test_sentiment_classification.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from maas_hub.snapshot_download import snapshot_download + +from modelscope.models import Model +from modelscope.models.nlp import SbertForSentimentClassification +from modelscope.pipelines import SentimentClassificationPipeline, pipeline +from modelscope.preprocessors import SentimentClassificationPreprocessor +from modelscope.utils.constant import Tasks + + +class SentimentClassificationTest(unittest.TestCase): + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' + + def test_run_from_local(self): + cache_path = snapshot_download(self.model_id) + tokenizer = SentimentClassificationPreprocessor(cache_path) + model = SbertForSentimentClassification( + cache_path, tokenizer=tokenizer) + pipeline1 = SentimentClassificationPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.sentence_similarity, model=model, preprocessor=tokenizer) + print(f'sentence1: {self.sentence1}\n' + f'pipeline1:{pipeline1(input=self.sentence1)}') + print() + print(f'sentence1: {self.sentence1}\n' + f'pipeline1: {pipeline2(input=self.sentence1)}') + + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = SentimentClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence1)) + + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=self.model_id) + print(pipeline_ins(input=self.sentence1)) + + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.sentence_similarity) + print(pipeline_ins(input=self.sentence1)) + + +if __name__ == '__main__': + unittest.main() From 48da1619a74cc17d2e9c2028e79e448b9da0b404 Mon Sep 17 00:00:00 2001 From: "fubang.zfb" Date: Wed, 15 Jun 2022 17:07:04 +0800 Subject: [PATCH 3/3] [to #42322933] init --- modelscope/preprocessors/nlp.py | 3 +-- tests/pipelines/test_sentiment_classification.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 31d5acb3..c6632ce7 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -28,8 +28,7 @@ class Tokenize(Preprocessor): @PREPROCESSORS.register_module( - Fields.sentiment_classification, - module_name=r'sbert-sentiment-classification') + Fields.nlp, module_name=r'sbert-sentiment-classification') class SentimentClassificationPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/tests/pipelines/test_sentiment_classification.py b/tests/pipelines/test_sentiment_classification.py index 1576b335..9a1a8484 100644 --- a/tests/pipelines/test_sentiment_classification.py +++ b/tests/pipelines/test_sentiment_classification.py @@ -11,8 +11,8 @@ from modelscope.utils.constant import Tasks class SentimentClassificationTest(unittest.TestCase): - model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' - sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' + model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' + sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音' def test_run_from_local(self): cache_path = snapshot_download(self.model_id) @@ -22,7 +22,9 @@ class SentimentClassificationTest(unittest.TestCase): pipeline1 = SentimentClassificationPipeline( model, preprocessor=tokenizer) pipeline2 = pipeline( - Tasks.sentence_similarity, model=model, preprocessor=tokenizer) + Tasks.sentiment_classification, + model=model, + preprocessor=tokenizer) print(f'sentence1: {self.sentence1}\n' f'pipeline1:{pipeline1(input=self.sentence1)}') print() @@ -33,18 +35,18 @@ class SentimentClassificationTest(unittest.TestCase): model = Model.from_pretrained(self.model_id) tokenizer = SentimentClassificationPreprocessor(model.model_dir) pipeline_ins = pipeline( - task=Tasks.sentence_similarity, + task=Tasks.sentiment_classification, model=model, preprocessor=tokenizer) print(pipeline_ins(input=self.sentence1)) def test_run_with_model_name(self): pipeline_ins = pipeline( - task=Tasks.sentence_similarity, model=self.model_id) + task=Tasks.sentiment_classification, model=self.model_id) print(pipeline_ins(input=self.sentence1)) def test_run_with_default_model(self): - pipeline_ins = pipeline(task=Tasks.sentence_similarity) + pipeline_ins = pipeline(task=Tasks.sentiment_classification) print(pipeline_ins(input=self.sentence1))