Files
modelscope/tests/pipelines/nlp/test_dialog_state_tracking.py

47 lines
1.5 KiB
Python
Raw Normal View History

2022-06-25 12:24:18 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
2022-06-25 13:44:31 +08:00
import os
import os.path as osp
import tempfile
2022-06-25 12:24:18 +08:00
import unittest
2022-06-25 13:44:31 +08:00
from modelscope.hub.snapshot_download import snapshot_download
2022-06-25 12:24:18 +08:00
from modelscope.models import Model
from modelscope.models.nlp import DialogStateTrackingModel
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
from modelscope.preprocessors import DialogStateTrackingPreprocessor
from modelscope.utils.constant import Tasks
2022-06-25 13:44:31 +08:00
class DialogStateTrackingTest(unittest.TestCase):
model_id = 'damo/nlp_space_dialog-state-tracking'
test_case = [{
'utter': {
'User-1':
"I'm looking for a place to stay. It needs to be a guesthouse and include free wifi."
},
'history_states': [{}]
}]
2022-06-25 12:24:18 +08:00
def test_run(self):
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking'
2022-06-25 13:44:31 +08:00
# cache_path = snapshot_download(self.model_id)
model = DialogStateTrackingModel(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
pipeline1 = DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor)
history_states = {}
for step, item in enumerate(self.test_case):
history_states = pipeline1(item)
print(history_states)
2022-06-25 13:44:31 +08:00
@unittest.skip('test with snapshot_download')
def test_run_with_model_from_modelhub(self):
2022-06-25 12:24:18 +08:00
pass
if __name__ == '__main__':
unittest.main()