You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
4.6 KiB

import argparse
import numpy as np
import os
import json
import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from dataset import PromptDstDataset
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from metrics import PromptDSTEvaluator
from datetime import datetime
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt(history, value, tokenizer, model, device):
# get prompt for training based on "type"
prompt = get_value_based_prompt(value)
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True)
return generated_word.strip().lower()
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--test_data_file", default=None, type=str, required=True,
help="The test/eval data file <JSON Path>.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The directory where the predictions should be saved")
parser.add_argument("--tuned_model_path", default=None, type=str, required=True,
help="The fine-tuned model path")
# Optional
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
# parse the arguments
args = parser.parse_args()
# setup CUDA device for training on GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
# prepare model & tokenizer -> load pre-trained model
tokenizer = GPT2Tokenizer.from_pretrained(args.tuned_model_path)
model = AutoModelForCausalLM.from_pretrained(args.tuned_model_path, pad_token_id=tokenizer.eos_token_id)
# set the device to the model
model.to(device)
# set seed
set_seed(args)
# load testing/eval dataset
dataset = PromptDstDataset(args.test_data_file)
# set tqdm progress bars for Epochs & number of training steps
progress = tqdm(total=dataset.len(), desc="Progress")
# set eval mode
model.eval()
# outputs array -> to be saved to output_dir
outputs = []
# JGA metric
evaluator = PromptDSTEvaluator()
# iterate through test dataset and generate slots
for item in dataset.dataset_items:
history = item['history']
true_states = {}
gen_states = {}
# iterate through (slot, value) pairs and add them to true states
for slot, value in item['belief_states']:
true_states[slot] = value
# iterate through (slot, value) pairs and generate each slot using value
for value in item['values']:
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(history=history,
value=value,
tokenizer=tokenizer,
model=model,
device=device)
# add the generated slot to generated states
gen_states[generated_slot] = value
# update tqdm progress
progress.update(1)
# add true belief states & generated belief states to outputs
outputs.append({"history": history,
"true_states": true_states,
"gen_states": gen_states})
# add true & generated belief states to evaluator for computing JGA
evaluator.add_data_item(true_states.copy(), gen_states.copy())
# compute JGA & print results
evaluator.compute_joint_goal_accuracy()
# save the outputs to output_dir
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S")
output_file = os.path.join(args.output_dir, 'outputs-{}.json'.format(datetime_str))
print('Saving Outputs file :: ', output_file)
json.dump(outputs, open(output_file, 'w'), indent=2)
if __name__ == "__main__":
main()