From 1d65ca348963a7b52938f6eee44cafe47fed6dd5 Mon Sep 17 00:00:00 2001 From: Pavan Mandava Date: Thu, 24 Nov 2022 08:28:02 +0100 Subject: [PATCH] Added more info to validation summary --- prompt-learning/prompt_train.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/prompt-learning/prompt_train.py b/prompt-learning/prompt_train.py index 65fac6f..5fcec7c 100644 --- a/prompt-learning/prompt_train.py +++ b/prompt-learning/prompt_train.py @@ -51,9 +51,6 @@ def main(): help="Flag to enable/disable the use of answered prompts while validation") parser.add_argument("--validation_file", default="", type=str, help="Validation file for evaluating model after each epoch") - parser.add_argument("--validation_with_true_values", action="store_true", - help="Flag for enabling/disabling the usage of TRUE values for slot generation during " - "validation") # parse the arguments args = parser.parse_args() @@ -62,6 +59,9 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.n_gpu = torch.cuda.device_count() + # set seed + set_seed(args) + # Print args print('Training Args :: ', json.dumps(vars(args), indent=2)) save_args(args.save_model_dir, args) @@ -73,9 +73,6 @@ def main(): # set the device to the model model.to(device) - # set seed - set_seed(args) - # load training dataset training_data = PromptDstDataset(args.train_data_file) @@ -116,7 +113,7 @@ def main(): tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) validation_summary = { - 'validation_with_true_values': args.validation_with_true_values + 'validation_with_true_values': True } # outputs file extension (representing usage of prompt ensemble & answered prompts)