2022-06-25 12:24:18 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2022-06-25 13:44:31 +08:00
|
|
|
import os
|
|
|
|
|
import os.path as osp
|
|
|
|
|
import tempfile
|
2022-06-25 12:24:18 +08:00
|
|
|
import unittest
|
|
|
|
|
|
2022-06-25 13:44:31 +08:00
|
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
2022-06-25 12:24:18 +08:00
|
|
|
from modelscope.models import Model
|
|
|
|
|
from modelscope.models.nlp import DialogStateTrackingModel
|
|
|
|
|
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
|
|
|
|
|
from modelscope.preprocessors import DialogStateTrackingPreprocessor
|
|
|
|
|
from modelscope.utils.constant import Tasks
|
|
|
|
|
|
|
|
|
|
|
2022-06-25 13:44:31 +08:00
|
|
|
class DialogStateTrackingTest(unittest.TestCase):
|
|
|
|
|
model_id = 'damo/nlp_space_dialog-state-tracking'
|
2022-06-28 16:02:44 +08:00
|
|
|
test_case = [{
|
2022-06-29 21:21:16 +08:00
|
|
|
'User-1':
|
|
|
|
|
'am looking for a place to to stay that has cheap price range it should be in a type of hotel'
|
|
|
|
|
}, {
|
|
|
|
|
'System-1':
|
|
|
|
|
'Okay, do you have a specific area you want to stay in?',
|
|
|
|
|
'Dialog_Act-1': {
|
|
|
|
|
'Hotel-Request': [['Area', '?']]
|
2022-06-28 16:02:44 +08:00
|
|
|
},
|
2022-06-29 21:21:16 +08:00
|
|
|
'User-2':
|
|
|
|
|
"no, i just need to make sure it's cheap. oh, and i need parking"
|
2022-06-29 16:35:19 +08:00
|
|
|
}, {
|
2022-06-29 21:21:16 +08:00
|
|
|
'System-2':
|
|
|
|
|
'I found 1 cheap hotel for you that includes parking. Do you like me to book it?',
|
|
|
|
|
'Dialog_Act-2': {
|
|
|
|
|
'Booking-Inform': [['none', 'none']],
|
|
|
|
|
'Hotel-Inform': [['Price', 'cheap'], ['Choice', '1'],
|
|
|
|
|
['Parking', 'none']]
|
2022-06-29 16:35:19 +08:00
|
|
|
},
|
2022-06-29 21:21:16 +08:00
|
|
|
'User-3':
|
|
|
|
|
'Yes, please. 6 people 3 nights starting on tuesday.'
|
2022-06-28 16:02:44 +08:00
|
|
|
}]
|
2022-06-25 12:24:18 +08:00
|
|
|
|
|
|
|
|
def test_run(self):
|
2022-06-28 16:02:44 +08:00
|
|
|
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking'
|
2022-06-25 13:44:31 +08:00
|
|
|
# cache_path = snapshot_download(self.model_id)
|
|
|
|
|
|
2022-06-28 16:02:44 +08:00
|
|
|
model = DialogStateTrackingModel(cache_path)
|
|
|
|
|
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
|
2022-06-29 16:35:19 +08:00
|
|
|
pipelines = [
|
|
|
|
|
DialogStateTrackingPipeline(
|
|
|
|
|
model=model, preprocessor=preprocessor),
|
2022-06-29 21:57:50 +08:00
|
|
|
pipeline(
|
|
|
|
|
task=Tasks.dialog_state_tracking,
|
|
|
|
|
model=model,
|
|
|
|
|
preprocessor=preprocessor)
|
2022-06-29 16:35:19 +08:00
|
|
|
]
|
2022-06-28 16:02:44 +08:00
|
|
|
|
2022-06-29 21:21:16 +08:00
|
|
|
history_states = [{}]
|
|
|
|
|
utter = {}
|
2022-06-29 16:35:19 +08:00
|
|
|
pipelines_len = len(pipelines)
|
2022-06-28 16:02:44 +08:00
|
|
|
for step, item in enumerate(self.test_case):
|
2022-06-29 21:21:16 +08:00
|
|
|
utter.update(item)
|
|
|
|
|
ds = pipelines[step % pipelines_len]({
|
|
|
|
|
'utter':
|
|
|
|
|
utter,
|
|
|
|
|
'history_states':
|
|
|
|
|
history_states
|
|
|
|
|
})
|
|
|
|
|
print(ds)
|
|
|
|
|
|
|
|
|
|
history_states.extend([ds, {}])
|
2022-06-25 13:44:31 +08:00
|
|
|
|
|
|
|
|
@unittest.skip('test with snapshot_download')
|
|
|
|
|
def test_run_with_model_from_modelhub(self):
|
2022-06-25 12:24:18 +08:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|