You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.6 KiB
50 lines
1.6 KiB
import json
|
|
import os
|
|
|
|
|
|
class PromptDSTEvaluator:
|
|
|
|
def __init__(self, outputs_file_path=None):
|
|
|
|
self.true_states_list = []
|
|
self.gen_states_list = []
|
|
|
|
if outputs_file_path is not None and os.path.isfile(outputs_file_path):
|
|
outputs = json.load(open(outputs_file_path))
|
|
for item in outputs:
|
|
self.true_states_list.append(item['true_states'])
|
|
self.gen_states_list.append(item['gen_states'])
|
|
|
|
def add_data_item(self, true_states, gen_states):
|
|
self.true_states_list.append(true_states)
|
|
self.gen_states_list.append(gen_states)
|
|
|
|
def compute_joint_goal_accuracy(self, no_print=False):
|
|
if not no_print:
|
|
print('Computing Joint Goal Accuracy metric...')
|
|
|
|
if len(self.true_states_list) != len(self.gen_states_list):
|
|
raise ValueError('Length mismatch!')
|
|
|
|
# keep a count for computing JGA
|
|
correctly_predicted, total_turns = 0, 0
|
|
|
|
for truth, generated in zip(self.true_states_list, self.gen_states_list):
|
|
total_turns += 1
|
|
|
|
if set(truth.keys()) != set(generated.keys()):
|
|
continue
|
|
|
|
has_wrong_slot_value = False
|
|
for slot in truth:
|
|
if truth[slot] != generated[slot]:
|
|
has_wrong_slot_value = True
|
|
break
|
|
if not has_wrong_slot_value:
|
|
correctly_predicted += 1
|
|
|
|
jga_score = (correctly_predicted / total_turns) * 100
|
|
if not no_print:
|
|
print('Joint Goal Accuracy = ', jga_score)
|
|
return jga_score
|