|
|
|
|
@ -15,6 +15,7 @@ from prompt_utils import get_prompt_for_training
|
|
|
|
|
from prompt_utils import TYPE_VALUE_BASED_PROMPT
|
|
|
|
|
from prompt_utils import TYPE_INVERSE_PROMPT
|
|
|
|
|
from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT
|
|
|
|
|
from prompt_utils import PROMPT_TEMPLATES
|
|
|
|
|
from metrics import PromptDSTEvaluator
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
@ -50,6 +51,9 @@ def main():
|
|
|
|
|
help="Flag to enable/disable the use of answered prompts while validation")
|
|
|
|
|
parser.add_argument("--validation_file", default="", type=str,
|
|
|
|
|
help="Validation file for evaluating model after each epoch")
|
|
|
|
|
parser.add_argument("--validation_with_true_values", action="store_true",
|
|
|
|
|
help="Flag for enabling/disabling the usage of TRUE values for slot generation during "
|
|
|
|
|
"validation")
|
|
|
|
|
|
|
|
|
|
# parse the arguments
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
@ -111,6 +115,14 @@ def main():
|
|
|
|
|
|
|
|
|
|
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
|
|
|
|
|
|
|
|
|
|
validation_summary = {
|
|
|
|
|
'validation_with_true_values': args.validation_with_true_values
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# outputs file extension (representing usage of prompt ensemble & answered prompts)
|
|
|
|
|
out_ext = "_pe" if args.with_prompt_ensemble else ""
|
|
|
|
|
out_ext += "_pa" if args.with_answered_prompts else ""
|
|
|
|
|
|
|
|
|
|
# training loop
|
|
|
|
|
for epoch in range(args.num_epochs):
|
|
|
|
|
running_loss = 0.0
|
|
|
|
|
@ -169,7 +181,8 @@ def main():
|
|
|
|
|
loss_count = 0
|
|
|
|
|
|
|
|
|
|
# Save the model after finishing an epoch
|
|
|
|
|
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1)))
|
|
|
|
|
epoch_str = "{:02d}".format(epoch + 1)
|
|
|
|
|
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", epoch_str))
|
|
|
|
|
if not os.path.exists(output_dir):
|
|
|
|
|
os.makedirs(output_dir)
|
|
|
|
|
# save training args for each epoch (useful when testing/generating)
|
|
|
|
|
@ -230,22 +243,33 @@ def main():
|
|
|
|
|
evaluator.add_data_item(true_states.copy(), gen_states.copy())
|
|
|
|
|
|
|
|
|
|
validation_progress.close()
|
|
|
|
|
epoch_valid_summary = {}
|
|
|
|
|
# compute JGA & print results
|
|
|
|
|
tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...'))
|
|
|
|
|
jga_score = evaluator.compute_joint_goal_accuracy(no_print=True)
|
|
|
|
|
tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score)))
|
|
|
|
|
|
|
|
|
|
# output file extension (for prompt ensemble & answered prompts)
|
|
|
|
|
out_ext = "_pe" if args.with_prompt_ensemble else ""
|
|
|
|
|
out_ext += "_pa" if args.with_answered_prompts else ""
|
|
|
|
|
|
|
|
|
|
# save the outputs to trained epoch dir
|
|
|
|
|
now = datetime.now()
|
|
|
|
|
datetime_str = now.strftime("%Y%m%dT%H%M%S")
|
|
|
|
|
output_file = os.path.join(output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str))
|
|
|
|
|
|
|
|
|
|
# outputs file name
|
|
|
|
|
file_name = "outputs{}-{}.json".format(out_ext, datetime_str)
|
|
|
|
|
epoch_valid_summary['file_name'] = file_name
|
|
|
|
|
epoch_valid_summary['jga_score'] = round(jga_score, 3)
|
|
|
|
|
# add epoch summary to valid summary
|
|
|
|
|
validation_summary[epoch_str] = epoch_valid_summary
|
|
|
|
|
|
|
|
|
|
# save the outputs to trained epoch dir
|
|
|
|
|
output_file = os.path.join(output_dir, file_name)
|
|
|
|
|
tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file)))
|
|
|
|
|
json.dump(outputs, open(output_file, 'w'), indent=2)
|
|
|
|
|
|
|
|
|
|
# save validation file summary (if there's data)
|
|
|
|
|
if len(validation_summary) > 1:
|
|
|
|
|
valid_summary_file = os.path.join(args.save_model_dir, 'validation{}.json'.format(out_ext))
|
|
|
|
|
tqdm.write(str('Saving Validation Summary :: {}'.format(valid_summary_file)))
|
|
|
|
|
json.dump(validation_summary, open(valid_summary_file, 'w'), indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_seed(args):
|
|
|
|
|
np.random.seed(args.seed)
|
|
|
|
|
@ -258,7 +282,9 @@ def save_args(save_dir, args):
|
|
|
|
|
if not os.path.exists(save_dir):
|
|
|
|
|
os.makedirs(save_dir)
|
|
|
|
|
args_file = os.path.join(save_dir, 'args.json')
|
|
|
|
|
json.dump(vars(args), open(args_file, "w"), indent=2)
|
|
|
|
|
args_dict = vars(args)
|
|
|
|
|
args_dict['prompt_templates'] = PROMPT_TEMPLATES
|
|
|
|
|
json.dump(args_dict, open(args_file, "w"), indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device):
|
|
|
|
|
|