added JGA* metric compute, improved outputs file with extracted values, improved inverse prompt

main
Pavan Mandava 3 years ago
parent 52f9bfa8d1
commit 678b837e6f

@ -27,7 +27,7 @@ class PromptDSTEvaluator:
raise ValueError('Length mismatch!')
# keep a count for computing JGA
correctly_predicted, total_turns = 0, 0
correctly_generated, total_turns = 0, 0
for truth, generated in zip(self.true_states_list, self.gen_states_list):
total_turns += 1
@ -41,9 +41,46 @@ class PromptDSTEvaluator:
has_wrong_slot_value = True
break
if not has_wrong_slot_value:
correctly_predicted += 1
correctly_generated += 1
jga_score = (correctly_predicted / total_turns) * 100
jga_score = round((correctly_generated / total_turns) * 100, 2)
if not no_print:
print('Joint Goal Accuracy = ', jga_score)
print('Joint Goal Accuracy (JGA) = ', jga_score)
return jga_score
def compute_jga_for_correct_values(self, no_print=False):
if not no_print:
print('Computing Joint Goal Accuracy metric only where values are extracted correctly!')
if len(self.true_states_list) != len(self.gen_states_list):
raise ValueError('Length mismatch!')
# keep a count for computing JGA
correctly_generated, total_turns = 0, 0
for truth, generated in zip(self.true_states_list, self.gen_states_list):
total_turns += 1
# compare the extracted values with true state values
# use only the correctly extracted values while computing JGA*
extracted_values = list(generated.values())
true_values = list(truth.values())
correct_values = list(set(true_values).intersection(extracted_values))
# if no extracted correct values, then continue
if len(correct_values) <= 0:
continue
has_wrong_slot_value = False
for slot, value in truth.items():
if value in correct_values:
if slot not in generated or truth[slot] != generated[slot]:
has_wrong_slot_value = True
break
if not has_wrong_slot_value:
correctly_generated += 1
jga_star = round((correctly_generated / total_turns) * 100, 2)
if not no_print:
print('Joint Goal Accuracy* (JGA*) = ', jga_star)
return jga_star

@ -169,14 +169,18 @@ def main():
# add true belief states & generated belief states to outputs
outputs.append({"history": history,
"extracted_values": item['values'],
"true_states": true_states,
"gen_states": gen_states})
# add true & generated belief states to evaluator for computing JGA
evaluator.add_data_item(true_states.copy(), gen_states.copy())
progress.close()
# compute JGA & print results
evaluator.compute_joint_goal_accuracy()
# compute JGA* & print results (JGA* -> consider values that are extracted correctly)
evaluator.compute_jga_for_correct_values()
# output file extension (for prompt ensemble & answered prompts)
out_ext = "_pe" if args.with_prompt_ensemble else ""

@ -110,7 +110,7 @@ def main():
# set the model in training mode
model.train()
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
tqdm.write(str('Training starts now... [with_prompt_ensemble = ' + str(args.with_prompt_ensemble) + ']'))
validation_summary = {
'validation_with_true_values': True
@ -178,8 +178,8 @@ def main():
loss_count = 0
# Save the model after finishing an epoch
epoch_str = "{:02d}".format(epoch + 1)
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", epoch_str))
epoch_str = "epoch-{:02d}".format(epoch + 1)
output_dir = os.path.join(args.save_model_dir, epoch_str)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# save training args for each epoch (useful when testing/generating)
@ -196,7 +196,7 @@ def main():
# if validation file is provided, run evaluation here (after each epoch)
if args.validation_file and validation_data is not None:
tqdm.write(str('Validation In Progress...[with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
tqdm.write(str('Validation In Progress...[with_prompt_ensemble = ' + str(args.with_prompt_ensemble) + ']'))
# set tqdm progress bars for testing progress
validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False)
@ -280,7 +280,15 @@ def save_args(save_dir, args):
os.makedirs(save_dir)
args_file = os.path.join(save_dir, 'args.json')
args_dict = vars(args)
args_dict['prompt_templates'] = PROMPT_TEMPLATES
# save prompt templates used for training & inference
prompt_templates = PROMPT_TEMPLATES.copy()
if not args.with_inverse_prompt:
prompt_templates.pop(TYPE_INVERSE_PROMPT)
if args.with_prompt_ensemble:
prompt_templates.pop(TYPE_VALUE_BASED_PROMPT)
else:
prompt_templates.pop(TYPE_PROMPT_ENSEMBLE)
args_dict['prompt_templates'] = prompt_templates
json.dump(args_dict, open(args_file, "w"), indent=2)
@ -314,6 +322,36 @@ def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, mode
# model outputs
outputs = model(**encoded_prompt)
if prompt_type == TYPE_INVERSE_PROMPT:
# value ids
value_ids = tokenizer.encode(slot_value_pair[1], return_tensors="pt", add_prefix_space=True)
flipped_value_ids = torch.flip(value_ids, dims=[1])
flipped_logits = torch.flip(outputs.logits, dims=[1])
index = 1
# iterate through the value ids and compute loss (combined probability)
total_prob = None
for item in flipped_value_ids[0]:
token_logits = flipped_logits[:, index, :]
index += 1
# softmax probabilities
probs = torch.nn.functional.softmax(token_logits, dim=-1)
# this token generation probability
token_prob = torch.gather(probs, 1, torch.tensor([[item]], device=device)).squeeze(-1)
# multiply the probabilities for each word in belief state value
if total_prob is None:
total_prob = token_prob
else:
total_prob *= token_prob
loss = torch.negative(torch.log(total_prob))
# return loss and 'None' for generated values
return loss, None
# loss for slot generation using value-based prompt & the generated slot
if prompt_type == TYPE_VALUE_BASED_PROMPT:
# get last token logits [-2 -> for last but one item]
logits = outputs.logits[:, -2, :]
@ -325,9 +363,6 @@ def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, mode
loss = torch.negative(torch.log(last_token_prob))
# generated word -> the one with the highest probability
generated_word = None
if prompt_type == TYPE_VALUE_BASED_PROMPT:
# find the token with the highest probability, this will be the generated word
gen_word_token = torch.argmax(logits, dim=-1)
generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()

@ -12,7 +12,7 @@ PROMPT_TEMPLATES = {
"generate": "belief states: value = $value, slot ="
},
"inverse-prompt": {
"training": "belief states: slot = $slot, value = $value",
"training": "belief states: $slot = $value",
},
"prompt-ensemble": {
"training": {

Loading…
Cancel
Save