You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
464 lines
19 KiB
464 lines
19 KiB
import argparse
|
|
|
|
import numpy as np
|
|
import os
|
|
import json
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, GPT2Tokenizer
|
|
from dataset import PromptDstDataset
|
|
from torch.optim import AdamW
|
|
from transformers import get_scheduler
|
|
from tqdm.auto import tqdm
|
|
from prompt_utils import get_value_based_prompt
|
|
from prompt_utils import get_ensemble_prompts
|
|
from prompt_utils import get_prompt_for_training
|
|
from prompt_utils import TYPE_VALUE_BASED_PROMPT
|
|
from prompt_utils import TYPE_INVERSE_PROMPT
|
|
from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT
|
|
from prompt_utils import PROMPT_TEMPLATES
|
|
from metrics import PromptDSTEvaluator
|
|
from datetime import datetime
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Required parameters
|
|
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
|
help="The input training data file <JSON Path>.")
|
|
parser.add_argument("--save_model_dir", default=None, type=str, required=True,
|
|
help="The directory where the model should be saved")
|
|
parser.add_argument("--pretrained_model_path", default=None, type=str, required=True,
|
|
help="The pre-trained model path for fine tuning [Either original SOLOIST "
|
|
"or a saved model checkpoint]")
|
|
|
|
# Optional
|
|
parser.add_argument("--num_epochs", default=5, type=int,
|
|
help="Total number of training epochs to perform.")
|
|
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
|
parser.add_argument("--batch_size", default=1, type=int, help="Batch size for training.")
|
|
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
|
help="The initial learning rate for Adam Optimizer.")
|
|
parser.add_argument("--weight_decay", default=0.0, type=float,
|
|
help="Weight decay")
|
|
parser.add_argument("--with_inverse_prompt", action="store_true",
|
|
help="Flag for enabling/disabling inverse prompt during training")
|
|
parser.add_argument("--inverse_prompt_weight", default=0.1, type=float,
|
|
help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)")
|
|
parser.add_argument("--with_prompt_ensemble", action="store_true",
|
|
help="Flag for enabling/disabling prompt ensembling during training")
|
|
parser.add_argument("--with_answered_prompts", action="store_true",
|
|
help="Flag to enable/disable the use of answered prompts while validation")
|
|
parser.add_argument("--validation_file", default="", type=str,
|
|
help="Validation file for evaluating model after each epoch")
|
|
parser.add_argument("--validation_with_true_values", action="store_true",
|
|
help="Flag for enabling/disabling the usage of TRUE values for slot generation during "
|
|
"validation")
|
|
|
|
# parse the arguments
|
|
args = parser.parse_args()
|
|
|
|
# setup CUDA device for training on GPU (if available)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
args.n_gpu = torch.cuda.device_count()
|
|
|
|
# Print args
|
|
print('Training Args :: ', json.dumps(vars(args), indent=2))
|
|
save_args(args.save_model_dir, args)
|
|
|
|
# prepare model & tokenizer -> load pre-trained model
|
|
tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_model_path, do_lower_case=True)
|
|
model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_path, pad_token_id=tokenizer.eos_token_id)
|
|
|
|
# set the device to the model
|
|
model.to(device)
|
|
|
|
# set seed
|
|
set_seed(args)
|
|
|
|
# load training dataset
|
|
training_data = PromptDstDataset(args.train_data_file)
|
|
|
|
# load validation dataset
|
|
validation_data = None
|
|
if args.validation_file:
|
|
validation_data = PromptDstDataset(args.validation_file)
|
|
|
|
# create an optimizer and learning rate scheduler to fine-tune the model
|
|
no_decay = ["bias", "layer_norm.weight"]
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
"weight_decay": args.weight_decay,
|
|
},
|
|
{
|
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
"weight_decay": 0.0,
|
|
},
|
|
]
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
|
|
|
lr_scheduler = get_scheduler(
|
|
name="linear",
|
|
optimizer=optimizer,
|
|
num_warmup_steps=0,
|
|
num_training_steps=(args.num_epochs * training_data.total_num_slot_value_pairs)
|
|
)
|
|
|
|
# set tqdm progress bars for Epochs & number of training steps
|
|
num_training_steps = args.num_epochs * training_data.len()
|
|
epochs = tqdm(total=args.num_epochs, desc="Epochs", position=0)
|
|
training_progress = tqdm(total=num_training_steps, desc="Training Progress", position=1)
|
|
|
|
# set the model in training mode
|
|
model.train()
|
|
|
|
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
|
|
|
|
validation_summary = {
|
|
'validation_with_true_values': args.validation_with_true_values
|
|
}
|
|
|
|
# outputs file extension (representing usage of prompt ensemble & answered prompts)
|
|
out_ext = "_pe" if args.with_prompt_ensemble else ""
|
|
out_ext += "_pa" if args.with_answered_prompts else ""
|
|
|
|
# training loop
|
|
for epoch in range(args.num_epochs):
|
|
running_loss = 0.0
|
|
loss_count = 0
|
|
# set the model in training mode (after each epoch)
|
|
model.train()
|
|
for i, item in enumerate(training_data.dataset_items, start=1):
|
|
history = item['history']
|
|
# iterate through (slot, value) pairs
|
|
for slot, value in item['belief_states']:
|
|
|
|
# train/generate using value-based prompt first
|
|
loss, gen_slot = train_prompting(args=args,
|
|
history=history,
|
|
slot_value_pair=(slot, value),
|
|
prompt_type=TYPE_VALUE_BASED_PROMPT,
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
device=device)
|
|
|
|
if args.with_inverse_prompt:
|
|
# use the generated slot from value-based prompt
|
|
# clean/process the generated slot (remove whitespaces & convert to lower case)
|
|
generated_slot = gen_slot.strip().lower()
|
|
|
|
# train slot generation using inverse prompt
|
|
inv_loss, _ = train_prompting(args=args,
|
|
history=history,
|
|
slot_value_pair=(generated_slot, value),
|
|
prompt_type=TYPE_INVERSE_PROMPT,
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
device=device)
|
|
|
|
# compute total loss for this slot-value pair
|
|
loss = loss + (args.inverse_prompt_weight * inv_loss)
|
|
|
|
# store the loss for printing
|
|
running_loss += loss.item()
|
|
loss_count += 1
|
|
|
|
# backward pass & step
|
|
loss.backward()
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
# update progress
|
|
training_progress.update(1)
|
|
|
|
# print loss for every 100 steps
|
|
if i % 100 == 0:
|
|
last_loss = running_loss / loss_count
|
|
tqdm.write(str('Training Loss [Iteration {}, Epoch {}] = {}'.format(i, (epoch + 1), last_loss)))
|
|
running_loss = 0.0
|
|
loss_count = 0
|
|
|
|
# Save the model after finishing an epoch
|
|
epoch_str = "{:02d}".format(epoch + 1)
|
|
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", epoch_str))
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
# save training args for each epoch (useful when testing/generating)
|
|
save_args(output_dir, args)
|
|
model.save_pretrained(output_dir)
|
|
tokenizer.save_pretrained(output_dir)
|
|
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
|
tqdm.write(str('Saving model (after Epoch {} ) to :: {}'.format((epoch + 1), output_dir)))
|
|
|
|
# update epoch progress
|
|
epochs.update(1)
|
|
|
|
# Epoch finished -> continue with validation if the validation file is provided
|
|
# if validation file is provided, run evaluation here (after each epoch)
|
|
if args.validation_file and validation_data is not None:
|
|
|
|
tqdm.write(str('Validation In Progress...[with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
|
|
|
|
# set tqdm progress bars for testing progress
|
|
validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False)
|
|
|
|
# set eval mode
|
|
model.eval()
|
|
|
|
# outputs array -> to be saved to output_dir
|
|
outputs = []
|
|
|
|
# JGA metric
|
|
evaluator = PromptDSTEvaluator()
|
|
|
|
# iterate through validation dataset and generate slots using value-based prompt
|
|
for item in validation_data.dataset_items:
|
|
history = item['history']
|
|
true_states = {}
|
|
gen_states = {}
|
|
# iterate through (slot, value) pairs and generate each slot using value
|
|
for slot, value in item['belief_states']:
|
|
true_states[slot] = value
|
|
|
|
# generate slot using value-based prompt
|
|
generated_slot = generate_slot_from_prompt(args=args,
|
|
history=history,
|
|
value=value,
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
device=device)
|
|
|
|
# add the generated slot to generated states
|
|
gen_states[generated_slot] = value
|
|
|
|
# update tqdm progress
|
|
validation_progress.update(1)
|
|
|
|
# add true belief states & generated belief states to outputs
|
|
outputs.append({"true_states": true_states, "gen_states": gen_states})
|
|
|
|
# add true & generated belief states to evaluator for computing JGA
|
|
evaluator.add_data_item(true_states.copy(), gen_states.copy())
|
|
|
|
validation_progress.close()
|
|
epoch_valid_summary = {}
|
|
# compute JGA & print results
|
|
tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...'))
|
|
jga_score = evaluator.compute_joint_goal_accuracy(no_print=True)
|
|
tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score)))
|
|
|
|
now = datetime.now()
|
|
datetime_str = now.strftime("%Y%m%dT%H%M%S")
|
|
|
|
# outputs file name
|
|
file_name = "outputs{}-{}.json".format(out_ext, datetime_str)
|
|
epoch_valid_summary['file_name'] = file_name
|
|
epoch_valid_summary['jga_score'] = round(jga_score, 3)
|
|
# add epoch summary to valid summary
|
|
validation_summary[epoch_str] = epoch_valid_summary
|
|
|
|
# save the outputs to trained epoch dir
|
|
output_file = os.path.join(output_dir, file_name)
|
|
tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file)))
|
|
json.dump(outputs, open(output_file, 'w'), indent=2)
|
|
|
|
# save validation file summary (if there's data)
|
|
if len(validation_summary) > 1:
|
|
valid_summary_file = os.path.join(args.save_model_dir, 'validation{}.json'.format(out_ext))
|
|
tqdm.write(str('Saving Validation Summary :: {}'.format(valid_summary_file)))
|
|
json.dump(validation_summary, open(valid_summary_file, 'w'), indent=2)
|
|
|
|
|
|
def set_seed(args):
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
if args.n_gpu > 0:
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
def save_args(save_dir, args):
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
args_file = os.path.join(save_dir, 'args.json')
|
|
args_dict = vars(args)
|
|
args_dict['prompt_templates'] = PROMPT_TEMPLATES
|
|
json.dump(args_dict, open(args_file, "w"), indent=2)
|
|
|
|
|
|
def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device):
|
|
# slot_value_pair = (slot, value)
|
|
|
|
# use prompt ensemble when set in the args
|
|
if prompt_type is TYPE_VALUE_BASED_PROMPT and args.with_prompt_ensemble:
|
|
return train_prompt_ensemble(history=history,
|
|
slot_value_pair=slot_value_pair,
|
|
prompt_type=TYPE_PROMPT_ENSEMBLE,
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
device=device)
|
|
|
|
# get prompt for training based on "type"
|
|
prompt = get_prompt_for_training(prompt_type, slot_value_pair)
|
|
|
|
# combine history and prompt
|
|
input_prompt = history + prompt
|
|
|
|
# encode the history & prompt using tokenizer
|
|
encoded_prompt = tokenizer(input_prompt, return_tensors="pt")
|
|
encoded_prompt.to(device)
|
|
|
|
# get the last token id
|
|
# this could be a slot or value depending on prompt type
|
|
last_token = encoded_prompt['input_ids'][:, -1:]
|
|
last_token.to(device)
|
|
|
|
# model outputs
|
|
outputs = model(**encoded_prompt)
|
|
|
|
# get last token logits [-2 -> for last but one item]
|
|
logits = outputs.logits[:, -2, :]
|
|
|
|
# softmax probabilities
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
|
# last token generation probability
|
|
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
|
|
loss = torch.negative(torch.log(last_token_prob))
|
|
|
|
# generated word -> the one with the highest probability
|
|
generated_word = None
|
|
if prompt_type == TYPE_VALUE_BASED_PROMPT:
|
|
# find the token with the highest probability, this will be the generated word
|
|
gen_word_token = torch.argmax(logits, dim=-1)
|
|
generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
|
|
|
|
# loss is the log of probability
|
|
return loss, generated_word
|
|
|
|
|
|
def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, model, device):
|
|
# slot_value_pair = (slot, value)
|
|
# get list of prompts for training
|
|
prompts = get_prompt_for_training(prompt_type, slot_value_pair)
|
|
|
|
# total probability of all prompt functions
|
|
total_prob = None
|
|
|
|
gen_probs, gen_words = [], []
|
|
|
|
# iterate through each prompt
|
|
for prompt in prompts:
|
|
|
|
# combine history and prompt
|
|
input_prompt = history + prompt
|
|
|
|
# encode the history & prompt using tokenizer
|
|
encoded_prompt = tokenizer(input_prompt, return_tensors="pt")
|
|
encoded_prompt.to(device)
|
|
|
|
# get the last token id
|
|
# this could be a slot or value depending on prompt type
|
|
last_token = encoded_prompt['input_ids'][:, -1:]
|
|
last_token.to(device)
|
|
|
|
# model outputs
|
|
outputs = model(**encoded_prompt)
|
|
|
|
# get last token logits [-2 -> for last but one item]
|
|
logits = outputs.logits[:, -2, :]
|
|
|
|
# softmax probabilities
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
|
# last token generation probability
|
|
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
|
|
|
|
# weighted probability | sum of all prompt weights must be equals to 1
|
|
token_prob = PROMPT_WEIGHT * last_token_prob
|
|
|
|
# sum the probs for all prompt functions
|
|
if total_prob is None:
|
|
total_prob = token_prob
|
|
else:
|
|
total_prob += token_prob
|
|
|
|
# generated slot
|
|
# find the token with the highest probability, this will be the generated word
|
|
gen_word_token = torch.argmax(logits, dim=-1)
|
|
gen_word_prob = torch.gather(probs, 1, gen_word_token[:, None]).squeeze(-1)
|
|
gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
|
|
|
|
# add the generated word and probs to list
|
|
gen_probs.append(gen_word_prob.item())
|
|
gen_words.append(gen_word)
|
|
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
# loss is the log of probability
|
|
loss = torch.negative(torch.log(total_prob))
|
|
return loss, generated_word
|
|
|
|
|
|
# Use this for generating next word (Validation after each epoch)
|
|
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
|
|
# check if prompt ensemble is enabled in arguments
|
|
if args.with_prompt_ensemble:
|
|
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
|
|
|
|
# get value-based prompt for generation
|
|
prompt = get_value_based_prompt(value)
|
|
|
|
# combine history and prompt
|
|
prompt = history + prompt
|
|
# encode the history & prompt
|
|
encoded_prompt = tokenizer(prompt, return_tensors="pt")
|
|
encoded_prompt.to(device)
|
|
|
|
# generate 1 token
|
|
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
|
|
|
|
gen_token = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
|
|
generated_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True)
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
|
|
# get prompts for ensemble generation
|
|
prompts = get_ensemble_prompts(value)
|
|
|
|
gen_probs, gen_words = [], []
|
|
|
|
for prompt in prompts:
|
|
# combine history and prompt
|
|
prompt = history + prompt
|
|
|
|
# encode the history & prompt
|
|
encoded_prompt = tokenizer(prompt, return_tensors="pt")
|
|
encoded_prompt.to(device)
|
|
|
|
# generate 1 token (max length of slot = 1)
|
|
outputs = model.generate(**encoded_prompt,
|
|
return_dict_in_generate=True,
|
|
output_scores=True,
|
|
max_new_tokens=1)
|
|
|
|
gen_token = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:]
|
|
gen_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True).strip()
|
|
|
|
probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
|
|
gen_prob = torch.gather(probs, 1, gen_token).squeeze(-1)
|
|
|
|
# add the generated word and probs to list
|
|
gen_probs.append(gen_prob.item())
|
|
gen_words.append(gen_word)
|
|
|
|
# return word with the highest probability
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|