2022-06-10 10:17:27 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
import os
|
|
|
|
|
import os.path as osp
|
|
|
|
|
import tempfile
|
|
|
|
|
import unittest
|
|
|
|
|
|
2022-06-12 14:14:26 +08:00
|
|
|
from tests.case.nlp.dialog_intent_case import test_case
|
2022-06-10 10:17:27 +08:00
|
|
|
|
2022-06-12 14:55:32 +08:00
|
|
|
from modelscope.models.nlp import DialogIntentModel
|
|
|
|
|
from modelscope.pipelines import DialogIntentPipeline, pipeline
|
|
|
|
|
from modelscope.preprocessors import DialogIntentPreprocessor
|
|
|
|
|
from modelscope.utils.constant import Tasks
|
2022-06-10 10:17:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class DialogGenerationTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
def test_run(self):
|
|
|
|
|
|
|
|
|
|
modeldir = '/Users/yangliu/Desktop/space-dialog-intent'
|
|
|
|
|
|
|
|
|
|
preprocessor = DialogIntentPreprocessor(model_dir=modeldir)
|
|
|
|
|
model = DialogIntentModel(
|
|
|
|
|
model_dir=modeldir,
|
|
|
|
|
text_field=preprocessor.text_field,
|
|
|
|
|
config=preprocessor.config)
|
2022-06-12 14:14:26 +08:00
|
|
|
pipeline1 = DialogIntentPipeline(
|
|
|
|
|
model=model, preprocessor=preprocessor)
|
|
|
|
|
# pipeline1 = pipeline(task=Tasks.dialog_intent, model=model, preprocessor=preprocessor)
|
|
|
|
|
|
|
|
|
|
for item in test_case:
|
2022-06-12 14:55:32 +08:00
|
|
|
print(pipeline1(item))
|
2022-06-10 10:17:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|