You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
4.5 KiB
128 lines
4.5 KiB
import json
|
|
|
|
BELIEF_PREFIX = 'belief :'
|
|
INVALID_SLOT_VALUES = ["", "dontcare", "not mentioned", "don't care", "dont care", "do n't care", "none"]
|
|
|
|
|
|
def create_slot_value_map_from_belief_prediction(belief_prediction):
|
|
# remove 'belief:' from the beginning
|
|
if belief_prediction.startswith(BELIEF_PREFIX):
|
|
belief_prediction = belief_prediction[len(BELIEF_PREFIX):]
|
|
|
|
belief_state_splits = belief_prediction.split('|')
|
|
|
|
belief_slot_value_map = {}
|
|
for belief_state in belief_state_splits:
|
|
if belief_state == '':
|
|
continue
|
|
if len(belief_state.split()) == 0:
|
|
continue
|
|
domain = belief_state.split()[0]
|
|
if domain == 'none':
|
|
continue
|
|
|
|
# remove domain from belief state
|
|
belief_state = ' '.join(belief_state.split()[1:])
|
|
|
|
# split belief state slot-value pairs
|
|
slot_value_pairs = belief_state.split(';')
|
|
|
|
for slot_value_pair in slot_value_pairs:
|
|
if slot_value_pair.strip() == '':
|
|
continue
|
|
slot_values = slot_value_pair.split(' = ')
|
|
if len(slot_values) != 2:
|
|
continue
|
|
else:
|
|
slot, value = slot_values
|
|
slot = slot.strip().lower()
|
|
value = value.strip().lower()
|
|
if value in INVALID_SLOT_VALUES:
|
|
continue
|
|
belief_slot_value_map[slot] = value
|
|
|
|
return belief_slot_value_map
|
|
|
|
|
|
class BaselineDSTEvaluator:
|
|
def __init__(self, predictions_output_file, evaluation_file):
|
|
"""
|
|
create an Evaluator object for evaluating Baseline DST predictions
|
|
|
|
Args:
|
|
predictions_output_file: path of the predictions output JSON file
|
|
evaluation_file: path of the test/valid data file for extracting ground truth data
|
|
"""
|
|
# load predictions output json file
|
|
self.predictions = json.load(open(predictions_output_file))
|
|
# load test/valid data json file
|
|
self.eval_data = json.load(open(evaluation_file))
|
|
|
|
# do some length checks here
|
|
if len(self.predictions) == 0 or len(self.eval_data) == 0:
|
|
raise ValueError('Invalid Data (no items) in prediction or evaluation data file!')
|
|
|
|
if len(self.predictions) != len(self.eval_data):
|
|
raise ValueError('Length mismatch!')
|
|
|
|
def parse_prediction_belief_states(self):
|
|
|
|
belief_state_list = []
|
|
|
|
for prediction in self.predictions:
|
|
last_line = prediction[-1]
|
|
last_line = last_line.strip()
|
|
|
|
# remove system responses from the predictions
|
|
belief_prediction = last_line.split('system :')[0]
|
|
|
|
belief_slot_value_map = create_slot_value_map_from_belief_prediction(belief_prediction)
|
|
|
|
belief_state_list.append(belief_slot_value_map)
|
|
return belief_state_list
|
|
|
|
def parse_true_belief_states(self):
|
|
|
|
true_belief_state_list = []
|
|
|
|
for item in self.eval_data:
|
|
belief_prediction = item['belief']
|
|
|
|
belief_slot_value_map = create_slot_value_map_from_belief_prediction(belief_prediction)
|
|
true_belief_state_list.append(belief_slot_value_map)
|
|
|
|
return true_belief_state_list
|
|
|
|
def compute_joint_goal_accuracy(self, true_states, prediction_states):
|
|
print('Computing Joint Goal Accuracy metric...!')
|
|
|
|
if len(true_states) != len(prediction_states):
|
|
raise ValueError('Length mismatch!')
|
|
|
|
correctly_predicted, total_turns = 0, 0
|
|
for truth, prediction in zip(true_states, prediction_states):
|
|
# print("Truth :: ", truth)
|
|
# print("Prediction :: ", prediction)
|
|
total_turns += 1
|
|
|
|
if set(truth.keys()) != set(prediction.keys()):
|
|
continue
|
|
|
|
has_wrong_slot_value = False
|
|
for slot in truth:
|
|
if truth[slot] != prediction[slot]:
|
|
has_wrong_slot_value = True
|
|
break
|
|
if not has_wrong_slot_value:
|
|
correctly_predicted += 1
|
|
|
|
print('Evaluation :: Joint Goal Accuracy = ', (correctly_predicted / total_turns) * 100)
|
|
|
|
|
|
evaluator = BaselineDSTEvaluator('../outputs/baseline/50-dpd/checkpoint-55000/output_test.json',
|
|
'../data/baseline/test/test.soloist.json')
|
|
predicted_belief_states = evaluator.parse_prediction_belief_states()
|
|
true_belief_states = evaluator.parse_true_belief_states()
|
|
evaluator.compute_joint_goal_accuracy(true_belief_states, predicted_belief_states)
|
|
|