inherit bug fix

This commit is contained in:
ly119399
2022-06-30 10:20:35 +08:00
parent 6512a57909
commit 31a49a1830
6 changed files with 59 additions and 47 deletions

View File

@@ -97,5 +97,6 @@ class SpaceForDialogStateTrackingModel(Model):
'input_ids_unmasked': input_ids_unmasked,
'values': values,
'inform': inform,
'prefix': 'final'
'prefix': 'final',
'ds': input['ds']
}

View File

@@ -1,7 +1,6 @@
from .audio import LinearAECPipeline
from .audio.ans_pipeline import ANSPipeline
# from .audio import LinearAECPipeline
# from .audio.ans_pipeline import ANSPipeline
from .base import Pipeline
from .builder import pipeline
from .cv import * # noqa F403
from .multi_modal import * # noqa F403
from .nlp import * # noqa F403

View File

@@ -45,7 +45,8 @@ class DialogStateTrackingPipeline(Pipeline):
values = inputs['values']
inform = inputs['inform']
prefix = inputs['prefix']
ds = {slot: 'none' for slot in self.config.dst_slot_list}
# ds = {slot: 'none' for slot in self.config.dst_slot_list}
ds = inputs['ds']
ds = predict_and_format(self.config, self.tokenizer, _inputs,
_outputs[2], _outputs[3], _outputs[4],
@@ -113,7 +114,11 @@ def predict_and_format(config, tokenizer, features, per_slot_class_logits,
'false'):
dialog_state[slot] = 'false'
elif class_prediction == config.dst_class_types.index('inform'):
dialog_state[slot] = '§§' + inform[i][slot]
# dialog_state[slot] = '§§' + inform[i][slot]
if isinstance(inform[i][slot], str):
dialog_state[slot] = inform[i][slot]
elif isinstance(inform[i][slot], list):
dialog_state[slot] = inform[i][slot][0]
# Referral case is handled below
prediction_addendum['slot_prediction_%s'

View File

@@ -114,6 +114,44 @@ TASK_OUTPUTS = {
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.nli: ['scores', 'labels'],
Tasks.dialog_modeling: [],
Tasks.dialog_intent_prediction: [],
# {
# "dialog_states": {
# "taxi-leaveAt": "none",
# "taxi-destination": "none",
# "taxi-departure": "none",
# "taxi-arriveBy": "none",
# "restaurant-book_people": "none",
# "restaurant-book_day": "none",
# "restaurant-book_time": "none",
# "restaurant-food": "none",
# "restaurant-pricerange": "none",
# "restaurant-name": "none",
# "restaurant-area": "none",
# "hotel-book_people": "none",
# "hotel-book_day": "none",
# "hotel-book_stay": "none",
# "hotel-name": "none",
# "hotel-area": "none",
# "hotel-parking": "none",
# "hotel-pricerange": "cheap",
# "hotel-stars": "none",
# "hotel-internet": "none",
# "hotel-type": "true",
# "attraction-type": "none",
# "attraction-name": "none",
# "attraction-area": "none",
# "train-book_people": "none",
# "train-leaveAt": "none",
# "train-destination": "none",
# "train-day": "none",
# "train-arriveBy": "none",
# "train-departure": "none"
# }
# }
Tasks.dialog_state_tracking: ['dialog_states'],
# ============ audio tasks ===================
@@ -153,43 +191,5 @@ TASK_OUTPUTS = {
# {
# "image": np.ndarray with shape [height, width, 3]
# }
Tasks.text_to_image_synthesis: ['image'],
Tasks.dialog_modeling: [],
Tasks.dialog_intent_prediction: [],
# {
# "dialog_states": {
# "taxi-leaveAt": "none",
# "taxi-destination": "none",
# "taxi-departure": "none",
# "taxi-arriveBy": "none",
# "restaurant-book_people": "none",
# "restaurant-book_day": "none",
# "restaurant-book_time": "none",
# "restaurant-food": "none",
# "restaurant-pricerange": "none",
# "restaurant-name": "none",
# "restaurant-area": "none",
# "hotel-book_people": "none",
# "hotel-book_day": "none",
# "hotel-book_stay": "none",
# "hotel-name": "none",
# "hotel-area": "none",
# "hotel-parking": "none",
# "hotel-pricerange": "cheap",
# "hotel-stars": "none",
# "hotel-internet": "none",
# "hotel-type": "true",
# "attraction-type": "none",
# "attraction-name": "none",
# "attraction-area": "none",
# "train-book_people": "none",
# "train-leaveAt": "none",
# "train-destination": "none",
# "train-day": "none",
# "train-arriveBy": "none",
# "train-departure": "none"
# }
# }
Tasks.dialog_state_tracking: ['dialog_states']
Tasks.text_to_image_synthesis: ['image']
}

View File

@@ -118,8 +118,14 @@ class DialogStateTrackingPreprocessor(Preprocessor):
for slot in self.config.dst_slot_list
}
if len(history_states) > 2:
ds = history_states[-2]
else:
ds = {slot: 'none' for slot in self.config.dst_slot_list}
return {
'batch': dataset,
'features': features,
'diag_state': diag_state
'diag_state': diag_state,
'ds': ds
}

View File

@@ -432,6 +432,7 @@ class multiwoz22Processor(DSTProcessor):
usr_sys_switch = True
turn_itr = 0
inform_dict = {slot: 'none' for slot in slot_list}
for utt in utterances:
# Assert that system and user utterances alternate
is_sys_utt = utt['metadata'] != {}
@@ -1501,7 +1502,7 @@ if __name__ == '__main__':
}
}, {}]
example = processor.create_example(utter3, history_states3, set_type,
example = processor.create_example(utter2, history_states2, set_type,
slot_list, {}, append_history,
use_history_labels, swap_utterances,
label_value_repetitions,