You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
3.1 KiB

import json
import os
class PromptDSTEvaluator:
def __init__(self, outputs_file_path=None):
self.true_states_list = []
self.gen_states_list = []
if outputs_file_path is not None and os.path.isfile(outputs_file_path):
outputs = json.load(open(outputs_file_path))
for item in outputs:
self.true_states_list.append(item['true_states'])
self.gen_states_list.append(item['gen_states'])
def add_data_item(self, true_states, gen_states):
self.true_states_list.append(true_states)
self.gen_states_list.append(gen_states)
def compute_joint_goal_accuracy(self, no_print=False):
if not no_print:
print('Computing Joint Goal Accuracy metric...')
if len(self.true_states_list) != len(self.gen_states_list):
raise ValueError('Length mismatch!')
# keep a count for computing JGA
correctly_generated, total_turns = 0, 0
for truth, generated in zip(self.true_states_list, self.gen_states_list):
total_turns += 1
if set(truth.keys()) != set(generated.keys()):
continue
has_wrong_slot_value = False
for slot in truth:
if truth[slot] != generated[slot]:
has_wrong_slot_value = True
break
if not has_wrong_slot_value:
correctly_generated += 1
jga_score = round((correctly_generated / total_turns) * 100, 2)
if not no_print:
print('Joint Goal Accuracy (JGA) = ', jga_score)
return jga_score
def compute_jga_for_correct_values(self, no_print=False):
if not no_print:
print('Computing Joint Goal Accuracy metric only where values are extracted correctly!')
if len(self.true_states_list) != len(self.gen_states_list):
raise ValueError('Length mismatch!')
# keep a count for computing JGA
correctly_generated, total_turns = 0, 0
for truth, generated in zip(self.true_states_list, self.gen_states_list):
total_turns += 1
# compare the extracted values with true state values
# use only the correctly extracted values while computing JGA*
extracted_values = list(generated.values())
true_values = list(truth.values())
correct_values = list(set(true_values).intersection(extracted_values))
# if no extracted correct values, then continue
if len(correct_values) <= 0:
continue
has_wrong_slot_value = False
for slot, value in truth.items():
if value in correct_values:
if slot not in generated or truth[slot] != generated[slot]:
has_wrong_slot_value = True
break
if not has_wrong_slot_value:
correctly_generated += 1
jga_star = round((correctly_generated / total_turns) * 100, 2)
if not no_print:
print('Joint Goal Accuracy* (JGA*) = ', jga_star)
return jga_star