Added majority picking for generated slots with prompt ensemble

main
Pavan Mandava 3 years ago
parent 99800d13f8
commit 3aaa1288f2

@ -1,4 +1,6 @@
import argparse import argparse
import collections
import numpy as np import numpy as np
import os import os
import json import json
@ -11,6 +13,9 @@ from prompt_utils import get_ensemble_prompts
from metrics import PromptDSTEvaluator from metrics import PromptDSTEvaluator
from datetime import datetime from datetime import datetime
SLOT_PICK_HIGHEST_PROB = "highest"
SLOT_PICK_SIMPLE_MAJORITY = "majority"
def set_seed(args): def set_seed(args):
np.random.seed(args.seed) np.random.seed(args.seed)
@ -19,11 +24,11 @@ def set_seed(args):
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
# Use this for generating next word (Testing) # Use this for generating slot <- Inference (Testing)
def generate_slot_from_prompt(args, history, value, tokenizer, model, device): def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
# check if prompt ensemble is enabled in arguments # check if prompt ensemble is enabled in arguments
if args.with_prompt_ensemble: if args.with_prompt_ensemble:
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) return generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device)
# get value-based prompt for generating slots # get value-based prompt for generating slots
prompt = get_value_based_prompt(value) prompt = get_value_based_prompt(value)
@ -43,7 +48,7 @@ def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
return generated_word.strip().lower() return generated_word.strip().lower()
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device): def generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device):
# get prompts for ensemble generation # get prompts for ensemble generation
prompts = get_ensemble_prompts(value) prompts = get_ensemble_prompts(value)
@ -73,9 +78,27 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
gen_probs.append(gen_prob.item()) gen_probs.append(gen_prob.item())
gen_words.append(gen_word) gen_words.append(gen_word)
# return word with the highest probability if args.ensemble_pick_slot == SLOT_PICK_HIGHEST_PROB:
generated_word = gen_words[gen_probs.index(max(gen_probs))] # return the slot with the highest probability
return generated_word.strip().lower() generated_word = gen_words[gen_probs.index(max(gen_probs))]
return generated_word.strip().lower()
elif args.ensemble_pick_slot == SLOT_PICK_SIMPLE_MAJORITY:
# do a simple majority voting for generated slots
# if there's no simple majority, pick the slot with the highest probability
word_counter = collections.Counter(gen_words)
max_slot = max(word_counter, key=word_counter.get)
max_slot_count = word_counter[max_slot]
if max_slot_count >= 3 or (max_slot_count == 2 and len(word_counter) > 2):
# winner slot (simple majority)
generated_word = max_slot
else:
# generated slot with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
# return the generated slot
return generated_word.strip().lower()
def main(): def main():
@ -93,6 +116,8 @@ def main():
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--with_prompt_ensemble", action="store_true", parser.add_argument("--with_prompt_ensemble", action="store_true",
help="Flag for enabling/disabling prompt ensembling while generating") help="Flag for enabling/disabling prompt ensembling while generating")
parser.add_argument("--ensemble_pick_slot", type=str, default=SLOT_PICK_SIMPLE_MAJORITY,
help="Flag for setting the algorithm to pick the generated slot from ensemble outputs")
parser.add_argument("--with_answered_prompts", action="store_true", parser.add_argument("--with_answered_prompts", action="store_true",
help="Flag to enable/disable the use of answered prompts while generating") help="Flag to enable/disable the use of answered prompts while generating")

@ -6,13 +6,18 @@ TYPE_PROMPT_ENSEMBLE = "prompt-ensemble"
PROMPT_WEIGHT = 0.25 PROMPT_WEIGHT = 0.25
INVERSE_PROMPTS = {
"i1": "belief states: $slot = $value",
"i2": "belief states: slot = $slot, value = $value",
}
PROMPT_TEMPLATES = { PROMPT_TEMPLATES = {
"value-based": { "value-based": {
"training": "belief states: value = $value, slot = $slot", "training": "belief states: value = $value, slot = $slot",
"generate": "belief states: value = $value, slot =" "generate": "belief states: value = $value, slot ="
}, },
"inverse-prompt": { "inverse-prompt": {
"training": "belief states: $slot = $value", "training": INVERSE_PROMPTS["i2"],
}, },
"prompt-ensemble": { "prompt-ensemble": {
"training": { "training": {

@ -49,10 +49,11 @@ SAVE_DIR="${SAVED_MODELS_PROMPT}"/"${experiment_folder}"
echo "Trained Models (epochs) will be saved in ${SAVE_DIR}" echo "Trained Models (epochs) will be saved in ${SAVE_DIR}"
# different number of epoch for different training sets # different number of epoch for different training sets
# when using prompt ensemble for training, preferably use more number of epochs.
if [ "$data_split" = "5-dpd" ] || [ "$data_split" = "10-dpd" ]; then if [ "$data_split" = "5-dpd" ] || [ "$data_split" = "10-dpd" ]; then
epochs=7 epochs=5
else else
epochs=10 epochs=8
fi fi
python prompt_train.py \ python prompt_train.py \

Loading…
Cancel
Save