mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
merge feat nlp
This commit is contained in:
@@ -8,3 +8,4 @@ from .sentiment_classification_pipeline import * # noqa F403
|
||||
from .sequence_classification_pipeline import * # noqa F403
|
||||
from .text_generation_pipeline import * # noqa F403
|
||||
from .word_segmentation_pipeline import * # noqa F403
|
||||
from .zero_shot_classification_pipeline import * # noqa F403
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.special import softmax
|
||||
|
||||
from ...metainfo import Pipelines
|
||||
from ...models import Model
|
||||
from ...models.nlp import SbertForZeroShotClassification
|
||||
from ...preprocessors import ZeroShotClassificationPreprocessor
|
||||
from ...utils.constant import Tasks
|
||||
from ..base import Input, Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
__all__ = ['ZeroShotClassificationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.zero_shot_classification,
|
||||
module_name=Pipelines.zero_shot_classification)
|
||||
class ZeroShotClassificationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[SbertForZeroShotClassification, str],
|
||||
preprocessor: ZeroShotClassificationPreprocessor = None,
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
|
||||
Args:
|
||||
model (SbertForSentimentClassification): a model instance
|
||||
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
|
||||
"""
|
||||
assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \
|
||||
'model must be a single str or SbertForZeroShotClassification'
|
||||
model = model if isinstance(
|
||||
model,
|
||||
SbertForZeroShotClassification) else Model.from_pretrained(model)
|
||||
self.entailment_id = 0
|
||||
self.contradiction_id = 2
|
||||
if preprocessor is None:
|
||||
preprocessor = ZeroShotClassificationPreprocessor(model.model_dir)
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_params = {}
|
||||
postprocess_params = {}
|
||||
if 'candidate_labels' in kwargs:
|
||||
candidate_labels = kwargs.pop('candidate_labels')
|
||||
preprocess_params['candidate_labels'] = candidate_labels
|
||||
postprocess_params['candidate_labels'] = candidate_labels
|
||||
else:
|
||||
raise ValueError('You must include at least one label.')
|
||||
preprocess_params['hypothesis_template'] = kwargs.pop(
|
||||
'hypothesis_template', '{}')
|
||||
postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self,
|
||||
inputs: Dict[str, Any],
|
||||
candidate_labels,
|
||||
multi_label=False) -> Dict[str, Any]:
|
||||
"""process the prediction results
|
||||
Args:
|
||||
inputs (Dict[str, Any]): _description_
|
||||
Returns:
|
||||
Dict[str, Any]: the prediction results
|
||||
"""
|
||||
logits = inputs['logits']
|
||||
if multi_label or len(candidate_labels) == 1:
|
||||
logits = logits[..., [self.contradiction_id, self.entailment_id]]
|
||||
scores = softmax(logits, axis=-1)[..., 1]
|
||||
else:
|
||||
logits = logits[..., self.entailment_id]
|
||||
scores = softmax(logits, axis=-1)
|
||||
reversed_index = list(reversed(scores.argsort()))
|
||||
result = {
|
||||
'labels': [candidate_labels[i] for i in reversed_index],
|
||||
'scores': [scores[i].item() for i in reversed_index]
|
||||
}
|
||||
return result
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
# from .audio import LinearAECAndFbank
|
||||
from .audio import LinearAECAndFbank
|
||||
from .base import Preprocessor
|
||||
# from .builder import PREPROCESSORS, build_preprocessor
|
||||
from .builder import PREPROCESSORS, build_preprocessor
|
||||
from .common import Compose
|
||||
from .image import LoadImage, load_image
|
||||
from .kws import WavToLists
|
||||
@@ -11,5 +11,4 @@ from .nlp import * # noqa F403
|
||||
from .space.dialog_intent_prediction_preprocessor import * # noqa F403
|
||||
from .space.dialog_modeling_preprocessor import * # noqa F403
|
||||
from .space.dialog_state_tracking_preprocessor import * # noqa F403
|
||||
|
||||
# from .text_to_speech import * # noqa F403
|
||||
from .text_to_speech import * # noqa F403
|
||||
|
||||
Reference in New Issue
Block a user