You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
4.5 KiB

import json
BELIEF_PREFIX = 'belief :'
INVALID_SLOT_VALUES = ["", "dontcare", "not mentioned", "don't care", "dont care", "do n't care", "none"]
def create_slot_value_map_from_belief_prediction(belief_prediction):
# remove 'belief:' from the beginning
if belief_prediction.startswith(BELIEF_PREFIX):
belief_prediction = belief_prediction[len(BELIEF_PREFIX):]
belief_state_splits = belief_prediction.split('|')
belief_slot_value_map = {}
for belief_state in belief_state_splits:
if belief_state == '':
continue
if len(belief_state.split()) == 0:
continue
domain = belief_state.split()[0]
if domain == 'none':
continue
# remove domain from belief state
belief_state = ' '.join(belief_state.split()[1:])
# split belief state slot-value pairs
slot_value_pairs = belief_state.split(';')
for slot_value_pair in slot_value_pairs:
if slot_value_pair.strip() == '':
continue
slot_values = slot_value_pair.split(' = ')
if len(slot_values) != 2:
continue
else:
slot, value = slot_values
slot = slot.strip().lower()
value = value.strip().lower()
if value in INVALID_SLOT_VALUES:
continue
belief_slot_value_map[slot] = value
return belief_slot_value_map
class BaselineDSTEvaluator:
def __init__(self, predictions_output_file, evaluation_file):
"""
create an Evaluator object for evaluating Baseline DST predictions
Args:
predictions_output_file: path of the predictions output JSON file
evaluation_file: path of the test/valid data file for extracting ground truth data
"""
# load predictions output json file
self.predictions = json.load(open(predictions_output_file))
# load test/valid data json file
self.eval_data = json.load(open(evaluation_file))
# do some length checks here
if len(self.predictions) == 0 or len(self.eval_data) == 0:
raise ValueError('Invalid Data (no items) in prediction or evaluation data file!')
if len(self.predictions) != len(self.eval_data):
raise ValueError('Length mismatch!')
def parse_prediction_belief_states(self):
belief_state_list = []
for prediction in self.predictions:
last_line = prediction[-1]
last_line = last_line.strip()
# remove system responses from the predictions
belief_prediction = last_line.split('system :')[0]
belief_slot_value_map = create_slot_value_map_from_belief_prediction(belief_prediction)
belief_state_list.append(belief_slot_value_map)
return belief_state_list
def parse_true_belief_states(self):
true_belief_state_list = []
for item in self.eval_data:
belief_prediction = item['belief']
belief_slot_value_map = create_slot_value_map_from_belief_prediction(belief_prediction)
true_belief_state_list.append(belief_slot_value_map)
return true_belief_state_list
def compute_joint_goal_accuracy(self, true_states, prediction_states):
print('Computing Joint Goal Accuracy metric...!')
if len(true_states) != len(prediction_states):
raise ValueError('Length mismatch!')
correctly_predicted, total_turns = 0, 0
for truth, prediction in zip(true_states, prediction_states):
# print("Truth :: ", truth)
# print("Prediction :: ", prediction)
total_turns += 1
if set(truth.keys()) != set(prediction.keys()):
continue
has_wrong_slot_value = False
for slot in truth:
if truth[slot] != prediction[slot]:
has_wrong_slot_value = True
break
if not has_wrong_slot_value:
correctly_predicted += 1
print('Evaluation :: Joint Goal Accuracy = ', (correctly_predicted / total_turns) * 100)
evaluator = BaselineDSTEvaluator('../outputs/baseline/experiment-20220829/50-dpd/checkpoint-55000/output_test.json',
'../data/baseline/test/test.soloist.json')
predicted_belief_states = evaluator.parse_prediction_belief_states()
true_belief_states = evaluator.parse_true_belief_states()
evaluator.compute_joint_goal_accuracy(true_belief_states, predicted_belief_states)