ut run passed

This commit is contained in:
雨泓
2022-06-22 21:51:29 +08:00
parent 1009d54d22
commit 3987a3bf7d
7 changed files with 27 additions and 8 deletions

View File

@@ -17,5 +17,5 @@ class SbertForNLI(SbertForSequenceClassificationBase):
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
super().__init__(model_dir, *args, model_args={"num_labels": 3}, **kwargs)
assert self.model.config.num_labels == 3

View File

@@ -1,13 +1,14 @@
from modelscope.utils.constant import Tasks
from .sbert_for_sequence_classification import SbertForSequenceClassificationBase
from ..builder import MODELS
from modelscope.metainfo import Models
__all__ = ['SbertForSentimentClassification']
@MODELS.register_module(
Tasks.sentiment_classification,
module_name=r'sbert-sentiment-classification')
module_name=Models.structbert)
class SbertForSentimentClassification(SbertForSequenceClassificationBase):
def __init__(self, model_dir: str, *args, **kwargs):

View File

@@ -4,6 +4,7 @@ from ..base import Model
import numpy as np
import json
import os
import torch
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel, SbertModel
@@ -33,9 +34,11 @@ class SbertTextClassfier(SbertPreTrainedModel):
class SbertForSequenceClassificationBase(Model):
def __init__(self, model_dir: str, *args, **kwargs):
def __init__(self, model_dir: str, model_args=None, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.model = SbertTextClassfier.from_pretrained(model_dir)
if model_args is None:
model_args = {}
self.model = SbertTextClassfier.from_pretrained(model_dir, **model_args)
self.id2label = {}
self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
if os.path.exists(self.label_path):
@@ -43,8 +46,17 @@ class SbertForSequenceClassificationBase(Model):
self.label_mapping = json.load(f)
self.id2label = {idx: name for name, idx in self.label_mapping.items()}
def train(self):
return self.model.train()
def eval(self):
return self.model.eval()
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
return self.model.forward(input)
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
token_type_ids = torch.tensor(
input['token_type_ids'], dtype=torch.long)
return self.model.forward(input_ids, token_type_ids)
def postprocess(self, input, **kwargs):
logits = input["logits"]

View File

@@ -26,6 +26,12 @@ class SbertForZeroShotClassification(Model):
from sofa import SbertForSequenceClassification
self.model = SbertForSequenceClassification.from_pretrained(model_dir)
def train(self):
return self.model.train()
def eval(self):
return self.model.eval()
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

View File

@@ -69,7 +69,7 @@ class NLIPipeline(Pipeline):
new_result = list()
for pred in preds:
new_result.append({
'pred': self.label_id_to_name[pred],
'pred': self.model.id2label[pred],
'prob': float(probs[b][pred]),
'logit': float(logits[b][pred])
})

View File

@@ -70,7 +70,7 @@ class SentimentClassificationPipeline(Pipeline):
new_result = list()
for pred in preds:
new_result.append({
'pred': self.label_id_to_name[pred],
'pred': self.model.id2label[pred],
'prob': float(probs[b][pred]),
'logit': float(logits[b][pred])
})

View File

@@ -44,7 +44,7 @@ class ZeroShotClassificationPipeline(Pipeline):
if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor(
sc_model.model_dir)
model.eval()
sc_model.eval()
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
def _sanitize_parameters(self, **kwargs):