|
|
|
@ -178,8 +178,8 @@ def main():
|
|
|
|
loss_count = 0
|
|
|
|
loss_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
# Save the model after finishing an epoch
|
|
|
|
# Save the model after finishing an epoch
|
|
|
|
epoch_str = "{:02d}".format(epoch + 1)
|
|
|
|
epoch_str = "epoch-{:02d}".format(epoch + 1)
|
|
|
|
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", epoch_str))
|
|
|
|
output_dir = os.path.join(args.save_model_dir, epoch_str)
|
|
|
|
if not os.path.exists(output_dir):
|
|
|
|
if not os.path.exists(output_dir):
|
|
|
|
os.makedirs(output_dir)
|
|
|
|
os.makedirs(output_dir)
|
|
|
|
# save training args for each epoch (useful when testing/generating)
|
|
|
|
# save training args for each epoch (useful when testing/generating)
|
|
|
|
@ -280,7 +280,15 @@ def save_args(save_dir, args):
|
|
|
|
os.makedirs(save_dir)
|
|
|
|
os.makedirs(save_dir)
|
|
|
|
args_file = os.path.join(save_dir, 'args.json')
|
|
|
|
args_file = os.path.join(save_dir, 'args.json')
|
|
|
|
args_dict = vars(args)
|
|
|
|
args_dict = vars(args)
|
|
|
|
args_dict['prompt_templates'] = PROMPT_TEMPLATES
|
|
|
|
# save prompt templates used for training & inference
|
|
|
|
|
|
|
|
prompt_templates = PROMPT_TEMPLATES.copy()
|
|
|
|
|
|
|
|
if not args.with_inverse_prompt:
|
|
|
|
|
|
|
|
prompt_templates.pop(TYPE_INVERSE_PROMPT)
|
|
|
|
|
|
|
|
if args.with_prompt_ensemble:
|
|
|
|
|
|
|
|
prompt_templates.pop(TYPE_VALUE_BASED_PROMPT)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
prompt_templates.pop(TYPE_PROMPT_ENSEMBLE)
|
|
|
|
|
|
|
|
args_dict['prompt_templates'] = prompt_templates
|
|
|
|
json.dump(args_dict, open(args_file, "w"), indent=2)
|
|
|
|
json.dump(args_dict, open(args_file, "w"), indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -314,6 +322,36 @@ def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, mode
|
|
|
|
# model outputs
|
|
|
|
# model outputs
|
|
|
|
outputs = model(**encoded_prompt)
|
|
|
|
outputs = model(**encoded_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prompt_type == TYPE_INVERSE_PROMPT:
|
|
|
|
|
|
|
|
# value ids
|
|
|
|
|
|
|
|
value_ids = tokenizer.encode(slot_value_pair[1], return_tensors="pt", add_prefix_space=True)
|
|
|
|
|
|
|
|
flipped_value_ids = torch.flip(value_ids, dims=[1])
|
|
|
|
|
|
|
|
flipped_logits = torch.flip(outputs.logits, dims=[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index = 1
|
|
|
|
|
|
|
|
# iterate through the value ids and compute loss (combined probability)
|
|
|
|
|
|
|
|
total_prob = None
|
|
|
|
|
|
|
|
for item in flipped_value_ids[0]:
|
|
|
|
|
|
|
|
token_logits = flipped_logits[:, index, :]
|
|
|
|
|
|
|
|
index += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# softmax probabilities
|
|
|
|
|
|
|
|
probs = torch.nn.functional.softmax(token_logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# this token generation probability
|
|
|
|
|
|
|
|
token_prob = torch.gather(probs, 1, torch.tensor([[item]], device=device)).squeeze(-1)
|
|
|
|
|
|
|
|
# multiply the probabilities for each word in belief state value
|
|
|
|
|
|
|
|
if total_prob is None:
|
|
|
|
|
|
|
|
total_prob = token_prob
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
total_prob *= token_prob
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = torch.negative(torch.log(total_prob))
|
|
|
|
|
|
|
|
# return loss and 'None' for generated values
|
|
|
|
|
|
|
|
return loss, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# loss for slot generation using value-based prompt & the generated slot
|
|
|
|
|
|
|
|
if prompt_type == TYPE_VALUE_BASED_PROMPT:
|
|
|
|
# get last token logits [-2 -> for last but one item]
|
|
|
|
# get last token logits [-2 -> for last but one item]
|
|
|
|
logits = outputs.logits[:, -2, :]
|
|
|
|
logits = outputs.logits[:, -2, :]
|
|
|
|
|
|
|
|
|
|
|
|
@ -325,9 +363,6 @@ def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, mode
|
|
|
|
loss = torch.negative(torch.log(last_token_prob))
|
|
|
|
loss = torch.negative(torch.log(last_token_prob))
|
|
|
|
|
|
|
|
|
|
|
|
# generated word -> the one with the highest probability
|
|
|
|
# generated word -> the one with the highest probability
|
|
|
|
generated_word = None
|
|
|
|
|
|
|
|
if prompt_type == TYPE_VALUE_BASED_PROMPT:
|
|
|
|
|
|
|
|
# find the token with the highest probability, this will be the generated word
|
|
|
|
|
|
|
|
gen_word_token = torch.argmax(logits, dim=-1)
|
|
|
|
gen_word_token = torch.argmax(logits, dim=-1)
|
|
|
|
generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
|
|
|
|
generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|