|
|
|
|
@ -46,6 +46,8 @@ def main():
|
|
|
|
|
help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)")
|
|
|
|
|
parser.add_argument("--with_prompt_ensemble", action="store_true",
|
|
|
|
|
help="Flag for enabling/disabling prompt ensembling during training")
|
|
|
|
|
parser.add_argument("--with_answered_prompts", action="store_true",
|
|
|
|
|
help="Flag to enable/disable the use of answered prompts while validation")
|
|
|
|
|
parser.add_argument("--validation_file", default="", type=str,
|
|
|
|
|
help="Validation file for evaluating model after each epoch")
|
|
|
|
|
|
|
|
|
|
@ -56,6 +58,10 @@ def main():
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
args.n_gpu = torch.cuda.device_count()
|
|
|
|
|
|
|
|
|
|
# Print args
|
|
|
|
|
print('Training Args :: ', json.dumps(vars(args), indent=2))
|
|
|
|
|
save_args(args.save_model_dir, args)
|
|
|
|
|
|
|
|
|
|
# prepare model & tokenizer -> load pre-trained model
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_model_path, do_lower_case=True)
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_path, pad_token_id=tokenizer.eos_token_id)
|
|
|
|
|
@ -166,6 +172,8 @@ def main():
|
|
|
|
|
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1)))
|
|
|
|
|
if not os.path.exists(output_dir):
|
|
|
|
|
os.makedirs(output_dir)
|
|
|
|
|
# save training args for each epoch (useful when testing/generating)
|
|
|
|
|
save_args(output_dir, args)
|
|
|
|
|
model.save_pretrained(output_dir)
|
|
|
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
|
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
|
|
|
|
@ -227,10 +235,14 @@ def main():
|
|
|
|
|
jga_score = evaluator.compute_joint_goal_accuracy(no_print=True)
|
|
|
|
|
tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score)))
|
|
|
|
|
|
|
|
|
|
# output file extension (for prompt ensemble & answered prompts)
|
|
|
|
|
out_ext = "_pe" if args.with_prompt_ensemble else ""
|
|
|
|
|
out_ext += "_pa" if args.with_answered_prompts else ""
|
|
|
|
|
|
|
|
|
|
# save the outputs to trained epoch dir
|
|
|
|
|
now = datetime.now()
|
|
|
|
|
datetime_str = now.strftime("%Y%m%dT%H%M%S")
|
|
|
|
|
output_file = os.path.join(output_dir, 'outputs-{}.json'.format(datetime_str))
|
|
|
|
|
output_file = os.path.join(output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str))
|
|
|
|
|
tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file)))
|
|
|
|
|
json.dump(outputs, open(output_file, 'w'), indent=2)
|
|
|
|
|
|
|
|
|
|
@ -242,6 +254,13 @@ def set_seed(args):
|
|
|
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_args(save_dir, args):
|
|
|
|
|
if not os.path.exists(save_dir):
|
|
|
|
|
os.makedirs(save_dir)
|
|
|
|
|
args_file = os.path.join(save_dir, 'args.json')
|
|
|
|
|
json.dump(vars(args), open(args_file, "w"), indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device):
|
|
|
|
|
# slot_value_pair = (slot, value)
|
|
|
|
|
|
|
|
|
|
@ -298,8 +317,8 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode
|
|
|
|
|
# get list of prompts for training
|
|
|
|
|
prompts = get_prompt_for_training(prompt_type, slot_value_pair)
|
|
|
|
|
|
|
|
|
|
# return total loss
|
|
|
|
|
total_loss = None
|
|
|
|
|
# total probability of all prompt functions
|
|
|
|
|
total_prob = None
|
|
|
|
|
|
|
|
|
|
gen_probs, gen_words = [], []
|
|
|
|
|
|
|
|
|
|
@ -329,13 +348,15 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode
|
|
|
|
|
|
|
|
|
|
# last token generation probability
|
|
|
|
|
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
|
|
|
|
|
# weighted probability
|
|
|
|
|
|
|
|
|
|
# weighted probability | sum of all prompt weights must be equals to 1
|
|
|
|
|
token_prob = PROMPT_WEIGHT * last_token_prob
|
|
|
|
|
loss = torch.negative(torch.log(token_prob))
|
|
|
|
|
if total_loss is None:
|
|
|
|
|
total_loss = loss
|
|
|
|
|
|
|
|
|
|
# sum the probs for all prompt functions
|
|
|
|
|
if total_prob is None:
|
|
|
|
|
total_prob = token_prob
|
|
|
|
|
else:
|
|
|
|
|
total_loss += loss
|
|
|
|
|
total_prob += token_prob
|
|
|
|
|
|
|
|
|
|
# generated slot
|
|
|
|
|
# find the token with the highest probability, this will be the generated word
|
|
|
|
|
@ -349,7 +370,8 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode
|
|
|
|
|
|
|
|
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
|
|
|
# loss is the log of probability
|
|
|
|
|
return total_loss, generated_word
|
|
|
|
|
loss = torch.negative(torch.log(total_prob))
|
|
|
|
|
return loss, generated_word
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Use this for generating next word (Validation after each epoch)
|
|
|
|
|
@ -390,7 +412,7 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
|
|
|
|
|
encoded_prompt = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
encoded_prompt.to(device)
|
|
|
|
|
|
|
|
|
|
# generate 1 token
|
|
|
|
|
# generate 1 token (max length of slot = 1)
|
|
|
|
|
outputs = model.generate(**encoded_prompt,
|
|
|
|
|
return_dict_in_generate=True,
|
|
|
|
|
output_scores=True,
|
|
|
|
|
@ -406,9 +428,9 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
|
|
|
|
|
gen_probs.append(gen_prob.item())
|
|
|
|
|
gen_words.append(gen_word)
|
|
|
|
|
|
|
|
|
|
# return word with the highest probability
|
|
|
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
# return word with the highest probability
|
|
|
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|