diff --git a/modelscope/models/nlp/space/dialog_state_tracking_model.py b/modelscope/models/nlp/space/dialog_state_tracking_model.py index b24b3f94..00fc968a 100644 --- a/modelscope/models/nlp/space/dialog_state_tracking_model.py +++ b/modelscope/models/nlp/space/dialog_state_tracking_model.py @@ -97,5 +97,6 @@ class SpaceForDialogStateTrackingModel(Model): 'input_ids_unmasked': input_ids_unmasked, 'values': values, 'inform': inform, - 'prefix': 'final' + 'prefix': 'final', + 'ds': input['ds'] } diff --git a/modelscope/pipelines/__init__.py b/modelscope/pipelines/__init__.py index 74f5507f..b0bd7489 100644 --- a/modelscope/pipelines/__init__.py +++ b/modelscope/pipelines/__init__.py @@ -1,7 +1,6 @@ -from .audio import LinearAECPipeline -from .audio.ans_pipeline import ANSPipeline +# from .audio import LinearAECPipeline +# from .audio.ans_pipeline import ANSPipeline from .base import Pipeline from .builder import pipeline -from .cv import * # noqa F403 from .multi_modal import * # noqa F403 from .nlp import * # noqa F403 diff --git a/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py index d4fa4bef..1a1d542a 100644 --- a/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py @@ -45,7 +45,8 @@ class DialogStateTrackingPipeline(Pipeline): values = inputs['values'] inform = inputs['inform'] prefix = inputs['prefix'] - ds = {slot: 'none' for slot in self.config.dst_slot_list} + # ds = {slot: 'none' for slot in self.config.dst_slot_list} + ds = inputs['ds'] ds = predict_and_format(self.config, self.tokenizer, _inputs, _outputs[2], _outputs[3], _outputs[4], @@ -113,7 +114,11 @@ def predict_and_format(config, tokenizer, features, per_slot_class_logits, 'false'): dialog_state[slot] = 'false' elif class_prediction == config.dst_class_types.index('inform'): - dialog_state[slot] = '§§' + inform[i][slot] + # dialog_state[slot] = '§§' + inform[i][slot] + if isinstance(inform[i][slot], str): + dialog_state[slot] = inform[i][slot] + elif isinstance(inform[i][slot], list): + dialog_state[slot] = inform[i][slot][0] # Referral case is handled below prediction_addendum['slot_prediction_%s' diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 8e62939b..5b9e36b7 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -114,6 +114,44 @@ TASK_OUTPUTS = { # "scores": [0.9, 0.1, 0.05, 0.05] # } Tasks.nli: ['scores', 'labels'], + Tasks.dialog_modeling: [], + Tasks.dialog_intent_prediction: [], + + # { + # "dialog_states": { + # "taxi-leaveAt": "none", + # "taxi-destination": "none", + # "taxi-departure": "none", + # "taxi-arriveBy": "none", + # "restaurant-book_people": "none", + # "restaurant-book_day": "none", + # "restaurant-book_time": "none", + # "restaurant-food": "none", + # "restaurant-pricerange": "none", + # "restaurant-name": "none", + # "restaurant-area": "none", + # "hotel-book_people": "none", + # "hotel-book_day": "none", + # "hotel-book_stay": "none", + # "hotel-name": "none", + # "hotel-area": "none", + # "hotel-parking": "none", + # "hotel-pricerange": "cheap", + # "hotel-stars": "none", + # "hotel-internet": "none", + # "hotel-type": "true", + # "attraction-type": "none", + # "attraction-name": "none", + # "attraction-area": "none", + # "train-book_people": "none", + # "train-leaveAt": "none", + # "train-destination": "none", + # "train-day": "none", + # "train-arriveBy": "none", + # "train-departure": "none" + # } + # } + Tasks.dialog_state_tracking: ['dialog_states'], # ============ audio tasks =================== @@ -153,43 +191,5 @@ TASK_OUTPUTS = { # { # "image": np.ndarray with shape [height, width, 3] # } - Tasks.text_to_image_synthesis: ['image'], - Tasks.dialog_modeling: [], - Tasks.dialog_intent_prediction: [], - - # { - # "dialog_states": { - # "taxi-leaveAt": "none", - # "taxi-destination": "none", - # "taxi-departure": "none", - # "taxi-arriveBy": "none", - # "restaurant-book_people": "none", - # "restaurant-book_day": "none", - # "restaurant-book_time": "none", - # "restaurant-food": "none", - # "restaurant-pricerange": "none", - # "restaurant-name": "none", - # "restaurant-area": "none", - # "hotel-book_people": "none", - # "hotel-book_day": "none", - # "hotel-book_stay": "none", - # "hotel-name": "none", - # "hotel-area": "none", - # "hotel-parking": "none", - # "hotel-pricerange": "cheap", - # "hotel-stars": "none", - # "hotel-internet": "none", - # "hotel-type": "true", - # "attraction-type": "none", - # "attraction-name": "none", - # "attraction-area": "none", - # "train-book_people": "none", - # "train-leaveAt": "none", - # "train-destination": "none", - # "train-day": "none", - # "train-arriveBy": "none", - # "train-departure": "none" - # } - # } - Tasks.dialog_state_tracking: ['dialog_states'] + Tasks.text_to_image_synthesis: ['image'] } diff --git a/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py b/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py index 60a7bf4f..c1509eec 100644 --- a/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py @@ -118,8 +118,14 @@ class DialogStateTrackingPreprocessor(Preprocessor): for slot in self.config.dst_slot_list } + if len(history_states) > 2: + ds = history_states[-2] + else: + ds = {slot: 'none' for slot in self.config.dst_slot_list} + return { 'batch': dataset, 'features': features, - 'diag_state': diag_state + 'diag_state': diag_state, + 'ds': ds } diff --git a/modelscope/preprocessors/space/dst_processors.py b/modelscope/preprocessors/space/dst_processors.py index 01a7e3c7..12f7f1f8 100644 --- a/modelscope/preprocessors/space/dst_processors.py +++ b/modelscope/preprocessors/space/dst_processors.py @@ -432,6 +432,7 @@ class multiwoz22Processor(DSTProcessor): usr_sys_switch = True turn_itr = 0 + inform_dict = {slot: 'none' for slot in slot_list} for utt in utterances: # Assert that system and user utterances alternate is_sys_utt = utt['metadata'] != {} @@ -1501,7 +1502,7 @@ if __name__ == '__main__': } }, {}] - example = processor.create_example(utter3, history_states3, set_type, + example = processor.create_example(utter2, history_states2, set_type, slot_list, {}, append_history, use_history_labels, swap_utterances, label_value_repetitions,