Added majority picking for generated slots with prompt ensemble

main
Pavan Mandava 3 years ago
parent 99800d13f8
commit 3aaa1288f2

@ -1,4 +1,6 @@
import argparse
import collections
import numpy as np
import os
import json
@ -11,6 +13,9 @@ from prompt_utils import get_ensemble_prompts
from metrics import PromptDSTEvaluator
from datetime import datetime
SLOT_PICK_HIGHEST_PROB = "highest"
SLOT_PICK_SIMPLE_MAJORITY = "majority"
def set_seed(args):
np.random.seed(args.seed)
@ -19,11 +24,11 @@ def set_seed(args):
torch.cuda.manual_seed_all(args.seed)
# Use this for generating next word (Testing)
# Use this for generating slot <- Inference (Testing)
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
# check if prompt ensemble is enabled in arguments
if args.with_prompt_ensemble:
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
return generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device)
# get value-based prompt for generating slots
prompt = get_value_based_prompt(value)
@ -43,7 +48,7 @@ def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
return generated_word.strip().lower()
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
def generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device):
# get prompts for ensemble generation
prompts = get_ensemble_prompts(value)
@ -73,8 +78,26 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
gen_probs.append(gen_prob.item())
gen_words.append(gen_word)
# return word with the highest probability
if args.ensemble_pick_slot == SLOT_PICK_HIGHEST_PROB:
# return the slot with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
return generated_word.strip().lower()
elif args.ensemble_pick_slot == SLOT_PICK_SIMPLE_MAJORITY:
# do a simple majority voting for generated slots
# if there's no simple majority, pick the slot with the highest probability
word_counter = collections.Counter(gen_words)
max_slot = max(word_counter, key=word_counter.get)
max_slot_count = word_counter[max_slot]
if max_slot_count >= 3 or (max_slot_count == 2 and len(word_counter) > 2):
# winner slot (simple majority)
generated_word = max_slot
else:
# generated slot with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
# return the generated slot
return generated_word.strip().lower()
@ -93,6 +116,8 @@ def main():
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--with_prompt_ensemble", action="store_true",
help="Flag for enabling/disabling prompt ensembling while generating")
parser.add_argument("--ensemble_pick_slot", type=str, default=SLOT_PICK_SIMPLE_MAJORITY,
help="Flag for setting the algorithm to pick the generated slot from ensemble outputs")
parser.add_argument("--with_answered_prompts", action="store_true",
help="Flag to enable/disable the use of answered prompts while generating")

@ -6,13 +6,18 @@ TYPE_PROMPT_ENSEMBLE = "prompt-ensemble"
PROMPT_WEIGHT = 0.25
INVERSE_PROMPTS = {
"i1": "belief states: $slot = $value",
"i2": "belief states: slot = $slot, value = $value",
}
PROMPT_TEMPLATES = {
"value-based": {
"training": "belief states: value = $value, slot = $slot",
"generate": "belief states: value = $value, slot ="
},
"inverse-prompt": {
"training": "belief states: $slot = $value",
"training": INVERSE_PROMPTS["i2"],
},
"prompt-ensemble": {
"training": {

@ -49,10 +49,11 @@ SAVE_DIR="${SAVED_MODELS_PROMPT}"/"${experiment_folder}"
echo "Trained Models (epochs) will be saved in ${SAVE_DIR}"
# different number of epoch for different training sets
# when using prompt ensemble for training, preferably use more number of epochs.
if [ "$data_split" = "5-dpd" ] || [ "$data_split" = "10-dpd" ]; then
epochs=7
epochs=5
else
epochs=10
epochs=8
fi
python prompt_train.py \

Loading…
Cancel
Save