|
|
|
@ -51,9 +51,6 @@ def main():
|
|
|
|
help="Flag to enable/disable the use of answered prompts while validation")
|
|
|
|
help="Flag to enable/disable the use of answered prompts while validation")
|
|
|
|
parser.add_argument("--validation_file", default="", type=str,
|
|
|
|
parser.add_argument("--validation_file", default="", type=str,
|
|
|
|
help="Validation file for evaluating model after each epoch")
|
|
|
|
help="Validation file for evaluating model after each epoch")
|
|
|
|
parser.add_argument("--validation_with_true_values", action="store_true",
|
|
|
|
|
|
|
|
help="Flag for enabling/disabling the usage of TRUE values for slot generation during "
|
|
|
|
|
|
|
|
"validation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# parse the arguments
|
|
|
|
# parse the arguments
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
@ -62,6 +59,9 @@ def main():
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
args.n_gpu = torch.cuda.device_count()
|
|
|
|
args.n_gpu = torch.cuda.device_count()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# set seed
|
|
|
|
|
|
|
|
set_seed(args)
|
|
|
|
|
|
|
|
|
|
|
|
# Print args
|
|
|
|
# Print args
|
|
|
|
print('Training Args :: ', json.dumps(vars(args), indent=2))
|
|
|
|
print('Training Args :: ', json.dumps(vars(args), indent=2))
|
|
|
|
save_args(args.save_model_dir, args)
|
|
|
|
save_args(args.save_model_dir, args)
|
|
|
|
@ -73,9 +73,6 @@ def main():
|
|
|
|
# set the device to the model
|
|
|
|
# set the device to the model
|
|
|
|
model.to(device)
|
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
# set seed
|
|
|
|
|
|
|
|
set_seed(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# load training dataset
|
|
|
|
# load training dataset
|
|
|
|
training_data = PromptDstDataset(args.train_data_file)
|
|
|
|
training_data = PromptDstDataset(args.train_data_file)
|
|
|
|
|
|
|
|
|
|
|
|
@ -116,7 +113,7 @@ def main():
|
|
|
|
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
|
|
|
|
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
|
|
|
|
|
|
|
|
|
|
|
|
validation_summary = {
|
|
|
|
validation_summary = {
|
|
|
|
'validation_with_true_values': args.validation_with_true_values
|
|
|
|
'validation_with_true_values': True
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# outputs file extension (representing usage of prompt ensemble & answered prompts)
|
|
|
|
# outputs file extension (representing usage of prompt ensemble & answered prompts)
|
|
|
|
|