You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

235 lines
8.9 KiB

import argparse
import collections
import numpy as np
import os
import json
import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from dataset import PromptDstDataset
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from prompt_utils import get_ensemble_prompts
from prompt_utils import get_answered_prompts
from metrics import PromptDSTEvaluator
from datetime import datetime
SLOT_PICK_HIGHEST_PROB = "highest"
SLOT_PICK_SIMPLE_MAJORITY = "majority"
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
# Use this for generating slot <- Inference (Testing)
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
# check if prompt ensemble is enabled in arguments
if args.with_prompt_ensemble:
return generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device)
# append prompt demonstrations here
answered_prompts = ''
if args.with_answered_prompts:
answered_prompts = get_answered_prompts()
# get value-based prompt for generating slots
prompt = get_value_based_prompt(value)
# combine history and prompt
prompt = answered_prompts + history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token (max length of slot = 1)
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_token_id = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True)
return generated_word.strip().lower()
def generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device):
# get prompts for ensemble generation
prompts = get_ensemble_prompts(value)
gen_probs, gen_words = [], []
answered_prompts = ''
if args.with_answered_prompts:
answered_prompts = get_answered_prompts()
for prompt in prompts:
# combine history and prompt
prompt = answered_prompts + history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token (max length of slot = 1)
outputs = model.generate(**encoded_prompt,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=1)
gen_token_id = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:]
gen_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True).strip()
probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
gen_prob = torch.gather(probs, 1, gen_token_id).squeeze(-1)
# add the generated word and probs to list
gen_probs.append(gen_prob.item())
gen_words.append(gen_word)
if args.ensemble_pick_slot == SLOT_PICK_HIGHEST_PROB:
# return the slot with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
return generated_word.strip().lower()
elif args.ensemble_pick_slot == SLOT_PICK_SIMPLE_MAJORITY:
# do a simple majority voting for generated slots
# if there's no simple majority, pick the slot with the highest probability
word_counter = collections.Counter(gen_words)
max_slot = max(word_counter, key=word_counter.get)
max_slot_count = word_counter[max_slot]
if max_slot_count >= 3 or (max_slot_count == 2 and len(word_counter) > 2):
# winner slot (simple majority)
generated_word = max_slot
else:
# generated slot with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
# return the generated slot
return generated_word.strip().lower()
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--test_data_file", default=None, type=str, required=True,
help="The test/eval data file <JSON Path>.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The directory where the predictions should be saved")
parser.add_argument("--tuned_model_path", default=None, type=str, required=True,
help="The fine-tuned model path")
# Optional
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--with_prompt_ensemble", action="store_true",
help="Flag for enabling/disabling prompt ensembling while generating")
parser.add_argument("--ensemble_pick_slot", type=str, default=SLOT_PICK_SIMPLE_MAJORITY,
help="Flag for setting the algorithm to pick the generated slot from ensemble outputs")
parser.add_argument("--with_answered_prompts", action="store_true",
help="Flag to enable/disable the use of answered prompts while generating")
# parse the arguments
args = parser.parse_args()
# check for args.json file in the saved model path & check if trained with prompt ensemble
args_file = os.path.join(args.tuned_model_path, 'args.json')
if os.path.isfile(args_file):
args_dict = json.load(open(args_file))
if 'with_prompt_ensemble' in args_dict:
args.with_prompt_ensemble = args_dict['with_prompt_ensemble']
else:
print("No 'args.json' file found in the saved epoch dir!")
# setup CUDA device for training on GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
print('Generation Config :: ', json.dumps(vars(args), indent=2))
# prepare model & tokenizer -> load pre-trained model
tokenizer = GPT2Tokenizer.from_pretrained(args.tuned_model_path)
model = AutoModelForCausalLM.from_pretrained(args.tuned_model_path, pad_token_id=tokenizer.eos_token_id)
# set the device to the model
model.to(device)
# set seed
set_seed(args)
# load testing/eval dataset
dataset = PromptDstDataset(args.test_data_file)
# set tqdm progress bars for Epochs & number of training steps
progress = tqdm(total=dataset.len(), desc="Progress")
# set eval mode
model.eval()
# outputs array -> to be saved to output_dir
outputs = []
# JGA metric
evaluator = PromptDSTEvaluator()
tqdm.write(str('Generating slots now...'))
# iterate through test dataset and generate slots
for item in dataset.dataset_items:
history = item['history']
true_states = {}
gen_states = {}
# iterate through (slot, value) pairs and add them to true states
for slot, value in item['belief_states']:
true_states[slot] = value
# iterate through (slot, value) pairs and generate each slot using value
for value in item['values']:
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(args=args,
history=history,
value=value,
tokenizer=tokenizer,
model=model,
device=device)
# add the generated slot to generated states
gen_states[generated_slot] = value
# update tqdm progress
progress.update(1)
# add true belief states & generated belief states to outputs
outputs.append({"history": history,
"extracted_values": item['values'],
"true_states": true_states,
"gen_states": gen_states})
# add true & generated belief states to evaluator for computing JGA
evaluator.add_data_item(true_states.copy(), gen_states.copy())
progress.close()
# compute JGA & print results
evaluator.compute_joint_goal_accuracy()
# compute JGA* & print results (JGA* -> consider values that are extracted correctly)
evaluator.compute_jga_for_correct_values()
# output file extension (for prompt ensemble & answered prompts)
out_ext = "_pe" if args.with_prompt_ensemble else ""
out_ext += "_pa" if args.with_answered_prompts else ""
# save the outputs to output_dir
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S")
output_file = os.path.join(args.output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str))
print('Saving Outputs file :: ', output_file)
json.dump(outputs, open(output_file, 'w'), indent=2)
if __name__ == "__main__":
main()