You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

200 lines
7.4 KiB

import argparse
import numpy as np
import os
import json
import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from dataset import PromptDstDataset
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from prompt_utils import get_ensemble_prompts
from metrics import PromptDSTEvaluator
from datetime import datetime
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
# Use this for generating next word (Testing)
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
# check if prompt ensemble is enabled in arguments
if args.with_prompt_ensemble:
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
# get value-based prompt for generating slots
prompt = get_value_based_prompt(value)
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token (max length of slot = 1)
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_token_id = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True)
return generated_word.strip().lower()
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
# get prompts for ensemble generation
prompts = get_ensemble_prompts(value)
gen_probs, gen_words = [], []
for prompt in prompts:
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token (max length of slot = 1)
outputs = model.generate(**encoded_prompt,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=1)
gen_token_id = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:]
gen_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True).strip()
probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
gen_prob = torch.gather(probs, 1, gen_token_id).squeeze(-1)
# add the generated word and probs to list
gen_probs.append(gen_prob.item())
gen_words.append(gen_word)
# return word with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
return generated_word.strip().lower()
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--test_data_file", default=None, type=str, required=True,
help="The test/eval data file <JSON Path>.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The directory where the predictions should be saved")
parser.add_argument("--tuned_model_path", default=None, type=str, required=True,
help="The fine-tuned model path")
# Optional
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--with_prompt_ensemble", action="store_true",
help="Flag for enabling/disabling prompt ensembling while generating")
parser.add_argument("--with_answered_prompts", action="store_true",
help="Flag to enable/disable the use of answered prompts while generating")
# parse the arguments
args = parser.parse_args()
# check for args.json file in the saved model path & check if trained with prompt ensemble
args_file = os.path.join(args.tuned_model_path, 'args.json')
if os.path.isfile(args_file):
args_dict = json.load(open(args_file))
if 'with_prompt_ensemble' in args_dict:
args.with_prompt_ensemble = args_dict['with_prompt_ensemble']
else:
print("No 'args.json' file found in the saved epoch dir!")
# setup CUDA device for training on GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
print('Generation Config :: ', json.dumps(vars(args), indent=2))
# prepare model & tokenizer -> load pre-trained model
tokenizer = GPT2Tokenizer.from_pretrained(args.tuned_model_path)
model = AutoModelForCausalLM.from_pretrained(args.tuned_model_path, pad_token_id=tokenizer.eos_token_id)
# set the device to the model
model.to(device)
# set seed
set_seed(args)
# load testing/eval dataset
dataset = PromptDstDataset(args.test_data_file)
# set tqdm progress bars for Epochs & number of training steps
progress = tqdm(total=dataset.len(), desc="Progress")
# set eval mode
model.eval()
# outputs array -> to be saved to output_dir
outputs = []
# JGA metric
evaluator = PromptDSTEvaluator()
tqdm.write(str('Generating slots now...'))
# iterate through test dataset and generate slots
for item in dataset.dataset_items:
history = item['history']
true_states = {}
gen_states = {}
# iterate through (slot, value) pairs and add them to true states
for slot, value in item['belief_states']:
true_states[slot] = value
# iterate through (slot, value) pairs and generate each slot using value
for value in item['values']:
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(args=args,
history=history,
value=value,
tokenizer=tokenizer,
model=model,
device=device)
# add the generated slot to generated states
gen_states[generated_slot] = value
# update tqdm progress
progress.update(1)
# add true belief states & generated belief states to outputs
outputs.append({"history": history,
"extracted_values": item['values'],
"true_states": true_states,
"gen_states": gen_states})
# add true & generated belief states to evaluator for computing JGA
evaluator.add_data_item(true_states.copy(), gen_states.copy())
progress.close()
# compute JGA & print results
evaluator.compute_joint_goal_accuracy()
# compute JGA* & print results (JGA* -> consider values that are extracted correctly)
evaluator.compute_jga_for_correct_values()
# output file extension (for prompt ensemble & answered prompts)
out_ext = "_pe" if args.with_prompt_ensemble else ""
out_ext += "_pa" if args.with_answered_prompts else ""
# save the outputs to output_dir
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S")
output_file = os.path.join(args.output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str))
print('Saving Outputs file :: ', output_file)
json.dump(outputs, open(output_file, 'w'), indent=2)
if __name__ == "__main__":
main()