Saving Validation summary in json and prompt templates in args file

main
Pavan Mandava 3 years ago
parent b9539b05a0
commit cc281edce6

@ -15,6 +15,7 @@ from prompt_utils import get_prompt_for_training
from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT
from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT
from prompt_utils import PROMPT_TEMPLATES
from metrics import PromptDSTEvaluator
from datetime import datetime
@ -50,6 +51,9 @@ def main():
help="Flag to enable/disable the use of answered prompts while validation")
parser.add_argument("--validation_file", default="", type=str,
help="Validation file for evaluating model after each epoch")
parser.add_argument("--validation_with_true_values", action="store_true",
help="Flag for enabling/disabling the usage of TRUE values for slot generation during "
"validation")
# parse the arguments
args = parser.parse_args()
@ -111,6 +115,14 @@ def main():
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
validation_summary = {
'validation_with_true_values': args.validation_with_true_values
}
# outputs file extension (representing usage of prompt ensemble & answered prompts)
out_ext = "_pe" if args.with_prompt_ensemble else ""
out_ext += "_pa" if args.with_answered_prompts else ""
# training loop
for epoch in range(args.num_epochs):
running_loss = 0.0
@ -169,7 +181,8 @@ def main():
loss_count = 0
# Save the model after finishing an epoch
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1)))
epoch_str = "{:02d}".format(epoch + 1)
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", epoch_str))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# save training args for each epoch (useful when testing/generating)
@ -230,22 +243,33 @@ def main():
evaluator.add_data_item(true_states.copy(), gen_states.copy())
validation_progress.close()
epoch_valid_summary = {}
# compute JGA & print results
tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...'))
jga_score = evaluator.compute_joint_goal_accuracy(no_print=True)
tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score)))
# output file extension (for prompt ensemble & answered prompts)
out_ext = "_pe" if args.with_prompt_ensemble else ""
out_ext += "_pa" if args.with_answered_prompts else ""
# save the outputs to trained epoch dir
now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S")
output_file = os.path.join(output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str))
# outputs file name
file_name = "outputs{}-{}.json".format(out_ext, datetime_str)
epoch_valid_summary['file_name'] = file_name
epoch_valid_summary['jga_score'] = round(jga_score, 3)
# add epoch summary to valid summary
validation_summary[epoch_str] = epoch_valid_summary
# save the outputs to trained epoch dir
output_file = os.path.join(output_dir, file_name)
tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file)))
json.dump(outputs, open(output_file, 'w'), indent=2)
# save validation file summary (if there's data)
if len(validation_summary) > 1:
valid_summary_file = os.path.join(args.save_model_dir, 'validation{}.json'.format(out_ext))
tqdm.write(str('Saving Validation Summary :: {}'.format(valid_summary_file)))
json.dump(validation_summary, open(valid_summary_file, 'w'), indent=2)
def set_seed(args):
np.random.seed(args.seed)
@ -258,7 +282,9 @@ def save_args(save_dir, args):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
args_file = os.path.join(save_dir, 'args.json')
json.dump(vars(args), open(args_file, "w"), indent=2)
args_dict = vars(args)
args_dict['prompt_templates'] = PROMPT_TEMPLATES
json.dump(args_dict, open(args_file, "w"), indent=2)
def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device):

@ -11,7 +11,9 @@ PROMPT_TEMPLATES = {
"training": "belief states: value = $value, slot = $slot",
"generate": "belief states: value = $value, slot ="
},
"inverse-prompt": "belief states: slot = $slot, value = $value",
"inverse-prompt": {
"training": "belief states: $slot = $value",
},
"prompt-ensemble": {
"training": {
"p1": "belief states: value = $value, slot = $slot",
@ -30,20 +32,16 @@ PROMPT_TEMPLATES = {
def get_prompt_for_training(typ, slot_value):
if typ == TYPE_INVERSE_PROMPT:
template = Template(PROMPT_TEMPLATES[typ])
return template.substitute(slot=slot_value[0], value=slot_value[1])
else:
template = PROMPT_TEMPLATES[typ]['training']
if isinstance(template, str):
return Template(template).substitute(slot=slot_value[0], value=slot_value[1])
elif isinstance(template, dict):
template_list = template.values()
prompt_list = []
for template_str in template_list:
template = Template(template_str)
prompt_list.append(template.substitute(slot=slot_value[0], value=slot_value[1]))
return prompt_list
template = PROMPT_TEMPLATES[typ]['training']
if isinstance(template, str):
return Template(template).substitute(slot=slot_value[0], value=slot_value[1])
elif isinstance(template, dict):
template_list = template.values()
prompt_list = []
for template_str in template_list:
prompt = Template(template_str).substitute(slot=slot_value[0], value=slot_value[1])
prompt_list.append(prompt)
return prompt_list
def get_value_based_prompt(value):

Loading…
Cancel
Save