You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
7.4 KiB
200 lines
7.4 KiB
import argparse
|
|
import numpy as np
|
|
import os
|
|
import json
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, GPT2Tokenizer
|
|
from dataset import PromptDstDataset
|
|
from tqdm.auto import tqdm
|
|
from prompt_utils import get_value_based_prompt
|
|
from prompt_utils import get_ensemble_prompts
|
|
from metrics import PromptDSTEvaluator
|
|
from datetime import datetime
|
|
|
|
|
|
def set_seed(args):
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
if args.n_gpu > 0:
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
# Use this for generating next word (Testing)
|
|
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
|
|
# check if prompt ensemble is enabled in arguments
|
|
if args.with_prompt_ensemble:
|
|
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
|
|
|
|
# get value-based prompt for generating slots
|
|
prompt = get_value_based_prompt(value)
|
|
|
|
# combine history and prompt
|
|
prompt = history + prompt
|
|
# encode the history & prompt
|
|
encoded_prompt = tokenizer(prompt, return_tensors="pt")
|
|
encoded_prompt.to(device)
|
|
|
|
# generate 1 token (max length of slot = 1)
|
|
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
|
|
|
|
gen_token_id = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
|
|
generated_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True)
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
|
|
# get prompts for ensemble generation
|
|
prompts = get_ensemble_prompts(value)
|
|
|
|
gen_probs, gen_words = [], []
|
|
|
|
for prompt in prompts:
|
|
# combine history and prompt
|
|
prompt = history + prompt
|
|
|
|
# encode the history & prompt
|
|
encoded_prompt = tokenizer(prompt, return_tensors="pt")
|
|
encoded_prompt.to(device)
|
|
|
|
# generate 1 token (max length of slot = 1)
|
|
outputs = model.generate(**encoded_prompt,
|
|
return_dict_in_generate=True,
|
|
output_scores=True,
|
|
max_new_tokens=1)
|
|
|
|
gen_token_id = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:]
|
|
gen_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True).strip()
|
|
|
|
probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
|
|
gen_prob = torch.gather(probs, 1, gen_token_id).squeeze(-1)
|
|
|
|
# add the generated word and probs to list
|
|
gen_probs.append(gen_prob.item())
|
|
gen_words.append(gen_word)
|
|
|
|
# return word with the highest probability
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Required parameters
|
|
parser.add_argument("--test_data_file", default=None, type=str, required=True,
|
|
help="The test/eval data file <JSON Path>.")
|
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
|
help="The directory where the predictions should be saved")
|
|
parser.add_argument("--tuned_model_path", default=None, type=str, required=True,
|
|
help="The fine-tuned model path")
|
|
|
|
# Optional
|
|
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
|
parser.add_argument("--with_prompt_ensemble", action="store_true",
|
|
help="Flag for enabling/disabling prompt ensembling while generating")
|
|
parser.add_argument("--with_answered_prompts", action="store_true",
|
|
help="Flag to enable/disable the use of answered prompts while generating")
|
|
|
|
# parse the arguments
|
|
args = parser.parse_args()
|
|
|
|
# check for args.json file in the saved model path & check if trained with prompt ensemble
|
|
args_file = os.path.join(args.tuned_model_path, 'args.json')
|
|
if os.path.isfile(args_file):
|
|
args_dict = json.load(open(args_file))
|
|
if 'with_prompt_ensemble' in args_dict:
|
|
args.with_prompt_ensemble = args_dict['with_prompt_ensemble']
|
|
else:
|
|
print("No 'args.json' file found in the saved epoch dir!")
|
|
|
|
# setup CUDA device for training on GPU (if available)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
args.n_gpu = torch.cuda.device_count()
|
|
|
|
print('Generation Config :: ', json.dumps(vars(args), indent=2))
|
|
|
|
# prepare model & tokenizer -> load pre-trained model
|
|
tokenizer = GPT2Tokenizer.from_pretrained(args.tuned_model_path)
|
|
model = AutoModelForCausalLM.from_pretrained(args.tuned_model_path, pad_token_id=tokenizer.eos_token_id)
|
|
|
|
# set the device to the model
|
|
model.to(device)
|
|
|
|
# set seed
|
|
set_seed(args)
|
|
|
|
# load testing/eval dataset
|
|
dataset = PromptDstDataset(args.test_data_file)
|
|
|
|
# set tqdm progress bars for Epochs & number of training steps
|
|
progress = tqdm(total=dataset.len(), desc="Progress")
|
|
|
|
# set eval mode
|
|
model.eval()
|
|
|
|
# outputs array -> to be saved to output_dir
|
|
outputs = []
|
|
|
|
# JGA metric
|
|
evaluator = PromptDSTEvaluator()
|
|
|
|
tqdm.write(str('Generating slots now...'))
|
|
|
|
# iterate through test dataset and generate slots
|
|
for item in dataset.dataset_items:
|
|
history = item['history']
|
|
true_states = {}
|
|
gen_states = {}
|
|
|
|
# iterate through (slot, value) pairs and add them to true states
|
|
for slot, value in item['belief_states']:
|
|
true_states[slot] = value
|
|
|
|
# iterate through (slot, value) pairs and generate each slot using value
|
|
for value in item['values']:
|
|
# generate slot using value-based prompt
|
|
generated_slot = generate_slot_from_prompt(args=args,
|
|
history=history,
|
|
value=value,
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
device=device)
|
|
|
|
# add the generated slot to generated states
|
|
gen_states[generated_slot] = value
|
|
|
|
# update tqdm progress
|
|
progress.update(1)
|
|
|
|
# add true belief states & generated belief states to outputs
|
|
outputs.append({"history": history,
|
|
"extracted_values": item['values'],
|
|
"true_states": true_states,
|
|
"gen_states": gen_states})
|
|
|
|
# add true & generated belief states to evaluator for computing JGA
|
|
evaluator.add_data_item(true_states.copy(), gen_states.copy())
|
|
|
|
progress.close()
|
|
# compute JGA & print results
|
|
evaluator.compute_joint_goal_accuracy()
|
|
# compute JGA* & print results (JGA* -> consider values that are extracted correctly)
|
|
evaluator.compute_jga_for_correct_values()
|
|
|
|
# output file extension (for prompt ensemble & answered prompts)
|
|
out_ext = "_pe" if args.with_prompt_ensemble else ""
|
|
out_ext += "_pa" if args.with_answered_prompts else ""
|
|
# save the outputs to output_dir
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir)
|
|
now = datetime.now()
|
|
datetime_str = now.strftime("%Y%m%dT%H%M%S")
|
|
output_file = os.path.join(args.output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str))
|
|
print('Saving Outputs file :: ', output_file)
|
|
json.dump(outputs, open(output_file, 'w'), indent=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|