Files
modelscope/tests/pipelines/test_dialog_intent_prediction.py

60 lines
2.0 KiB
Python
Raw Normal View History

2022-06-10 10:17:27 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
2022-06-22 20:14:03 +08:00
from modelscope.hub.snapshot_download import snapshot_download
2022-06-12 23:27:26 +08:00
from modelscope.models import Model
from modelscope.models.nlp import SpaceForDialogIntentModel
2022-06-17 14:04:28 +08:00
from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline
from modelscope.preprocessors import DialogIntentPredictionPreprocessor
2022-06-12 14:55:32 +08:00
from modelscope.utils.constant import Tasks
2022-06-10 10:17:27 +08:00
2022-06-17 14:04:28 +08:00
class DialogIntentPredictionTest(unittest.TestCase):
model_id = 'damo/nlp_space_dialog-intent-prediction'
2022-06-12 23:27:26 +08:00
test_case = [
'How do I locate my card?',
'I still have not received my new card, I ordered over a week ago.'
]
2022-06-10 10:17:27 +08:00
2022-06-12 23:27:26 +08:00
@unittest.skip('test with snapshot_download')
2022-06-10 10:17:27 +08:00
def test_run(self):
2022-06-12 23:27:26 +08:00
cache_path = snapshot_download(self.model_id)
2022-06-17 14:04:28 +08:00
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
model = SpaceForDialogIntentModel(
2022-06-12 23:27:26 +08:00
model_dir=cache_path,
2022-06-10 10:17:27 +08:00
text_field=preprocessor.text_field,
config=preprocessor.config)
2022-06-12 14:14:26 +08:00
2022-06-12 23:27:26 +08:00
pipelines = [
2022-06-17 14:04:28 +08:00
DialogIntentPredictionPipeline(
model=model, preprocessor=preprocessor),
2022-06-12 23:27:26 +08:00
pipeline(
2022-06-17 14:04:28 +08:00
task=Tasks.dialog_intent_prediction,
2022-06-12 23:27:26 +08:00
model=model,
preprocessor=preprocessor)
]
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
2022-06-17 14:04:28 +08:00
preprocessor = DialogIntentPredictionPreprocessor(
model_dir=model.model_dir)
2022-06-12 23:27:26 +08:00
pipelines = [
2022-06-17 14:04:28 +08:00
DialogIntentPredictionPipeline(
model=model, preprocessor=preprocessor),
2022-06-12 23:27:26 +08:00
pipeline(
2022-06-17 14:04:28 +08:00
task=Tasks.dialog_intent_prediction,
2022-06-12 23:27:26 +08:00
model=model,
preprocessor=preprocessor)
]
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))
2022-06-10 10:17:27 +08:00
if __name__ == '__main__':
unittest.main()