import json import os class PromptDSTEvaluator: def __init__(self, outputs_file_path=None): self.true_states_list = [] self.gen_states_list = [] if outputs_file_path is not None and os.path.isfile(outputs_file_path): outputs = json.load(open(outputs_file_path)) for item in outputs: self.true_states_list.append(item['true_states']) self.gen_states_list.append(item['gen_states']) def add_data_item(self, true_states, gen_states): self.true_states_list.append(true_states) self.gen_states_list.append(gen_states) def compute_joint_goal_accuracy(self, no_print=False): if not no_print: print('Computing Joint Goal Accuracy metric...') if len(self.true_states_list) != len(self.gen_states_list): raise ValueError('Length mismatch!') # keep a count for computing JGA correctly_predicted, total_turns = 0, 0 for truth, generated in zip(self.true_states_list, self.gen_states_list): total_turns += 1 if set(truth.keys()) != set(generated.keys()): continue has_wrong_slot_value = False for slot in truth: if truth[slot] != generated[slot]: has_wrong_slot_value = True break if not has_wrong_slot_value: correctly_predicted += 1 jga_score = (correctly_predicted / total_turns) * 100 if not no_print: print('Joint Goal Accuracy = ', jga_score) return jga_score