mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
ut run passed
This commit is contained in:
@@ -17,5 +17,5 @@ class SbertForNLI(SbertForSequenceClassificationBase):
|
||||
model_cls (Optional[Any], optional): model loader, if None, use the
|
||||
default loader to load model weights, by default None.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
super().__init__(model_dir, *args, model_args={"num_labels": 3}, **kwargs)
|
||||
assert self.model.config.num_labels == 3
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .sbert_for_sequence_classification import SbertForSequenceClassificationBase
|
||||
from ..builder import MODELS
|
||||
from modelscope.metainfo import Models
|
||||
|
||||
__all__ = ['SbertForSentimentClassification']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.sentiment_classification,
|
||||
module_name=r'sbert-sentiment-classification')
|
||||
module_name=Models.structbert)
|
||||
class SbertForSentimentClassification(SbertForSequenceClassificationBase):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
|
||||
@@ -4,6 +4,7 @@ from ..base import Model
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel, SbertModel
|
||||
|
||||
|
||||
@@ -33,9 +34,11 @@ class SbertTextClassfier(SbertPreTrainedModel):
|
||||
|
||||
class SbertForSequenceClassificationBase(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
def __init__(self, model_dir: str, model_args=None, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model = SbertTextClassfier.from_pretrained(model_dir)
|
||||
if model_args is None:
|
||||
model_args = {}
|
||||
self.model = SbertTextClassfier.from_pretrained(model_dir, **model_args)
|
||||
self.id2label = {}
|
||||
self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
|
||||
if os.path.exists(self.label_path):
|
||||
@@ -43,8 +46,17 @@ class SbertForSequenceClassificationBase(Model):
|
||||
self.label_mapping = json.load(f)
|
||||
self.id2label = {idx: name for name, idx in self.label_mapping.items()}
|
||||
|
||||
def train(self):
|
||||
return self.model.train()
|
||||
|
||||
def eval(self):
|
||||
return self.model.eval()
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||
return self.model.forward(input)
|
||||
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
|
||||
token_type_ids = torch.tensor(
|
||||
input['token_type_ids'], dtype=torch.long)
|
||||
return self.model.forward(input_ids, token_type_ids)
|
||||
|
||||
def postprocess(self, input, **kwargs):
|
||||
logits = input["logits"]
|
||||
|
||||
@@ -26,6 +26,12 @@ class SbertForZeroShotClassification(Model):
|
||||
from sofa import SbertForSequenceClassification
|
||||
self.model = SbertForSequenceClassification.from_pretrained(model_dir)
|
||||
|
||||
def train(self):
|
||||
return self.model.train()
|
||||
|
||||
def eval(self):
|
||||
return self.model.eval()
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||
"""return the result by the model
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class NLIPipeline(Pipeline):
|
||||
new_result = list()
|
||||
for pred in preds:
|
||||
new_result.append({
|
||||
'pred': self.label_id_to_name[pred],
|
||||
'pred': self.model.id2label[pred],
|
||||
'prob': float(probs[b][pred]),
|
||||
'logit': float(logits[b][pred])
|
||||
})
|
||||
|
||||
@@ -70,7 +70,7 @@ class SentimentClassificationPipeline(Pipeline):
|
||||
new_result = list()
|
||||
for pred in preds:
|
||||
new_result.append({
|
||||
'pred': self.label_id_to_name[pred],
|
||||
'pred': self.model.id2label[pred],
|
||||
'prob': float(probs[b][pred]),
|
||||
'logit': float(logits[b][pred])
|
||||
})
|
||||
|
||||
@@ -44,7 +44,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
if preprocessor is None:
|
||||
preprocessor = ZeroShotClassificationPreprocessor(
|
||||
sc_model.model_dir)
|
||||
model.eval()
|
||||
sc_model.eval()
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user