import json import os class PromptDSTEvaluator: def __init__(self, outputs_file_path=None): self.true_states_list = [] self.gen_states_list = [] if outputs_file_path is not None and os.path.isfile(outputs_file_path): outputs = json.load(open(outputs_file_path)) for item in outputs: self.true_states_list.append(item['true_states']) self.gen_states_list.append(item['gen_states']) def add_data_item(self, true_states, gen_states): self.true_states_list.append(true_states) self.gen_states_list.append(gen_states) def compute_joint_goal_accuracy(self, no_print=False): if not no_print: print('Computing Joint Goal Accuracy metric...') if len(self.true_states_list) != len(self.gen_states_list): raise ValueError('Length mismatch!') # keep a count for computing JGA correctly_generated, total_turns = 0, 0 for truth, generated in zip(self.true_states_list, self.gen_states_list): total_turns += 1 if set(truth.keys()) != set(generated.keys()): continue has_wrong_slot_value = False for slot in truth: if truth[slot] != generated[slot]: has_wrong_slot_value = True break if not has_wrong_slot_value: correctly_generated += 1 jga_score = round((correctly_generated / total_turns) * 100, 2) if not no_print: print('Joint Goal Accuracy (JGA) = ', jga_score) return jga_score def compute_jga_for_correct_values(self, no_print=False): if not no_print: print('Computing Joint Goal Accuracy metric only where values are extracted correctly!') if len(self.true_states_list) != len(self.gen_states_list): raise ValueError('Length mismatch!') # keep a count for computing JGA correctly_generated, total_turns = 0, 0 for truth, generated in zip(self.true_states_list, self.gen_states_list): total_turns += 1 # compare the extracted values with true state values # use only the correctly extracted values while computing JGA* extracted_values = list(generated.values()) true_values = list(truth.values()) correct_values = list(set(true_values).intersection(extracted_values)) # if no extracted correct values, then continue if len(correct_values) <= 0: continue has_wrong_slot_value = False for slot, value in truth.items(): if value in correct_values: if slot not in generated or truth[slot] != generated[slot]: has_wrong_slot_value = True break if not has_wrong_slot_value: correctly_generated += 1 jga_star = round((correctly_generated / total_turns) * 100, 2) if not no_print: print('Joint Goal Accuracy* (JGA*) = ', jga_star) return jga_star