Files
modelscope/tests/pipelines/nlp/test_dialog_state_tracking.py

77 lines
2.5 KiB
Python
Raw Normal View History

2022-06-25 12:24:18 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
2022-06-25 13:44:31 +08:00
import os
import os.path as osp
import tempfile
2022-06-25 12:24:18 +08:00
import unittest
2022-06-25 13:44:31 +08:00
from modelscope.hub.snapshot_download import snapshot_download
2022-06-25 12:24:18 +08:00
from modelscope.models import Model
from modelscope.models.nlp import DialogStateTrackingModel
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
from modelscope.preprocessors import DialogStateTrackingPreprocessor
from modelscope.utils.constant import Tasks
2022-06-25 13:44:31 +08:00
class DialogStateTrackingTest(unittest.TestCase):
model_id = 'damo/nlp_space_dialog-state-tracking'
test_case = [{
2022-06-29 21:21:16 +08:00
'User-1':
'am looking for a place to to stay that has cheap price range it should be in a type of hotel'
}, {
'System-1':
'Okay, do you have a specific area you want to stay in?',
'Dialog_Act-1': {
'Hotel-Request': [['Area', '?']]
},
2022-06-29 21:21:16 +08:00
'User-2':
"no, i just need to make sure it's cheap. oh, and i need parking"
2022-06-29 16:35:19 +08:00
}, {
2022-06-29 21:21:16 +08:00
'System-2':
'I found 1 cheap hotel for you that includes parking. Do you like me to book it?',
'Dialog_Act-2': {
'Booking-Inform': [['none', 'none']],
'Hotel-Inform': [['Price', 'cheap'], ['Choice', '1'],
['Parking', 'none']]
2022-06-29 16:35:19 +08:00
},
2022-06-29 21:21:16 +08:00
'User-3':
'Yes, please. 6 people 3 nights starting on tuesday.'
}]
2022-06-25 12:24:18 +08:00
def test_run(self):
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking'
2022-06-25 13:44:31 +08:00
# cache_path = snapshot_download(self.model_id)
model = DialogStateTrackingModel(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
2022-06-29 16:35:19 +08:00
pipelines = [
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
2022-06-29 21:57:50 +08:00
pipeline(
task=Tasks.dialog_state_tracking,
model=model,
preprocessor=preprocessor)
2022-06-29 16:35:19 +08:00
]
2022-06-29 21:21:16 +08:00
history_states = [{}]
utter = {}
2022-06-29 16:35:19 +08:00
pipelines_len = len(pipelines)
for step, item in enumerate(self.test_case):
2022-06-29 21:21:16 +08:00
utter.update(item)
ds = pipelines[step % pipelines_len]({
'utter':
utter,
'history_states':
history_states
})
print(ds)
history_states.extend([ds, {}])
2022-06-25 13:44:31 +08:00
@unittest.skip('test with snapshot_download')
def test_run_with_model_from_modelhub(self):
2022-06-25 12:24:18 +08:00
pass
if __name__ == '__main__':
unittest.main()