Saving Validation summary in json and prompt templates in args file

main
Pavan Mandava 3 years ago
parent b9539b05a0
commit cc281edce6

@ -15,6 +15,7 @@ from prompt_utils import get_prompt_for_training
from prompt_utils import TYPE_VALUE_BASED_PROMPT from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT from prompt_utils import TYPE_INVERSE_PROMPT
from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT
from prompt_utils import PROMPT_TEMPLATES
from metrics import PromptDSTEvaluator from metrics import PromptDSTEvaluator
from datetime import datetime from datetime import datetime
@ -50,6 +51,9 @@ def main():
help="Flag to enable/disable the use of answered prompts while validation") help="Flag to enable/disable the use of answered prompts while validation")
parser.add_argument("--validation_file", default="", type=str, parser.add_argument("--validation_file", default="", type=str,
help="Validation file for evaluating model after each epoch") help="Validation file for evaluating model after each epoch")
parser.add_argument("--validation_with_true_values", action="store_true",
help="Flag for enabling/disabling the usage of TRUE values for slot generation during "
"validation")
# parse the arguments # parse the arguments
args = parser.parse_args() args = parser.parse_args()
@ -111,6 +115,14 @@ def main():
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
validation_summary = {
'validation_with_true_values': args.validation_with_true_values
}
# outputs file extension (representing usage of prompt ensemble & answered prompts)
out_ext = "_pe" if args.with_prompt_ensemble else ""
out_ext += "_pa" if args.with_answered_prompts else ""
# training loop # training loop
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
running_loss = 0.0 running_loss = 0.0
@ -169,7 +181,8 @@ def main():
loss_count = 0 loss_count = 0
# Save the model after finishing an epoch # Save the model after finishing an epoch
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1))) epoch_str = "{:02d}".format(epoch + 1)
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", epoch_str))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
# save training args for each epoch (useful when testing/generating) # save training args for each epoch (useful when testing/generating)
@ -230,22 +243,33 @@ def main():
evaluator.add_data_item(true_states.copy(), gen_states.copy()) evaluator.add_data_item(true_states.copy(), gen_states.copy())
validation_progress.close() validation_progress.close()
epoch_valid_summary = {}
# compute JGA & print results # compute JGA & print results
tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...')) tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...'))
jga_score = evaluator.compute_joint_goal_accuracy(no_print=True) jga_score = evaluator.compute_joint_goal_accuracy(no_print=True)
tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score))) tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score)))
# output file extension (for prompt ensemble & answered prompts)
out_ext = "_pe" if args.with_prompt_ensemble else ""
out_ext += "_pa" if args.with_answered_prompts else ""
# save the outputs to trained epoch dir
now = datetime.now() now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S") datetime_str = now.strftime("%Y%m%dT%H%M%S")
output_file = os.path.join(output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str))
# outputs file name
file_name = "outputs{}-{}.json".format(out_ext, datetime_str)
epoch_valid_summary['file_name'] = file_name
epoch_valid_summary['jga_score'] = round(jga_score, 3)
# add epoch summary to valid summary
validation_summary[epoch_str] = epoch_valid_summary
# save the outputs to trained epoch dir
output_file = os.path.join(output_dir, file_name)
tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file))) tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file)))
json.dump(outputs, open(output_file, 'w'), indent=2) json.dump(outputs, open(output_file, 'w'), indent=2)
# save validation file summary (if there's data)
if len(validation_summary) > 1:
valid_summary_file = os.path.join(args.save_model_dir, 'validation{}.json'.format(out_ext))
tqdm.write(str('Saving Validation Summary :: {}'.format(valid_summary_file)))
json.dump(validation_summary, open(valid_summary_file, 'w'), indent=2)
def set_seed(args): def set_seed(args):
np.random.seed(args.seed) np.random.seed(args.seed)
@ -258,7 +282,9 @@ def save_args(save_dir, args):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
args_file = os.path.join(save_dir, 'args.json') args_file = os.path.join(save_dir, 'args.json')
json.dump(vars(args), open(args_file, "w"), indent=2) args_dict = vars(args)
args_dict['prompt_templates'] = PROMPT_TEMPLATES
json.dump(args_dict, open(args_file, "w"), indent=2)
def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device): def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device):

@ -11,7 +11,9 @@ PROMPT_TEMPLATES = {
"training": "belief states: value = $value, slot = $slot", "training": "belief states: value = $value, slot = $slot",
"generate": "belief states: value = $value, slot =" "generate": "belief states: value = $value, slot ="
}, },
"inverse-prompt": "belief states: slot = $slot, value = $value", "inverse-prompt": {
"training": "belief states: $slot = $value",
},
"prompt-ensemble": { "prompt-ensemble": {
"training": { "training": {
"p1": "belief states: value = $value, slot = $slot", "p1": "belief states: value = $value, slot = $slot",
@ -30,10 +32,6 @@ PROMPT_TEMPLATES = {
def get_prompt_for_training(typ, slot_value): def get_prompt_for_training(typ, slot_value):
if typ == TYPE_INVERSE_PROMPT:
template = Template(PROMPT_TEMPLATES[typ])
return template.substitute(slot=slot_value[0], value=slot_value[1])
else:
template = PROMPT_TEMPLATES[typ]['training'] template = PROMPT_TEMPLATES[typ]['training']
if isinstance(template, str): if isinstance(template, str):
return Template(template).substitute(slot=slot_value[0], value=slot_value[1]) return Template(template).substitute(slot=slot_value[0], value=slot_value[1])
@ -41,8 +39,8 @@ def get_prompt_for_training(typ, slot_value):
template_list = template.values() template_list = template.values()
prompt_list = [] prompt_list = []
for template_str in template_list: for template_str in template_list:
template = Template(template_str) prompt = Template(template_str).substitute(slot=slot_value[0], value=slot_value[1])
prompt_list.append(template.substitute(slot=slot_value[0], value=slot_value[1])) prompt_list.append(prompt)
return prompt_list return prompt_list

Loading…
Cancel
Save