import argparse import numpy as np import os import json import torch from transformers import AutoModelForCausalLM, GPT2Tokenizer from dataset import PromptDstDataset from torch.optim import AdamW from transformers import get_scheduler from tqdm.auto import tqdm from prompt_utils import get_value_based_prompt from prompt_utils import get_ensemble_prompts from prompt_utils import get_prompt_for_training from prompt_utils import TYPE_VALUE_BASED_PROMPT from prompt_utils import TYPE_INVERSE_PROMPT from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT from metrics import PromptDSTEvaluator from datetime import datetime def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--train_data_file", default=None, type=str, required=True, help="The input training data file .") parser.add_argument("--save_model_dir", default=None, type=str, required=True, help="The directory where the model should be saved") parser.add_argument("--pretrained_model_path", default=None, type=str, required=True, help="The pre-trained model path for fine tuning [Either original SOLOIST " "or a saved model checkpoint]") # Optional parser.add_argument("--num_epochs", default=5, type=int, help="Total number of training epochs to perform.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument("--batch_size", default=1, type=int, help="Batch size for training.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam Optimizer.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay") parser.add_argument("--with_inverse_prompt", action="store_true", help="Flag for enabling/disabling inverse prompt during training") parser.add_argument("--inverse_prompt_weight", default=0.1, type=float, help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)") parser.add_argument("--with_prompt_ensemble", action="store_true", help="Flag for enabling/disabling prompt ensembling during training") parser.add_argument("--validation_file", default="", type=str, help="Validation file for evaluating model after each epoch") # parse the arguments args = parser.parse_args() # setup CUDA device for training on GPU (if available) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.n_gpu = torch.cuda.device_count() # prepare model & tokenizer -> load pre-trained model tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_model_path, do_lower_case=True) model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_path, pad_token_id=tokenizer.eos_token_id) # set the device to the model model.to(device) # set seed set_seed(args) # load training dataset training_data = PromptDstDataset(args.train_data_file) # load validation dataset validation_data = None if args.validation_file: validation_data = PromptDstDataset(args.validation_file) # create an optimizer and learning rate scheduler to fine-tune the model no_decay = ["bias", "layer_norm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=(args.num_epochs * training_data.total_num_slot_value_pairs) ) # set tqdm progress bars for Epochs & number of training steps num_training_steps = args.num_epochs * training_data.len() epochs = tqdm(total=args.num_epochs, desc="Epochs", position=0) training_progress = tqdm(total=num_training_steps, desc="Training Progress", position=1) # set the model in training mode model.train() tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) # training loop for epoch in range(args.num_epochs): running_loss = 0.0 loss_count = 0 # set the model in training mode (after each epoch) model.train() for i, item in enumerate(training_data.dataset_items, start=1): history = item['history'] # iterate through (slot, value) pairs for slot, value in item['belief_states']: # train/generate using value-based prompt first loss, gen_slot = train_prompting(args=args, history=history, slot_value_pair=(slot, value), prompt_type=TYPE_VALUE_BASED_PROMPT, tokenizer=tokenizer, model=model, device=device) if args.with_inverse_prompt: # use the generated slot from value-based prompt # clean/process the generated slot (remove whitespaces & convert to lower case) generated_slot = gen_slot.strip().lower() # train slot generation using inverse prompt inv_loss, _ = train_prompting(args=args, history=history, slot_value_pair=(generated_slot, value), prompt_type=TYPE_INVERSE_PROMPT, tokenizer=tokenizer, model=model, device=device) # compute total loss for this slot-value pair loss = loss + (args.inverse_prompt_weight * inv_loss) # store the loss for printing running_loss += loss.item() loss_count += 1 # backward pass & step loss.backward() optimizer.step() lr_scheduler.step() optimizer.zero_grad() # update progress training_progress.update(1) # print loss for every 100 steps if i % 100 == 0: last_loss = running_loss / loss_count tqdm.write(str('Training Loss [Iteration {}, Epoch {}] = {}'.format(i, (epoch + 1), last_loss))) running_loss = 0.0 loss_count = 0 # Save the model after finishing an epoch output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1))) if not os.path.exists(output_dir): os.makedirs(output_dir) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, 'training_args.bin')) tqdm.write(str('Saving model (after Epoch {} ) to :: {}'.format((epoch + 1), output_dir))) # update epoch progress epochs.update(1) # Epoch finished -> continue with validation if the validation file is provided # if validation file is provided, run evaluation here (after each epoch) if args.validation_file and validation_data is not None: tqdm.write(str('Validation In Progress...[with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) # set tqdm progress bars for testing progress validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False) # set eval mode model.eval() # outputs array -> to be saved to output_dir outputs = [] # JGA metric evaluator = PromptDSTEvaluator() # iterate through validation dataset and generate slots using value-based prompt for item in validation_data.dataset_items: history = item['history'] true_states = {} gen_states = {} # iterate through (slot, value) pairs and generate each slot using value for slot, value in item['belief_states']: true_states[slot] = value # generate slot using value-based prompt generated_slot = generate_slot_from_prompt(args=args, history=history, value=value, tokenizer=tokenizer, model=model, device=device) # add the generated slot to generated states gen_states[generated_slot] = value # update tqdm progress validation_progress.update(1) # add true belief states & generated belief states to outputs outputs.append({"true_states": true_states, "gen_states": gen_states}) # add true & generated belief states to evaluator for computing JGA evaluator.add_data_item(true_states.copy(), gen_states.copy()) validation_progress.close() # compute JGA & print results tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...')) jga_score = evaluator.compute_joint_goal_accuracy(no_print=True) tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score))) # save the outputs to trained epoch dir now = datetime.now() datetime_str = now.strftime("%Y%m%dT%H%M%S") output_file = os.path.join(output_dir, 'outputs-{}.json'.format(datetime_str)) tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file))) json.dump(outputs, open(output_file, 'w'), indent=2) def set_seed(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device): # slot_value_pair = (slot, value) # use prompt ensemble when set in the args if prompt_type is TYPE_VALUE_BASED_PROMPT and args.with_prompt_ensemble: return train_prompt_ensemble(history=history, slot_value_pair=slot_value_pair, prompt_type=TYPE_PROMPT_ENSEMBLE, tokenizer=tokenizer, model=model, device=device) # get prompt for training based on "type" prompt = get_prompt_for_training(prompt_type, slot_value_pair) # combine history and prompt input_prompt = history + prompt # encode the history & prompt using tokenizer encoded_prompt = tokenizer(input_prompt, return_tensors="pt") encoded_prompt.to(device) # get the last token id # this could be a slot or value depending on prompt type last_token = encoded_prompt['input_ids'][:, -1:] last_token.to(device) # model outputs outputs = model(**encoded_prompt) # get last token logits [-2 -> for last but one item] logits = outputs.logits[:, -2, :] # softmax probabilities probs = torch.nn.functional.softmax(logits, dim=-1) # last token generation probability last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1) loss = torch.negative(torch.log(last_token_prob)) # generated word -> the one with the highest probability generated_word = None if prompt_type == TYPE_VALUE_BASED_PROMPT: # find the token with the highest probability, this will be the generated word gen_word_token = torch.argmax(logits, dim=-1) generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip() # loss is the log of probability return loss, generated_word def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, model, device): # slot_value_pair = (slot, value) # get list of prompts for training prompts = get_prompt_for_training(prompt_type, slot_value_pair) # return total loss total_loss = None gen_probs, gen_words = [], [] # iterate through each prompt for prompt in prompts: # combine history and prompt input_prompt = history + prompt # encode the history & prompt using tokenizer encoded_prompt = tokenizer(input_prompt, return_tensors="pt") encoded_prompt.to(device) # get the last token id # this could be a slot or value depending on prompt type last_token = encoded_prompt['input_ids'][:, -1:] last_token.to(device) # model outputs outputs = model(**encoded_prompt) # get last token logits [-2 -> for last but one item] logits = outputs.logits[:, -2, :] # softmax probabilities probs = torch.nn.functional.softmax(logits, dim=-1) # last token generation probability last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1) # weighted probability token_prob = PROMPT_WEIGHT * last_token_prob loss = torch.negative(torch.log(token_prob)) if total_loss is None: total_loss = loss else: total_loss += loss # generated slot # find the token with the highest probability, this will be the generated word gen_word_token = torch.argmax(logits, dim=-1) gen_word_prob = torch.gather(probs, 1, gen_word_token[:, None]).squeeze(-1) gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip() # add the generated word and probs to list gen_probs.append(gen_word_prob.item()) gen_words.append(gen_word) generated_word = gen_words[gen_probs.index(max(gen_probs))] # loss is the log of probability return total_loss, generated_word # Use this for generating next word (Validation after each epoch) def generate_slot_from_prompt(args, history, value, tokenizer, model, device): # check if prompt ensemble is enabled in arguments if args.with_prompt_ensemble: return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) # get value-based prompt for generation prompt = get_value_based_prompt(value) # combine history and prompt prompt = history + prompt # encode the history & prompt encoded_prompt = tokenizer(prompt, return_tensors="pt") encoded_prompt.to(device) # generate 1 token outputs = model.generate(**encoded_prompt, max_new_tokens=1) gen_token = outputs[:, encoded_prompt['input_ids'].shape[-1]:] generated_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True) return generated_word.strip().lower() def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device): # get prompts for ensemble generation prompts = get_ensemble_prompts(value) gen_probs, gen_words = [], [] for prompt in prompts: # combine history and prompt prompt = history + prompt # encode the history & prompt encoded_prompt = tokenizer(prompt, return_tensors="pt") encoded_prompt.to(device) # generate 1 token outputs = model.generate(**encoded_prompt, return_dict_in_generate=True, output_scores=True, max_new_tokens=1) gen_token = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:] gen_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True).strip() probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1) gen_prob = torch.gather(probs, 1, gen_token).squeeze(-1) # add the generated word and probs to list gen_probs.append(gen_prob.item()) gen_words.append(gen_word) # return word with the highest probability generated_word = gen_words[gen_probs.index(max(gen_probs))] return generated_word.strip().lower() if __name__ == "__main__": main()