You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

496 lines
20 KiB

import argparse
import numpy as np
import os
import json
import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from dataset import PromptDstDataset
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from prompt_utils import get_ensemble_prompts
from prompt_utils import get_prompt_for_training
from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT
from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT
from prompt_utils import PROMPT_TEMPLATES
from metrics import PromptDSTEvaluator
from datetime import datetime
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--train_data_file", default=None, type=str, required=True,
help="The input training data file <JSON Path>.")
parser.add_argument("--save_model_dir", default=None, type=str, required=True,
help="The directory where the model should be saved")
parser.add_argument("--pretrained_model_path", default=None, type=str, required=True,
help="The pre-trained model path for fine tuning [Either original SOLOIST "
"or a saved model checkpoint]")
# Optional
parser.add_argument("--num_epochs", default=5, type=int,
help="Total number of training epochs to perform.")
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size for training.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam Optimizer.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight decay")
parser.add_argument("--with_inverse_prompt", action="store_true",
help="Flag for enabling/disabling inverse prompt during training")
parser.add_argument("--inverse_prompt_weight", default=0.1, type=float,
help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)")
parser.add_argument("--with_prompt_ensemble", action="store_true",
help="Flag for enabling/disabling prompt ensembling during training")
parser.add_argument("--with_answered_prompts", action="store_true",
help="Flag to enable/disable the use of answered prompts while validation")
parser.add_argument("--validation_file", default="", type=str,
help="Validation file for evaluating model after each epoch")
# parse the arguments
args = parser.parse_args()
# setup CUDA device for training on GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
# set seed
set_seed(args)
# Print args
print('Training Args :: ', json.dumps(vars(args), indent=2))
save_args(args.save_model_dir, args)
# prepare model & tokenizer -> load pre-trained model
tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_model_path, do_lower_case=True)
model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_path, pad_token_id=tokenizer.eos_token_id)
# set the device to the model
model.to(device)
# load training dataset
training_data = PromptDstDataset(args.train_data_file)
# load validation dataset
validation_data = None
if args.validation_file:
validation_data = PromptDstDataset(args.validation_file)
# create an optimizer and learning rate scheduler to fine-tune the model
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=(args.num_epochs * training_data.total_num_slot_value_pairs)
)
# set tqdm progress bars for Epochs & number of training steps
num_training_steps = args.num_epochs * training_data.len()
epochs = tqdm(total=args.num_epochs, desc="Epochs", position=0)
training_progress = tqdm(total=num_training_steps, desc="Training Progress", position=1)
# set the model in training mode
model.train()
tqdm.write(str('Training starts now... [with_prompt_ensemble = ' + str(args.with_prompt_ensemble) + ']'))
validation_summary = {
'validation_with_true_values': True
}
# outputs file extension (representing usage of prompt ensemble & answered prompts)
out_ext = "_pe" if args.with_prompt_ensemble else ""
out_ext += "_pa" if args.with_answered_prompts else ""
# training loop
for epoch in range(args.num_epochs):
running_loss = 0.0
loss_count = 0
# set the model in training mode (after each epoch)
model.train()
for i, item in enumerate(training_data.dataset_items, start=1):
history = item['history']
# iterate through (slot, value) pairs
for slot, value in item['belief_states']:
# train/generate using value-based prompt first
loss, gen_slot = train_prompting(args=args,
history=history,
slot_value_pair=(slot, value),
prompt_type=TYPE_VALUE_BASED_PROMPT,
tokenizer=tokenizer,
model=model,
device=device)
if args.with_inverse_prompt:
# use the generated slot from value-based prompt
# clean/process the generated slot (remove whitespaces & convert to lower case)
generated_slot = gen_slot.strip().lower()
# train slot generation using inverse prompt
inv_loss, _ = train_prompting(args=args,
history=history,
slot_value_pair=(generated_slot, value),
prompt_type=TYPE_INVERSE_PROMPT,
tokenizer=tokenizer,
model=model,
device=device)
# compute total loss for this slot-value pair
loss = loss + (args.inverse_prompt_weight * inv_loss)
# store the loss for printing
running_loss += loss.item()
loss_count += 1
# backward pass & step
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# update progress
training_progress.update(1)
# print loss for every 100 steps
if i % 100 == 0:
last_loss = running_loss / loss_count
tqdm.write(str('Training Loss [Iteration {}, Epoch {}] = {}'.format(i, (epoch + 1), last_loss)))
running_loss = 0.0
loss_count = 0
# Save the model after finishing an epoch
epoch_str = "epoch-{:02d}".format(epoch + 1)
output_dir = os.path.join(args.save_model_dir, epoch_str)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# save training args for each epoch (useful when testing/generating)
save_args(output_dir, args)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
tqdm.write(str('Saving model (after Epoch {} ) to :: {}'.format((epoch + 1), output_dir)))
# update epoch progress
epochs.update(1)
# Epoch finished -> continue with validation if the validation file is provided
# if validation file is provided, run evaluation here (after each epoch)
if args.validation_file and validation_data is not None:
tqdm.write(str('Validation In Progress...[with_prompt_ensemble = ' + str(args.with_prompt_ensemble) + ']'))
# set tqdm progress bars for testing progress
validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False)
# set eval mode
model.eval()
# outputs array -> to be saved to output_dir
outputs = []
# JGA metric
evaluator = PromptDSTEvaluator()
# iterate through validation dataset and generate slots using value-based prompt
for item in validation_data.dataset_items:
history = item['history']
true_states = {}
gen_states = {}
# iterate through (slot, value) pairs and generate each slot using value
for slot, value in item['belief_states']:
true_states[slot] = value
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(args=args,
history=history,
value=value,
tokenizer=tokenizer,
model=model,
device=device)
# add the generated slot to generated states
gen_states[generated_slot] = value
# update tqdm progress
validation_progress.update(1)
# add true belief states & generated belief states to outputs
outputs.append({"true_states": true_states, "gen_states": gen_states})
# add true & generated belief states to evaluator for computing JGA
evaluator.add_data_item(true_states.copy(), gen_states.copy())
validation_progress.close()
epoch_valid_summary = {}
# compute JGA & print results
tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...'))
jga_score = evaluator.compute_joint_goal_accuracy(no_print=True)
tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score)))
now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S")
# outputs file name
file_name = "outputs{}-{}.json".format(out_ext, datetime_str)
epoch_valid_summary['file_name'] = file_name
epoch_valid_summary['jga_score'] = round(jga_score, 3)
# add epoch summary to valid summary
validation_summary[epoch_str] = epoch_valid_summary
# save the outputs to trained epoch dir
output_file = os.path.join(output_dir, file_name)
tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file)))
json.dump(outputs, open(output_file, 'w'), indent=2)
# save validation file summary (if there's data)
if len(validation_summary) > 1:
valid_summary_file = os.path.join(args.save_model_dir, 'validation{}.json'.format(out_ext))
tqdm.write(str('Saving Validation Summary :: {}'.format(valid_summary_file)))
json.dump(validation_summary, open(valid_summary_file, 'w'), indent=2)
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def save_args(save_dir, args):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
args_file = os.path.join(save_dir, 'args.json')
args_dict = vars(args)
# save prompt templates used for training & inference
prompt_templates = PROMPT_TEMPLATES.copy()
if not args.with_inverse_prompt:
prompt_templates.pop(TYPE_INVERSE_PROMPT)
if args.with_prompt_ensemble:
prompt_templates.pop(TYPE_VALUE_BASED_PROMPT)
else:
prompt_templates.pop(TYPE_PROMPT_ENSEMBLE)
args_dict['prompt_templates'] = prompt_templates
json.dump(args_dict, open(args_file, "w"), indent=2)
def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device):
# slot_value_pair = (slot, value)
# use prompt ensemble when set in the args
if prompt_type is TYPE_VALUE_BASED_PROMPT and args.with_prompt_ensemble:
return train_prompt_ensemble(history=history,
slot_value_pair=slot_value_pair,
prompt_type=TYPE_PROMPT_ENSEMBLE,
tokenizer=tokenizer,
model=model,
device=device)
# get prompt for training based on "type"
prompt = get_prompt_for_training(prompt_type, slot_value_pair)
# combine history and prompt
input_prompt = history + prompt
# encode the history & prompt using tokenizer
encoded_prompt = tokenizer(input_prompt, return_tensors="pt")
encoded_prompt.to(device)
# get the last token id
# this could be a slot or value depending on prompt type
last_token = encoded_prompt['input_ids'][:, -1:]
last_token.to(device)
# model outputs
outputs = model(**encoded_prompt)
if prompt_type == TYPE_INVERSE_PROMPT:
# value ids
value_ids = tokenizer.encode(slot_value_pair[1], return_tensors="pt", add_prefix_space=True)
flipped_value_ids = torch.flip(value_ids, dims=[1])
flipped_logits = torch.flip(outputs.logits, dims=[1])
index = 1
# iterate through the value ids and compute loss (combined probability)
total_prob = None
for item in flipped_value_ids[0]:
token_logits = flipped_logits[:, index, :]
index += 1
# softmax probabilities
probs = torch.nn.functional.softmax(token_logits, dim=-1)
# this token generation probability
token_prob = torch.gather(probs, 1, torch.tensor([[item]], device=device)).squeeze(-1)
# multiply the probabilities for each word in belief state value
if total_prob is None:
total_prob = token_prob
else:
total_prob *= token_prob
loss = torch.negative(torch.log(total_prob))
# return loss and 'None' for generated values
return loss, None
# loss for slot generation using value-based prompt & the generated slot
if prompt_type == TYPE_VALUE_BASED_PROMPT:
# get last token logits [-2 -> for last but one item]
logits = outputs.logits[:, -2, :]
# softmax probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# last token generation probability
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
loss = torch.negative(torch.log(last_token_prob))
# generated word -> the one with the highest probability
gen_word_token = torch.argmax(logits, dim=-1)
generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
# loss is the log of probability
return loss, generated_word
def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, model, device):
# slot_value_pair = (slot, value)
# get list of prompts for training
prompts = get_prompt_for_training(prompt_type, slot_value_pair)
# total probability of all prompt functions
total_prob = None
gen_probs, gen_words = [], []
# iterate through each prompt
for prompt in prompts:
# combine history and prompt
input_prompt = history + prompt
# encode the history & prompt using tokenizer
encoded_prompt = tokenizer(input_prompt, return_tensors="pt")
encoded_prompt.to(device)
# get the last token id
# this could be a slot or value depending on prompt type
last_token = encoded_prompt['input_ids'][:, -1:]
last_token.to(device)
# model outputs
outputs = model(**encoded_prompt)
# get last token logits [-2 -> for last but one item]
logits = outputs.logits[:, -2, :]
# softmax probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# last token generation probability
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
# weighted probability | sum of all prompt weights must be equals to 1
token_prob = PROMPT_WEIGHT * last_token_prob
# sum the probs for all prompt functions
if total_prob is None:
total_prob = token_prob
else:
total_prob += token_prob
# generated slot
# find the token with the highest probability, this will be the generated word
gen_word_token = torch.argmax(logits, dim=-1)
gen_word_prob = torch.gather(probs, 1, gen_word_token[:, None]).squeeze(-1)
gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
# add the generated word and probs to list
gen_probs.append(gen_word_prob.item())
gen_words.append(gen_word)
generated_word = gen_words[gen_probs.index(max(gen_probs))]
# loss is the log of probability
loss = torch.negative(torch.log(total_prob))
return loss, generated_word
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
# check if prompt ensemble is enabled in arguments
if args.with_prompt_ensemble:
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
# get value-based prompt for generation
prompt = get_value_based_prompt(value)
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_token = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True)
return generated_word.strip().lower()
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
# get prompts for ensemble generation
prompts = get_ensemble_prompts(value)
gen_probs, gen_words = [], []
for prompt in prompts:
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token (max length of slot = 1)
outputs = model.generate(**encoded_prompt,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=1)
gen_token = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:]
gen_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True).strip()
probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
gen_prob = torch.gather(probs, 1, gen_token).squeeze(-1)
# add the generated word and probs to list
gen_probs.append(gen_prob.item())
gen_words.append(gen_word)
# return word with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
return generated_word.strip().lower()
if __name__ == "__main__":
main()