diff --git a/prompt-learning/prompt_decode.py b/prompt-learning/prompt_decode.py index de2eef9..22aa1ec 100644 --- a/prompt-learning/prompt_decode.py +++ b/prompt-learning/prompt_decode.py @@ -1,4 +1,6 @@ import argparse +import collections + import numpy as np import os import json @@ -11,6 +13,9 @@ from prompt_utils import get_ensemble_prompts from metrics import PromptDSTEvaluator from datetime import datetime +SLOT_PICK_HIGHEST_PROB = "highest" +SLOT_PICK_SIMPLE_MAJORITY = "majority" + def set_seed(args): np.random.seed(args.seed) @@ -19,11 +24,11 @@ def set_seed(args): torch.cuda.manual_seed_all(args.seed) -# Use this for generating next word (Testing) +# Use this for generating slot <- Inference (Testing) def generate_slot_from_prompt(args, history, value, tokenizer, model, device): # check if prompt ensemble is enabled in arguments if args.with_prompt_ensemble: - return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) + return generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device) # get value-based prompt for generating slots prompt = get_value_based_prompt(value) @@ -43,7 +48,7 @@ def generate_slot_from_prompt(args, history, value, tokenizer, model, device): return generated_word.strip().lower() -def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device): +def generate_slot_with_prompt_ensemble(args, history, value, tokenizer, model, device): # get prompts for ensemble generation prompts = get_ensemble_prompts(value) @@ -73,9 +78,27 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) gen_probs.append(gen_prob.item()) gen_words.append(gen_word) - # return word with the highest probability - generated_word = gen_words[gen_probs.index(max(gen_probs))] - return generated_word.strip().lower() + if args.ensemble_pick_slot == SLOT_PICK_HIGHEST_PROB: + # return the slot with the highest probability + generated_word = gen_words[gen_probs.index(max(gen_probs))] + return generated_word.strip().lower() + elif args.ensemble_pick_slot == SLOT_PICK_SIMPLE_MAJORITY: + + # do a simple majority voting for generated slots + # if there's no simple majority, pick the slot with the highest probability + word_counter = collections.Counter(gen_words) + max_slot = max(word_counter, key=word_counter.get) + max_slot_count = word_counter[max_slot] + + if max_slot_count >= 3 or (max_slot_count == 2 and len(word_counter) > 2): + # winner slot (simple majority) + generated_word = max_slot + else: + # generated slot with the highest probability + generated_word = gen_words[gen_probs.index(max(gen_probs))] + + # return the generated slot + return generated_word.strip().lower() def main(): @@ -93,6 +116,8 @@ def main(): parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument("--with_prompt_ensemble", action="store_true", help="Flag for enabling/disabling prompt ensembling while generating") + parser.add_argument("--ensemble_pick_slot", type=str, default=SLOT_PICK_SIMPLE_MAJORITY, + help="Flag for setting the algorithm to pick the generated slot from ensemble outputs") parser.add_argument("--with_answered_prompts", action="store_true", help="Flag to enable/disable the use of answered prompts while generating") diff --git a/prompt-learning/prompt_utils.py b/prompt-learning/prompt_utils.py index 87e589f..568d3ee 100644 --- a/prompt-learning/prompt_utils.py +++ b/prompt-learning/prompt_utils.py @@ -6,13 +6,18 @@ TYPE_PROMPT_ENSEMBLE = "prompt-ensemble" PROMPT_WEIGHT = 0.25 +INVERSE_PROMPTS = { + "i1": "belief states: $slot = $value", + "i2": "belief states: slot = $slot, value = $value", +} + PROMPT_TEMPLATES = { "value-based": { "training": "belief states: value = $value, slot = $slot", "generate": "belief states: value = $value, slot =" }, "inverse-prompt": { - "training": "belief states: $slot = $value", + "training": INVERSE_PROMPTS["i2"], }, "prompt-ensemble": { "training": { diff --git a/prompt-learning/train_prompting.sh b/prompt-learning/train_prompting.sh index 13a8f8e..b2963b1 100644 --- a/prompt-learning/train_prompting.sh +++ b/prompt-learning/train_prompting.sh @@ -49,10 +49,11 @@ SAVE_DIR="${SAVED_MODELS_PROMPT}"/"${experiment_folder}" echo "Trained Models (epochs) will be saved in ${SAVE_DIR}" # different number of epoch for different training sets +# when using prompt ensemble for training, preferably use more number of epochs. if [ "$data_split" = "5-dpd" ] || [ "$data_split" = "10-dpd" ]; then - epochs=7 + epochs=5 else - epochs=10 + epochs=8 fi python prompt_train.py \