You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.1 KiB
87 lines
3.1 KiB
import json
|
|
import os
|
|
|
|
|
|
class PromptDSTEvaluator:
|
|
|
|
def __init__(self, outputs_file_path=None):
|
|
|
|
self.true_states_list = []
|
|
self.gen_states_list = []
|
|
|
|
if outputs_file_path is not None and os.path.isfile(outputs_file_path):
|
|
outputs = json.load(open(outputs_file_path))
|
|
for item in outputs:
|
|
self.true_states_list.append(item['true_states'])
|
|
self.gen_states_list.append(item['gen_states'])
|
|
|
|
def add_data_item(self, true_states, gen_states):
|
|
self.true_states_list.append(true_states)
|
|
self.gen_states_list.append(gen_states)
|
|
|
|
def compute_joint_goal_accuracy(self, no_print=False):
|
|
if not no_print:
|
|
print('Computing Joint Goal Accuracy metric...')
|
|
|
|
if len(self.true_states_list) != len(self.gen_states_list):
|
|
raise ValueError('Length mismatch!')
|
|
|
|
# keep a count for computing JGA
|
|
correctly_generated, total_turns = 0, 0
|
|
|
|
for truth, generated in zip(self.true_states_list, self.gen_states_list):
|
|
total_turns += 1
|
|
|
|
if set(truth.keys()) != set(generated.keys()):
|
|
continue
|
|
|
|
has_wrong_slot_value = False
|
|
for slot in truth:
|
|
if truth[slot] != generated[slot]:
|
|
has_wrong_slot_value = True
|
|
break
|
|
if not has_wrong_slot_value:
|
|
correctly_generated += 1
|
|
|
|
jga_score = round((correctly_generated / total_turns) * 100, 2)
|
|
if not no_print:
|
|
print('Joint Goal Accuracy (JGA) = ', jga_score)
|
|
return jga_score
|
|
|
|
def compute_jga_for_correct_values(self, no_print=False):
|
|
if not no_print:
|
|
print('Computing Joint Goal Accuracy metric for the values that are extracted correctly! (turn-level)')
|
|
|
|
if len(self.true_states_list) != len(self.gen_states_list):
|
|
raise ValueError('Unable to compute the metric. Length mismatch in the outputs!')
|
|
|
|
# keep a count for computing JGA
|
|
correctly_generated, total_turns = 0, 0
|
|
|
|
for truth, generated in zip(self.true_states_list, self.gen_states_list):
|
|
total_turns += 1
|
|
|
|
# compare the extracted values with true state values
|
|
# use only the correctly extracted values while computing JGA*
|
|
extracted_values = list(generated.values())
|
|
true_values = list(truth.values())
|
|
|
|
correct_values = list(set(true_values).intersection(extracted_values))
|
|
# if no extracted correct values, then continue
|
|
if len(correct_values) <= 0:
|
|
continue
|
|
|
|
has_wrong_slot_value = False
|
|
for slot, value in truth.items():
|
|
if value in correct_values:
|
|
if slot not in generated or truth[slot] != generated[slot]:
|
|
has_wrong_slot_value = True
|
|
break
|
|
if not has_wrong_slot_value:
|
|
correctly_generated += 1
|
|
|
|
jga_star = round((correctly_generated / total_turns) * 100, 2)
|
|
if not no_print:
|
|
print('Joint Goal Accuracy* (JGA*) = ', jga_star)
|
|
return jga_star
|