From 678b837e6ff864455f1cceed5df95d809dba7723 Mon Sep 17 00:00:00 2001 From: Pavan Mandava Date: Sun, 27 Nov 2022 10:31:54 +0100 Subject: [PATCH] added JGA* metric compute, improved outputs file with extracted values, improved inverse prompt --- prompt-learning/metrics.py | 45 ++++++++++++++++++-- prompt-learning/prompt_decode.py | 4 ++ prompt-learning/prompt_train.py | 71 ++++++++++++++++++++++++-------- prompt-learning/prompt_utils.py | 2 +- 4 files changed, 99 insertions(+), 23 deletions(-) diff --git a/prompt-learning/metrics.py b/prompt-learning/metrics.py index 8170015..24ec519 100644 --- a/prompt-learning/metrics.py +++ b/prompt-learning/metrics.py @@ -27,7 +27,7 @@ class PromptDSTEvaluator: raise ValueError('Length mismatch!') # keep a count for computing JGA - correctly_predicted, total_turns = 0, 0 + correctly_generated, total_turns = 0, 0 for truth, generated in zip(self.true_states_list, self.gen_states_list): total_turns += 1 @@ -41,9 +41,46 @@ class PromptDSTEvaluator: has_wrong_slot_value = True break if not has_wrong_slot_value: - correctly_predicted += 1 + correctly_generated += 1 - jga_score = (correctly_predicted / total_turns) * 100 + jga_score = round((correctly_generated / total_turns) * 100, 2) if not no_print: - print('Joint Goal Accuracy = ', jga_score) + print('Joint Goal Accuracy (JGA) = ', jga_score) return jga_score + + def compute_jga_for_correct_values(self, no_print=False): + if not no_print: + print('Computing Joint Goal Accuracy metric only where values are extracted correctly!') + + if len(self.true_states_list) != len(self.gen_states_list): + raise ValueError('Length mismatch!') + + # keep a count for computing JGA + correctly_generated, total_turns = 0, 0 + + for truth, generated in zip(self.true_states_list, self.gen_states_list): + total_turns += 1 + + # compare the extracted values with true state values + # use only the correctly extracted values while computing JGA* + extracted_values = list(generated.values()) + true_values = list(truth.values()) + + correct_values = list(set(true_values).intersection(extracted_values)) + # if no extracted correct values, then continue + if len(correct_values) <= 0: + continue + + has_wrong_slot_value = False + for slot, value in truth.items(): + if value in correct_values: + if slot not in generated or truth[slot] != generated[slot]: + has_wrong_slot_value = True + break + if not has_wrong_slot_value: + correctly_generated += 1 + + jga_star = round((correctly_generated / total_turns) * 100, 2) + if not no_print: + print('Joint Goal Accuracy* (JGA*) = ', jga_star) + return jga_star diff --git a/prompt-learning/prompt_decode.py b/prompt-learning/prompt_decode.py index 8f14523..de2eef9 100644 --- a/prompt-learning/prompt_decode.py +++ b/prompt-learning/prompt_decode.py @@ -169,14 +169,18 @@ def main(): # add true belief states & generated belief states to outputs outputs.append({"history": history, + "extracted_values": item['values'], "true_states": true_states, "gen_states": gen_states}) # add true & generated belief states to evaluator for computing JGA evaluator.add_data_item(true_states.copy(), gen_states.copy()) + progress.close() # compute JGA & print results evaluator.compute_joint_goal_accuracy() + # compute JGA* & print results (JGA* -> consider values that are extracted correctly) + evaluator.compute_jga_for_correct_values() # output file extension (for prompt ensemble & answered prompts) out_ext = "_pe" if args.with_prompt_ensemble else "" diff --git a/prompt-learning/prompt_train.py b/prompt-learning/prompt_train.py index 5fcec7c..4a7a58b 100644 --- a/prompt-learning/prompt_train.py +++ b/prompt-learning/prompt_train.py @@ -110,7 +110,7 @@ def main(): # set the model in training mode model.train() - tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) + tqdm.write(str('Training starts now... [with_prompt_ensemble = ' + str(args.with_prompt_ensemble) + ']')) validation_summary = { 'validation_with_true_values': True @@ -178,8 +178,8 @@ def main(): loss_count = 0 # Save the model after finishing an epoch - epoch_str = "{:02d}".format(epoch + 1) - output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", epoch_str)) + epoch_str = "epoch-{:02d}".format(epoch + 1) + output_dir = os.path.join(args.save_model_dir, epoch_str) if not os.path.exists(output_dir): os.makedirs(output_dir) # save training args for each epoch (useful when testing/generating) @@ -196,7 +196,7 @@ def main(): # if validation file is provided, run evaluation here (after each epoch) if args.validation_file and validation_data is not None: - tqdm.write(str('Validation In Progress...[with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) + tqdm.write(str('Validation In Progress...[with_prompt_ensemble = ' + str(args.with_prompt_ensemble) + ']')) # set tqdm progress bars for testing progress validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False) @@ -280,7 +280,15 @@ def save_args(save_dir, args): os.makedirs(save_dir) args_file = os.path.join(save_dir, 'args.json') args_dict = vars(args) - args_dict['prompt_templates'] = PROMPT_TEMPLATES + # save prompt templates used for training & inference + prompt_templates = PROMPT_TEMPLATES.copy() + if not args.with_inverse_prompt: + prompt_templates.pop(TYPE_INVERSE_PROMPT) + if args.with_prompt_ensemble: + prompt_templates.pop(TYPE_VALUE_BASED_PROMPT) + else: + prompt_templates.pop(TYPE_PROMPT_ENSEMBLE) + args_dict['prompt_templates'] = prompt_templates json.dump(args_dict, open(args_file, "w"), indent=2) @@ -314,25 +322,52 @@ def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, mode # model outputs outputs = model(**encoded_prompt) - # get last token logits [-2 -> for last but one item] - logits = outputs.logits[:, -2, :] + if prompt_type == TYPE_INVERSE_PROMPT: + # value ids + value_ids = tokenizer.encode(slot_value_pair[1], return_tensors="pt", add_prefix_space=True) + flipped_value_ids = torch.flip(value_ids, dims=[1]) + flipped_logits = torch.flip(outputs.logits, dims=[1]) + + index = 1 + # iterate through the value ids and compute loss (combined probability) + total_prob = None + for item in flipped_value_ids[0]: + token_logits = flipped_logits[:, index, :] + index += 1 + + # softmax probabilities + probs = torch.nn.functional.softmax(token_logits, dim=-1) + + # this token generation probability + token_prob = torch.gather(probs, 1, torch.tensor([[item]], device=device)).squeeze(-1) + # multiply the probabilities for each word in belief state value + if total_prob is None: + total_prob = token_prob + else: + total_prob *= token_prob + + loss = torch.negative(torch.log(total_prob)) + # return loss and 'None' for generated values + return loss, None + + # loss for slot generation using value-based prompt & the generated slot + if prompt_type == TYPE_VALUE_BASED_PROMPT: + # get last token logits [-2 -> for last but one item] + logits = outputs.logits[:, -2, :] - # softmax probabilities - probs = torch.nn.functional.softmax(logits, dim=-1) + # softmax probabilities + probs = torch.nn.functional.softmax(logits, dim=-1) - # last token generation probability - last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1) - loss = torch.negative(torch.log(last_token_prob)) + # last token generation probability + last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1) + loss = torch.negative(torch.log(last_token_prob)) - # generated word -> the one with the highest probability - generated_word = None - if prompt_type == TYPE_VALUE_BASED_PROMPT: - # find the token with the highest probability, this will be the generated word + # generated word -> the one with the highest probability gen_word_token = torch.argmax(logits, dim=-1) generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip() - # loss is the log of probability - return loss, generated_word + # loss is the log of probability + return loss, generated_word def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, model, device): diff --git a/prompt-learning/prompt_utils.py b/prompt-learning/prompt_utils.py index acefd9d..87e589f 100644 --- a/prompt-learning/prompt_utils.py +++ b/prompt-learning/prompt_utils.py @@ -12,7 +12,7 @@ PROMPT_TEMPLATES = { "generate": "belief states: value = $value, slot =" }, "inverse-prompt": { - "training": "belief states: slot = $slot, value = $value", + "training": "belief states: $slot = $value", }, "prompt-ensemble": { "training": {