Files
modelscope/maas_lib/models/nlp/sequence_classification_model.py
wenmeng.zwm db4a8be9c5 [to #41669377] docs and tools refinement and release
1. add build_doc linter script
2. add sphinx-docs support
3. add development doc and api doc
4. change version to 0.1.0 for the first internal release version

Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307
2022-05-20 16:51:34 +08:00

63 lines
2.2 KiB
Python

from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from maas_lib.utils.constant import Tasks
from ..base import Model
from ..builder import MODELS
__all__ = ['SequenceClassificationModel']
@MODELS.register_module(
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
class SequenceClassificationModel(Model):
def __init__(self,
model_dir: str,
model_cls: Optional[Any] = None,
*args,
**kwargs):
# Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs)
# Predictor.__init__(self, *args, **kwargs)
"""initilize the sequence classification model from the `model_dir` path.
Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, model_cls, *args, **kwargs)
from easynlp.appzoo import SequenceClassification
from easynlp.core.predictor import get_model_predictor
self.model_dir = model_dir
model_cls = SequenceClassification if not model_cls else model_cls
self.model = get_model_predictor(
model_dir=model_dir,
model_cls=model_cls,
input_keys=[('input_ids', torch.LongTensor),
('attention_mask', torch.LongTensor),
('token_type_ids', torch.LongTensor)],
output_keys=['predictions', 'probabilities', 'logits'])
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model
Args:
input (Dict[str, Any]): the preprocessed data
Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
return self.model.predict(input)
...