diff --git a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py index b20c3c7c..4b2e29dd 100644 --- a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py @@ -28,6 +28,7 @@ class DialogIntentPredictionPipeline(Pipeline): super().__init__(model=model, preprocessor=preprocessor, **kwargs) self.model = model + self.categories = preprocessor.categories def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: """process the prediction results @@ -42,6 +43,10 @@ class DialogIntentPredictionPipeline(Pipeline): pred = inputs['pred'] pos = np.where(pred == np.max(pred)) - result = {'pred': pred, 'label': pos[0]} + result = { + 'pred': pred, + 'label_pos': pos[0], + 'label': self.categories[pos[0][0]] + } return result diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 88c0e9a5..76a06bf1 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -122,6 +122,31 @@ TASK_OUTPUTS = { # } Tasks.nli: ['scores', 'labels'], + # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, + # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, + # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, + # 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, + # 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05, + # 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05, + # 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05, + # 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05, + # 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03, + # 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05, + # 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05, + # 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05, + # 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05, + # 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05, + # 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05, + # 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05, + # 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, + # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, + # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, + # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} + Tasks.dialog_intent_prediction: ['pred', 'label_pos', 'label'], + + # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] + Tasks.dialog_modeling: ['sys'], + # ============ audio tasks =================== # audio processed for single file in PCM format diff --git a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py b/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py index 1528495b..2ceede02 100644 --- a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py @@ -3,6 +3,8 @@ import os from typing import Any, Dict +import json + from ...metainfo import Preprocessors from ...utils.config import Config from ...utils.constant import Fields, ModelFile @@ -32,6 +34,11 @@ class DialogIntentPredictionPreprocessor(Preprocessor): self.text_field = IntentBPETextField( self.model_dir, config=self.config) + self.categories = None + with open(os.path.join(self.model_dir, 'categories.json'), 'r') as f: + self.categories = json.load(f) + assert len(self.categories) == 77 + @type_assert(object, str) def __call__(self, data: str) -> Dict[str, Any]: """process the raw input data