mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
merge with nli task
This commit is contained in:
@@ -5,4 +5,5 @@ from .audio.tts.vocoder import Hifigan16k
|
||||
from .base import Model
|
||||
from .builder import MODELS, build_model
|
||||
from .multi_model import OfaForImageCaptioning
|
||||
from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity
|
||||
from .nlp import (BertForSequenceClassification, SbertForNLI,
|
||||
SbertForSentenceSimilarity)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .bert_for_sequence_classification import * # noqa F403
|
||||
from .nli_model import * # noqa F403
|
||||
from .palm_for_text_generation import * # noqa F403
|
||||
from .sbert_for_sentence_similarity import * # noqa F403
|
||||
from .sbert_for_token_classification import * # noqa F403
|
||||
|
||||
84
modelscope/models/nlp/nli_model.py
Normal file
84
modelscope/models/nlp/nli_model.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sofa import SbertConfig, SbertModel
|
||||
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN, get_activation
|
||||
from transformers.models.bert.modeling_bert import SequenceClassifierOutput
|
||||
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ..base import Model, Tensor
|
||||
from ..builder import MODELS
|
||||
|
||||
__all__ = ['SbertForNLI']
|
||||
|
||||
|
||||
class SbertTextClassifier(SbertPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
self.encoder = SbertModel(config, add_pooling_layer=True)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, input_ids=None, token_type_ids=None):
|
||||
outputs = self.encoder(
|
||||
input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
return_dict=None,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
return logits
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base')
|
||||
class SbertForNLI(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the text generation model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
model_cls (Optional[Any], optional): model loader, if None, use the
|
||||
default loader to load model weights, by default None.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model_dir = model_dir
|
||||
|
||||
self.model = SbertTextClassifier.from_pretrained(
|
||||
model_dir, num_labels=3)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
input (Dict[str, Any]): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: results
|
||||
Example:
|
||||
{
|
||||
'predictions': array([1]), # lable 0-negative 1-positive
|
||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
|
||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
||||
}
|
||||
"""
|
||||
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
|
||||
token_type_ids = torch.tensor(
|
||||
input['token_type_ids'], dtype=torch.long)
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids, token_type_ids)
|
||||
probs = logits.softmax(-1).numpy()
|
||||
pred = logits.argmax(-1).numpy()
|
||||
logits = logits.numpy()
|
||||
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
|
||||
return res
|
||||
@@ -20,6 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
('sbert-base-chinese-sentence-similarity',
|
||||
'damo/nlp_structbert_sentence-similarity_chinese-base'),
|
||||
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
|
||||
Tasks.nli: ('nlp_structbert_nli_chinese-base',
|
||||
'damo/nlp_structbert_nli_chinese-base'),
|
||||
Tasks.text_classification:
|
||||
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
|
||||
Tasks.text_generation: ('palm2.0',
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .nli_pipeline import * # noqa F403
|
||||
from .sentence_similarity_pipeline import * # noqa F403
|
||||
from .sequence_classification_pipeline import * # noqa F403
|
||||
from .text_generation_pipeline import * # noqa F403
|
||||
|
||||
88
modelscope/pipelines/nlp/nli_pipeline.py
Normal file
88
modelscope/pipelines/nlp/nli_pipeline.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from modelscope.models.nlp import SbertForNLI
|
||||
from modelscope.preprocessors import NLIPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ...models import Model
|
||||
from ..base import Input, Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
__all__ = ['NLIPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base')
|
||||
class NLIPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[SbertForNLI, str],
|
||||
preprocessor: NLIPreprocessor = None,
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
|
||||
|
||||
Args:
|
||||
model (SbertForNLI): a model instance
|
||||
preprocessor (NLIPreprocessor): a preprocessor instance
|
||||
"""
|
||||
assert isinstance(model, str) or isinstance(model, SbertForNLI), \
|
||||
'model must be a single str or SbertForNLI'
|
||||
sc_model = model if isinstance(
|
||||
model, SbertForNLI) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = NLIPreprocessor(
|
||||
sc_model.model_dir,
|
||||
first_sequence='first_sequence',
|
||||
second_sequence='second_sequence')
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
self.label_path = os.path.join(sc_model.model_dir,
|
||||
'label_mapping.json')
|
||||
with open(self.label_path) as f:
|
||||
self.label_mapping = json.load(f)
|
||||
self.label_id_to_name = {
|
||||
idx: name
|
||||
for name, idx in self.label_mapping.items()
|
||||
}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): _description_
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
|
||||
probs = inputs['probabilities']
|
||||
logits = inputs['logits']
|
||||
predictions = np.argsort(-probs, axis=-1)
|
||||
preds = predictions[0]
|
||||
b = 0
|
||||
new_result = list()
|
||||
for pred in preds:
|
||||
new_result.append({
|
||||
'pred': self.label_id_to_name[pred],
|
||||
'prob': float(probs[b][pred]),
|
||||
'logit': float(logits[b][pred])
|
||||
})
|
||||
new_results = list()
|
||||
new_results.append({
|
||||
'id':
|
||||
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
|
||||
'output':
|
||||
new_result,
|
||||
'predictions':
|
||||
new_result[0]['pred'],
|
||||
'probabilities':
|
||||
','.join([str(t) for t in inputs['probabilities'][b]]),
|
||||
'logits':
|
||||
','.join([str(t) for t in inputs['logits'][b]])
|
||||
})
|
||||
|
||||
return new_results[0]
|
||||
@@ -7,4 +7,5 @@ from .common import Compose
|
||||
from .image import LoadImage, load_image
|
||||
from .multi_model import OfaImageCaptionPreprocessor
|
||||
from .nlp import * # noqa F403
|
||||
from .nlp import NLIPreprocessor, TextGenerationPreprocessor
|
||||
from .text_to_speech import * # noqa F403
|
||||
|
||||
@@ -12,7 +12,8 @@ from .builder import PREPROCESSORS
|
||||
|
||||
__all__ = [
|
||||
'Tokenize', 'SequenceClassificationPreprocessor',
|
||||
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor'
|
||||
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor',
|
||||
'NLIPreprocessor'
|
||||
]
|
||||
|
||||
|
||||
@@ -30,6 +31,77 @@ class Tokenize(Preprocessor):
|
||||
return data
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=r'nlp_structbert_nli_chinese-base')
|
||||
class NLIPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
from sofa import SbertTokenizer
|
||||
self.model_dir: str = model_dir
|
||||
self.first_sequence: str = kwargs.pop('first_sequence',
|
||||
'first_sequence')
|
||||
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
|
||||
self.sequence_length = kwargs.pop('sequence_length', 128)
|
||||
|
||||
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)
|
||||
|
||||
@type_assert(object, tuple)
|
||||
def __call__(self, data: tuple) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
Args:
|
||||
data (tuple): [sentence1, sentence2]
|
||||
sentence1 (str): a sentence
|
||||
Example:
|
||||
'you are so handsome.'
|
||||
sentence2 (str): a sentence
|
||||
Example:
|
||||
'you are so beautiful.'
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
sentence1, sentence2 = data
|
||||
new_data = {
|
||||
self.first_sequence: sentence1,
|
||||
self.second_sequence: sentence2
|
||||
}
|
||||
# preprocess the data for the model input
|
||||
|
||||
rst = {
|
||||
'id': [],
|
||||
'input_ids': [],
|
||||
'attention_mask': [],
|
||||
'token_type_ids': []
|
||||
}
|
||||
|
||||
max_seq_length = self.sequence_length
|
||||
|
||||
text_a = new_data[self.first_sequence]
|
||||
text_b = new_data[self.second_sequence]
|
||||
feature = self.tokenizer(
|
||||
text_a,
|
||||
text_b,
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=max_seq_length)
|
||||
|
||||
rst['id'].append(new_data.get('id', str(uuid.uuid4())))
|
||||
rst['input_ids'].append(feature['input_ids'])
|
||||
rst['attention_mask'].append(feature['attention_mask'])
|
||||
rst['token_type_ids'].append(feature['token_type_ids'])
|
||||
|
||||
return rst
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=r'bert-sequence-classification')
|
||||
class SequenceClassificationPreprocessor(Preprocessor):
|
||||
|
||||
@@ -32,6 +32,7 @@ class Tasks(object):
|
||||
|
||||
# nlp tasks
|
||||
word_segmentation = 'word-segmentation'
|
||||
nli = 'nli'
|
||||
sentiment_analysis = 'sentiment-analysis'
|
||||
sentence_similarity = 'sentence-similarity'
|
||||
text_classification = 'text-classification'
|
||||
|
||||
15
test.py
Normal file
15
test.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from modelscope.models import SbertForNLI
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.preprocessors import NLIPreprocessor
|
||||
|
||||
model = SbertForNLI('../nlp_structbert_nli_chinese-base')
|
||||
print(model)
|
||||
tokenizer = NLIPreprocessor(model.model_dir)
|
||||
|
||||
semantic_cls = pipeline('nli', model=model, preprocessor=tokenizer)
|
||||
print(type(semantic_cls))
|
||||
|
||||
print(
|
||||
semantic_cls(
|
||||
input=('我想还有一件事也伤害到了老师的招聘,那就是他们在课堂上失去了很多的权威',
|
||||
'教师在课堂上失去权威,导致想要进入这一职业的人减少了。')))
|
||||
49
tests/pipelines/test_nli.py
Normal file
49
tests/pipelines/test_nli.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SbertForNLI
|
||||
from modelscope.pipelines import NLIPipeline, pipeline
|
||||
from modelscope.preprocessors import NLIPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class NLITest(unittest.TestCase):
|
||||
model_id = 'damo/nlp_structbert_nli_chinese-base'
|
||||
sentence1 = '四川商务职业学院和四川财经职业学院哪个好?'
|
||||
sentence2 = '四川商务职业学院商务管理在哪个校区?'
|
||||
|
||||
@unittest.skip('skip temporarily to save test time')
|
||||
def test_run_from_local(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
tokenizer = NLIPreprocessor(cache_path)
|
||||
model = SbertForNLI(cache_path, tokenizer=tokenizer)
|
||||
pipeline1 = NLIPipeline(model, preprocessor=tokenizer)
|
||||
pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer)
|
||||
print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
|
||||
f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}')
|
||||
print()
|
||||
print(
|
||||
f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
|
||||
f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}')
|
||||
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
tokenizer = NLIPreprocessor(model.model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.nli, model=model, preprocessor=tokenizer)
|
||||
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
|
||||
|
||||
def test_run_with_model_name(self):
|
||||
pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id)
|
||||
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
|
||||
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(task=Tasks.nli)
|
||||
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user