|
|
|
|
@ -1,4 +1,6 @@
|
|
|
|
|
import argparse
|
|
|
|
|
import collections
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
@ -11,6 +13,9 @@ from prompt_utils import get_ensemble_prompts
|
|
|
|
|
from metrics import PromptDSTEvaluator
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
SLOT_PICK_HIGHEST_PROB = "highest"
|
|
|
|
|
SLOT_PICK_SIMPLE_MAJORITY = "majority"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_seed(args):
|
|
|
|
|
np.random.seed(args.seed)
|
|
|
|
|
@ -19,11 +24,11 @@ def set_seed(args):
|
|
|
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Use this for generating next word (Testing)
|
|
|
|
|
# Use this for generating slot <- Inference (Testing)
|
|
|
|
|
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
|
|
|
|
|
# check if prompt ensemble is enabled in arguments
|
|
|
|
|
if args.with_prompt_ensemble:
|
|
|
|
|
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
|
|
|
|
|
return generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device)
|
|
|
|
|
|
|
|
|
|
# get value-based prompt for generating slots
|
|
|
|
|
prompt = get_value_based_prompt(value)
|
|
|
|
|
@ -43,7 +48,7 @@ def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
|
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
|
|
|
|
|
def generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device):
|
|
|
|
|
# get prompts for ensemble generation
|
|
|
|
|
prompts = get_ensemble_prompts(value)
|
|
|
|
|
|
|
|
|
|
@ -73,8 +78,26 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
|
|
|
|
|
gen_probs.append(gen_prob.item())
|
|
|
|
|
gen_words.append(gen_word)
|
|
|
|
|
|
|
|
|
|
# return word with the highest probability
|
|
|
|
|
if args.ensemble_pick_slot == SLOT_PICK_HIGHEST_PROB:
|
|
|
|
|
# return the slot with the highest probability
|
|
|
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
elif args.ensemble_pick_slot == SLOT_PICK_SIMPLE_MAJORITY:
|
|
|
|
|
|
|
|
|
|
# do a simple majority voting for generated slots
|
|
|
|
|
# if there's no simple majority, pick the slot with the highest probability
|
|
|
|
|
word_counter = collections.Counter(gen_words)
|
|
|
|
|
max_slot = max(word_counter, key=word_counter.get)
|
|
|
|
|
max_slot_count = word_counter[max_slot]
|
|
|
|
|
|
|
|
|
|
if max_slot_count >= 3 or (max_slot_count == 2 and len(word_counter) > 2):
|
|
|
|
|
# winner slot (simple majority)
|
|
|
|
|
generated_word = max_slot
|
|
|
|
|
else:
|
|
|
|
|
# generated slot with the highest probability
|
|
|
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
|
|
|
|
|
|
|
|
# return the generated slot
|
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -93,6 +116,8 @@ def main():
|
|
|
|
|
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
|
|
|
|
parser.add_argument("--with_prompt_ensemble", action="store_true",
|
|
|
|
|
help="Flag for enabling/disabling prompt ensembling while generating")
|
|
|
|
|
parser.add_argument("--ensemble_pick_slot", type=str, default=SLOT_PICK_SIMPLE_MAJORITY,
|
|
|
|
|
help="Flag for setting the algorithm to pick the generated slot from ensemble outputs")
|
|
|
|
|
parser.add_argument("--with_answered_prompts", action="store_true",
|
|
|
|
|
help="Flag to enable/disable the use of answered prompts while generating")
|
|
|
|
|
|
|
|
|
|
|