From cc281edce63d23f0e41c49a811b9c1ea388f9c8d Mon Sep 17 00:00:00 2001 From: Pavan Mandava Date: Wed, 23 Nov 2022 20:35:41 +0100 Subject: [PATCH] Saving Validation summary in json and prompt templates in args file --- prompt-learning/prompt_train.py | 42 ++++++++++++++++++++++++++------- prompt-learning/prompt_utils.py | 28 ++++++++++------------ 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/prompt-learning/prompt_train.py b/prompt-learning/prompt_train.py index 4de020a..65fac6f 100644 --- a/prompt-learning/prompt_train.py +++ b/prompt-learning/prompt_train.py @@ -15,6 +15,7 @@ from prompt_utils import get_prompt_for_training from prompt_utils import TYPE_VALUE_BASED_PROMPT from prompt_utils import TYPE_INVERSE_PROMPT from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT +from prompt_utils import PROMPT_TEMPLATES from metrics import PromptDSTEvaluator from datetime import datetime @@ -50,6 +51,9 @@ def main(): help="Flag to enable/disable the use of answered prompts while validation") parser.add_argument("--validation_file", default="", type=str, help="Validation file for evaluating model after each epoch") + parser.add_argument("--validation_with_true_values", action="store_true", + help="Flag for enabling/disabling the usage of TRUE values for slot generation during " + "validation") # parse the arguments args = parser.parse_args() @@ -111,6 +115,14 @@ def main(): tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) + validation_summary = { + 'validation_with_true_values': args.validation_with_true_values + } + + # outputs file extension (representing usage of prompt ensemble & answered prompts) + out_ext = "_pe" if args.with_prompt_ensemble else "" + out_ext += "_pa" if args.with_answered_prompts else "" + # training loop for epoch in range(args.num_epochs): running_loss = 0.0 @@ -169,7 +181,8 @@ def main(): loss_count = 0 # Save the model after finishing an epoch - output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1))) + epoch_str = "{:02d}".format(epoch + 1) + output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", epoch_str)) if not os.path.exists(output_dir): os.makedirs(output_dir) # save training args for each epoch (useful when testing/generating) @@ -230,22 +243,33 @@ def main(): evaluator.add_data_item(true_states.copy(), gen_states.copy()) validation_progress.close() + epoch_valid_summary = {} # compute JGA & print results tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...')) jga_score = evaluator.compute_joint_goal_accuracy(no_print=True) tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score))) - # output file extension (for prompt ensemble & answered prompts) - out_ext = "_pe" if args.with_prompt_ensemble else "" - out_ext += "_pa" if args.with_answered_prompts else "" - - # save the outputs to trained epoch dir now = datetime.now() datetime_str = now.strftime("%Y%m%dT%H%M%S") - output_file = os.path.join(output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str)) + + # outputs file name + file_name = "outputs{}-{}.json".format(out_ext, datetime_str) + epoch_valid_summary['file_name'] = file_name + epoch_valid_summary['jga_score'] = round(jga_score, 3) + # add epoch summary to valid summary + validation_summary[epoch_str] = epoch_valid_summary + + # save the outputs to trained epoch dir + output_file = os.path.join(output_dir, file_name) tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file))) json.dump(outputs, open(output_file, 'w'), indent=2) + # save validation file summary (if there's data) + if len(validation_summary) > 1: + valid_summary_file = os.path.join(args.save_model_dir, 'validation{}.json'.format(out_ext)) + tqdm.write(str('Saving Validation Summary :: {}'.format(valid_summary_file))) + json.dump(validation_summary, open(valid_summary_file, 'w'), indent=2) + def set_seed(args): np.random.seed(args.seed) @@ -258,7 +282,9 @@ def save_args(save_dir, args): if not os.path.exists(save_dir): os.makedirs(save_dir) args_file = os.path.join(save_dir, 'args.json') - json.dump(vars(args), open(args_file, "w"), indent=2) + args_dict = vars(args) + args_dict['prompt_templates'] = PROMPT_TEMPLATES + json.dump(args_dict, open(args_file, "w"), indent=2) def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device): diff --git a/prompt-learning/prompt_utils.py b/prompt-learning/prompt_utils.py index 19ccaae..87e589f 100644 --- a/prompt-learning/prompt_utils.py +++ b/prompt-learning/prompt_utils.py @@ -11,7 +11,9 @@ PROMPT_TEMPLATES = { "training": "belief states: value = $value, slot = $slot", "generate": "belief states: value = $value, slot =" }, - "inverse-prompt": "belief states: slot = $slot, value = $value", + "inverse-prompt": { + "training": "belief states: $slot = $value", + }, "prompt-ensemble": { "training": { "p1": "belief states: value = $value, slot = $slot", @@ -30,20 +32,16 @@ PROMPT_TEMPLATES = { def get_prompt_for_training(typ, slot_value): - if typ == TYPE_INVERSE_PROMPT: - template = Template(PROMPT_TEMPLATES[typ]) - return template.substitute(slot=slot_value[0], value=slot_value[1]) - else: - template = PROMPT_TEMPLATES[typ]['training'] - if isinstance(template, str): - return Template(template).substitute(slot=slot_value[0], value=slot_value[1]) - elif isinstance(template, dict): - template_list = template.values() - prompt_list = [] - for template_str in template_list: - template = Template(template_str) - prompt_list.append(template.substitute(slot=slot_value[0], value=slot_value[1])) - return prompt_list + template = PROMPT_TEMPLATES[typ]['training'] + if isinstance(template, str): + return Template(template).substitute(slot=slot_value[0], value=slot_value[1]) + elif isinstance(template, dict): + template_list = template.values() + prompt_list = [] + for template_str in template_list: + prompt = Template(template_str).substitute(slot=slot_value[0], value=slot_value[1]) + prompt_list.append(prompt) + return prompt_list def get_value_based_prompt(value):