You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

371 lines
15 KiB

import argparse
import numpy as np
import os
import json
import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from dataset import PromptDstDataset
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from prompt_utils import get_prompt_for_training
from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT
from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT
from metrics import PromptDSTEvaluator
from datetime import datetime
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--train_data_file", default=None, type=str, required=True,
help="The input training data file <JSON Path>.")
parser.add_argument("--save_model_dir", default=None, type=str, required=True,
help="The directory where the model should be saved")
parser.add_argument("--pretrained_model_path", default=None, type=str, required=True,
help="The pre-trained model path for fine tuning [Either original SOLOIST "
"or a saved model checkpoint]")
# Optional
parser.add_argument("--num_epochs", default=5, type=int,
help="Total number of training epochs to perform.")
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size for training.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam Optimizer.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight decay")
parser.add_argument("--with_inverse_prompt", action="store_true",
help="Flag for enabling/disabling inverse prompt during training")
parser.add_argument("--inverse_prompt_weight", default=0.1, type=float,
help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)")
parser.add_argument("--with_prompt_ensemble", action="store_true",
help="Flag for enabling/disabling prompt ensembling during training")
parser.add_argument("--validation_file", default="", type=str,
help="Validation file for evaluating model after each epoch")
# parse the arguments
args = parser.parse_args()
# setup CUDA device for training on GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
# prepare model & tokenizer -> load pre-trained model
tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_model_path, do_lower_case=True)
model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_path, pad_token_id=tokenizer.eos_token_id)
# set the device to the model
model.to(device)
# set seed
set_seed(args)
# load training dataset
training_data = PromptDstDataset(args.train_data_file)
# load validation dataset
validation_data = None
if args.validation_file:
validation_data = PromptDstDataset(args.validation_file)
# create an optimizer and learning rate scheduler to fine-tune the model
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=(args.num_epochs * training_data.total_num_slot_value_pairs)
)
# set tqdm progress bars for Epochs & number of training steps
num_training_steps = args.num_epochs * training_data.len()
epochs = tqdm(total=args.num_epochs, desc="Epochs", position=0)
training_progress = tqdm(total=num_training_steps, desc="Training Progress", position=1)
# set the model in training mode
model.train()
# training loop
for epoch in range(args.num_epochs):
running_loss = 0.0
loss_count = 0
# set the model in training mode (after each epoch)
model.train()
for i, item in enumerate(training_data.dataset_items, start=1):
history = item['history']
# iterate through (slot, value) pairs
for slot, value in item['belief_states']:
# train/generate using value-based prompt first
loss, gen_slot = train_prompting(args=args,
history=history,
slot_value_pair=(slot, value),
prompt_type=TYPE_VALUE_BASED_PROMPT,
tokenizer=tokenizer,
model=model,
device=device)
if args.with_inverse_prompt:
# use the generated slot from value-based prompt
# clean/process the generated slot (remove whitespaces & convert to lower case)
generated_slot = gen_slot.strip().lower()
# train slot generation using inverse prompt
inv_loss, _ = train_prompting(args=args,
history=history,
slot_value_pair=(generated_slot, value),
prompt_type=TYPE_INVERSE_PROMPT,
tokenizer=tokenizer,
model=model,
device=device)
# compute total loss for this slot-value pair
loss = loss + (args.inverse_prompt_weight * inv_loss)
# store the loss for printing
running_loss += loss.item()
loss_count += 1
# backward pass & step
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# update progress
training_progress.update(1)
# print loss for every 100 steps
if i % 100 == 0:
last_loss = running_loss / loss_count
tqdm.write(str('Training Loss [Iteration {}, Epoch {}] = {}'.format(i, (epoch + 1), last_loss)))
running_loss = 0.0
loss_count = 0
# Save the model after finishing an epoch
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1)))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
tqdm.write(str('Saving model (after Epoch {} ) to :: {}'.format((epoch + 1), output_dir)))
# update epoch progress
epochs.update(1)
# Epoch finished -> continue with validation if the validation file is provided
# if validation file is provided, run evaluation here (after each epoch)
if args.validation_file and validation_data is not None:
# set tqdm progress bars for testing progress
validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False)
# set eval mode
model.eval()
# outputs array -> to be saved to output_dir
outputs = []
# JGA metric
evaluator = PromptDSTEvaluator()
# iterate through validation dataset and generate slots using value-based prompt
for item in validation_data.dataset_items:
history = item['history']
true_states = {}
gen_states = {}
# iterate through (slot, value) pairs and generate each slot using value
for slot, value in item['belief_states']:
true_states[slot] = value
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(history=history,
value=value,
tokenizer=tokenizer,
model=model,
device=device)
# add the generated slot to generated states
gen_states[generated_slot] = value
# update tqdm progress
validation_progress.update(1)
# add true belief states & generated belief states to outputs
outputs.append({"true_states": true_states, "gen_states": gen_states})
# add true & generated belief states to evaluator for computing JGA
evaluator.add_data_item(true_states.copy(), gen_states.copy())
validation_progress.close()
# compute JGA & print results
tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...'))
jga_score = evaluator.compute_joint_goal_accuracy(no_print=True)
tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score)))
# save the outputs to trained epoch dir
now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S")
output_file = os.path.join(output_dir, 'outputs-{}.json'.format(datetime_str))
tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file)))
json.dump(outputs, open(output_file, 'w'), indent=2)
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device):
# slot_value_pair = (slot, value)
# use prompt ensemble when set in the args
if prompt_type is TYPE_VALUE_BASED_PROMPT and args.with_prompt_ensemble:
return train_prompt_ensemble(history=history,
slot_value_pair=slot_value_pair,
prompt_type=TYPE_PROMPT_ENSEMBLE,
tokenizer=tokenizer,
model=model,
device=device)
# get prompt for training based on "type"
prompt = get_prompt_for_training(prompt_type, slot_value_pair)
# combine history and prompt
input_prompt = history + prompt
# encode the history & prompt using tokenizer
encoded_prompt = tokenizer(input_prompt, return_tensors="pt")
encoded_prompt.to(device)
# get the last token id
# this could be a slot or value depending on prompt type
last_token = encoded_prompt['input_ids'][:, -1:]
last_token.to(device)
# model outputs
outputs = model(**encoded_prompt)
# get last token logits [-2 -> for last but one item]
logits = outputs.logits[:, -2, :]
# softmax probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# last token generation probability
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
loss = torch.negative(torch.log(last_token_prob))
# generated word -> the one with the highest probability
generated_word = None
if prompt_type == TYPE_VALUE_BASED_PROMPT:
# find the token with the highest probability, this will be the generated word
gen_word_token = torch.argmax(logits, dim=-1)
generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
# loss is the log of probability
return loss, generated_word
def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, model, device):
# slot_value_pair = (slot, value)
# get list of prompts for training
prompts = get_prompt_for_training(prompt_type, slot_value_pair)
# return total loss
total_loss = None
gen_probs, gen_words = [], []
# iterate through each prompt
for prompt in prompts:
# combine history and prompt
input_prompt = history + prompt
# encode the history & prompt using tokenizer
encoded_prompt = tokenizer(input_prompt, return_tensors="pt")
encoded_prompt.to(device)
# get the last token id
# this could be a slot or value depending on prompt type
last_token = encoded_prompt['input_ids'][:, -1:]
last_token.to(device)
# model outputs
outputs = model(**encoded_prompt)
# get last token logits [-2 -> for last but one item]
logits = outputs.logits[:, -2, :]
# softmax probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# last token generation probability
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
# weighted probability
token_prob = PROMPT_WEIGHT * last_token_prob
loss = torch.negative(torch.log(token_prob))
if total_loss is None:
total_loss = loss
else:
total_loss += loss
# generated slot
# find the token with the highest probability, this will be the generated word
gen_word_token = torch.argmax(logits, dim=-1)
gen_word_prob = torch.gather(probs, 1, gen_word_token[:, None]).squeeze(-1)
gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
# add the generated word and probs to list
gen_probs.append(gen_word_prob)
gen_words.append(gen_word)
generated_word = gen_words.index(max(gen_probs))
# loss is the log of probability
return total_loss, generated_word
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt(history, value, tokenizer, model, device):
# get prompt for training based on "type"
prompt = get_value_based_prompt(value)
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True)
return generated_word.strip().lower()
if __name__ == "__main__":
main()