import argparse import collections import numpy as np import os import json import torch from transformers import AutoModelForCausalLM, GPT2Tokenizer from dataset import PromptDstDataset from tqdm.auto import tqdm from prompt_utils import get_value_based_prompt from prompt_utils import get_ensemble_prompts from metrics import PromptDSTEvaluator from datetime import datetime SLOT_PICK_HIGHEST_PROB = "highest" SLOT_PICK_SIMPLE_MAJORITY = "majority" def set_seed(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Use this for generating slot <- Inference (Testing) def generate_slot_from_prompt(args, history, value, tokenizer, model, device): # check if prompt ensemble is enabled in arguments if args.with_prompt_ensemble: return generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device) # get value-based prompt for generating slots prompt = get_value_based_prompt(value) # combine history and prompt prompt = history + prompt # encode the history & prompt encoded_prompt = tokenizer(prompt, return_tensors="pt") encoded_prompt.to(device) # generate 1 token (max length of slot = 1) outputs = model.generate(**encoded_prompt, max_new_tokens=1) gen_token_id = outputs[:, encoded_prompt['input_ids'].shape[-1]:] generated_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True) return generated_word.strip().lower() def generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device): # get prompts for ensemble generation prompts = get_ensemble_prompts(value) gen_probs, gen_words = [], [] for prompt in prompts: # combine history and prompt prompt = history + prompt # encode the history & prompt encoded_prompt = tokenizer(prompt, return_tensors="pt") encoded_prompt.to(device) # generate 1 token (max length of slot = 1) outputs = model.generate(**encoded_prompt, return_dict_in_generate=True, output_scores=True, max_new_tokens=1) gen_token_id = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:] gen_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True).strip() probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1) gen_prob = torch.gather(probs, 1, gen_token_id).squeeze(-1) # add the generated word and probs to list gen_probs.append(gen_prob.item()) gen_words.append(gen_word) if args.ensemble_pick_slot == SLOT_PICK_HIGHEST_PROB: # return the slot with the highest probability generated_word = gen_words[gen_probs.index(max(gen_probs))] return generated_word.strip().lower() elif args.ensemble_pick_slot == SLOT_PICK_SIMPLE_MAJORITY: # do a simple majority voting for generated slots # if there's no simple majority, pick the slot with the highest probability word_counter = collections.Counter(gen_words) max_slot = max(word_counter, key=word_counter.get) max_slot_count = word_counter[max_slot] if max_slot_count >= 3 or (max_slot_count == 2 and len(word_counter) > 2): # winner slot (simple majority) generated_word = max_slot else: # generated slot with the highest probability generated_word = gen_words[gen_probs.index(max(gen_probs))] # return the generated slot return generated_word.strip().lower() def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--test_data_file", default=None, type=str, required=True, help="The test/eval data file .") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The directory where the predictions should be saved") parser.add_argument("--tuned_model_path", default=None, type=str, required=True, help="The fine-tuned model path") # Optional parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument("--with_prompt_ensemble", action="store_true", help="Flag for enabling/disabling prompt ensembling while generating") parser.add_argument("--ensemble_pick_slot", type=str, default=SLOT_PICK_SIMPLE_MAJORITY, help="Flag for setting the algorithm to pick the generated slot from ensemble outputs") parser.add_argument("--with_answered_prompts", action="store_true", help="Flag to enable/disable the use of answered prompts while generating") # parse the arguments args = parser.parse_args() # check for args.json file in the saved model path & check if trained with prompt ensemble args_file = os.path.join(args.tuned_model_path, 'args.json') if os.path.isfile(args_file): args_dict = json.load(open(args_file)) if 'with_prompt_ensemble' in args_dict: args.with_prompt_ensemble = args_dict['with_prompt_ensemble'] else: print("No 'args.json' file found in the saved epoch dir!") # setup CUDA device for training on GPU (if available) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.n_gpu = torch.cuda.device_count() print('Generation Config :: ', json.dumps(vars(args), indent=2)) # prepare model & tokenizer -> load pre-trained model tokenizer = GPT2Tokenizer.from_pretrained(args.tuned_model_path) model = AutoModelForCausalLM.from_pretrained(args.tuned_model_path, pad_token_id=tokenizer.eos_token_id) # set the device to the model model.to(device) # set seed set_seed(args) # load testing/eval dataset dataset = PromptDstDataset(args.test_data_file) # set tqdm progress bars for Epochs & number of training steps progress = tqdm(total=dataset.len(), desc="Progress") # set eval mode model.eval() # outputs array -> to be saved to output_dir outputs = [] # JGA metric evaluator = PromptDSTEvaluator() tqdm.write(str('Generating slots now...')) # iterate through test dataset and generate slots for item in dataset.dataset_items: history = item['history'] true_states = {} gen_states = {} # iterate through (slot, value) pairs and add them to true states for slot, value in item['belief_states']: true_states[slot] = value # iterate through (slot, value) pairs and generate each slot using value for value in item['values']: # generate slot using value-based prompt generated_slot = generate_slot_from_prompt(args=args, history=history, value=value, tokenizer=tokenizer, model=model, device=device) # add the generated slot to generated states gen_states[generated_slot] = value # update tqdm progress progress.update(1) # add true belief states & generated belief states to outputs outputs.append({"history": history, "extracted_values": item['values'], "true_states": true_states, "gen_states": gen_states}) # add true & generated belief states to evaluator for computing JGA evaluator.add_data_item(true_states.copy(), gen_states.copy()) progress.close() # compute JGA & print results evaluator.compute_joint_goal_accuracy() # compute JGA* & print results (JGA* -> consider values that are extracted correctly) evaluator.compute_jga_for_correct_values() # output file extension (for prompt ensemble & answered prompts) out_ext = "_pe" if args.with_prompt_ensemble else "" out_ext += "_pa" if args.with_answered_prompts else "" # save the outputs to output_dir if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) now = datetime.now() datetime_str = now.strftime("%Y%m%dT%H%M%S") output_file = os.path.join(args.output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str)) print('Saving Outputs file :: ', output_file) json.dump(outputs, open(output_file, 'w'), indent=2) if __name__ == "__main__": main()