You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

50 lines
1.6 KiB

import json
import os
class PromptDSTEvaluator:
def __init__(self, outputs_file_path=None):
self.true_states_list = []
self.gen_states_list = []
if outputs_file_path is not None and os.path.isfile(outputs_file_path):
outputs = json.load(open(outputs_file_path))
for item in outputs:
self.true_states_list.append(item['true_states'])
self.gen_states_list.append(item['gen_states'])
def add_data_item(self, true_states, gen_states):
self.true_states_list.append(true_states)
self.gen_states_list.append(gen_states)
def compute_joint_goal_accuracy(self, no_print=False):
if not no_print:
print('Computing Joint Goal Accuracy metric...')
if len(self.true_states_list) != len(self.gen_states_list):
raise ValueError('Length mismatch!')
# keep a count for computing JGA
correctly_predicted, total_turns = 0, 0
for truth, generated in zip(self.true_states_list, self.gen_states_list):
total_turns += 1
if set(truth.keys()) != set(generated.keys()):
continue
has_wrong_slot_value = False
for slot in truth:
if truth[slot] != generated[slot]:
has_wrong_slot_value = True
break
if not has_wrong_slot_value:
correctly_predicted += 1
jga_score = (correctly_predicted / total_turns) * 100
if not no_print:
print('Joint Goal Accuracy = ', jga_score)
return jga_score