mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
1. add build_doc linter script 2. add sphinx-docs support 3. add development doc and api doc 4. change version to 0.1.0 for the first internal release version Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307
63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
from typing import Any, Dict, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from maas_lib.utils.constant import Tasks
|
|
from ..base import Model
|
|
from ..builder import MODELS
|
|
|
|
__all__ = ['SequenceClassificationModel']
|
|
|
|
|
|
@MODELS.register_module(
|
|
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
|
|
class SequenceClassificationModel(Model):
|
|
|
|
def __init__(self,
|
|
model_dir: str,
|
|
model_cls: Optional[Any] = None,
|
|
*args,
|
|
**kwargs):
|
|
# Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs)
|
|
# Predictor.__init__(self, *args, **kwargs)
|
|
"""initilize the sequence classification model from the `model_dir` path.
|
|
|
|
Args:
|
|
model_dir (str): the model path.
|
|
model_cls (Optional[Any], optional): model loader, if None, use the
|
|
default loader to load model weights, by default None.
|
|
"""
|
|
|
|
super().__init__(model_dir, model_cls, *args, **kwargs)
|
|
|
|
from easynlp.appzoo import SequenceClassification
|
|
from easynlp.core.predictor import get_model_predictor
|
|
self.model_dir = model_dir
|
|
model_cls = SequenceClassification if not model_cls else model_cls
|
|
self.model = get_model_predictor(
|
|
model_dir=model_dir,
|
|
model_cls=model_cls,
|
|
input_keys=[('input_ids', torch.LongTensor),
|
|
('attention_mask', torch.LongTensor),
|
|
('token_type_ids', torch.LongTensor)],
|
|
output_keys=['predictions', 'probabilities', 'logits'])
|
|
|
|
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
|
"""return the result by the model
|
|
|
|
Args:
|
|
input (Dict[str, Any]): the preprocessed data
|
|
|
|
Returns:
|
|
Dict[str, np.ndarray]: results
|
|
Example:
|
|
{
|
|
'predictions': array([1]), # lable 0-negative 1-positive
|
|
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
|
|
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
|
}
|
|
"""
|
|
return self.model.predict(input)
|
|
...
|