@ -13,6 +13,7 @@ from prompt_utils import get_value_based_prompt
from prompt_utils import get_prompt_for_training
from prompt_utils import get_prompt_for_training
from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT
from prompt_utils import TYPE_PROMPT_ENSEMBLE , PROMPT_WEIGHT
from metrics import PromptDSTEvaluator
from metrics import PromptDSTEvaluator
from datetime import datetime
from datetime import datetime
@ -42,6 +43,8 @@ def main():
help = " Flag for enabling/disabling inverse prompt during training " )
help = " Flag for enabling/disabling inverse prompt during training " )
parser . add_argument ( " --inverse_prompt_weight " , default = 0.1 , type = float ,
parser . add_argument ( " --inverse_prompt_weight " , default = 0.1 , type = float ,
help = " Weight to adjust the influence of Inverse Prompt, decimal (0,1) " )
help = " Weight to adjust the influence of Inverse Prompt, decimal (0,1) " )
parser . add_argument ( " --with_prompt_ensemble " , action = " store_true " ,
help = " Flag for enabling/disabling prompt ensembling during training " )
parser . add_argument ( " --validation_file " , default = " " , type = str ,
parser . add_argument ( " --validation_file " , default = " " , type = str ,
help = " Validation file for evaluating model after each epoch " )
help = " Validation file for evaluating model after each epoch " )
@ -111,7 +114,8 @@ def main():
for slot , value in item [ ' belief_states ' ] :
for slot , value in item [ ' belief_states ' ] :
# train/generate using value-based prompt first
# train/generate using value-based prompt first
loss , gen_slot = train_prompting ( history = history ,
loss , gen_slot = train_prompting ( args = args ,
history = history ,
slot_value_pair = ( slot , value ) ,
slot_value_pair = ( slot , value ) ,
prompt_type = TYPE_VALUE_BASED_PROMPT ,
prompt_type = TYPE_VALUE_BASED_PROMPT ,
tokenizer = tokenizer ,
tokenizer = tokenizer ,
@ -124,7 +128,8 @@ def main():
generated_slot = gen_slot . strip ( ) . lower ( )
generated_slot = gen_slot . strip ( ) . lower ( )
# train slot generation using inverse prompt
# train slot generation using inverse prompt
inv_loss , _ = train_prompting ( history = history ,
inv_loss , _ = train_prompting ( args = args ,
history = history ,
slot_value_pair = ( generated_slot , value ) ,
slot_value_pair = ( generated_slot , value ) ,
prompt_type = TYPE_INVERSE_PROMPT ,
prompt_type = TYPE_INVERSE_PROMPT ,
tokenizer = tokenizer ,
tokenizer = tokenizer ,
@ -231,7 +236,18 @@ def set_seed(args):
torch . cuda . manual_seed_all ( args . seed )
torch . cuda . manual_seed_all ( args . seed )
def train_prompting ( history , slot_value_pair , prompt_type , tokenizer , model , device ) :
def train_prompting ( args , history , slot_value_pair , prompt_type , tokenizer , model , device ) :
# slot_value_pair = (slot, value)
# use prompt ensemble when set in the args
if prompt_type is TYPE_VALUE_BASED_PROMPT and args . with_prompt_ensemble :
return train_prompt_ensemble ( history = history ,
slot_value_pair = slot_value_pair ,
prompt_type = TYPE_PROMPT_ENSEMBLE ,
tokenizer = tokenizer ,
model = model ,
device = device )
# get prompt for training based on "type"
# get prompt for training based on "type"
prompt = get_prompt_for_training ( prompt_type , slot_value_pair )
prompt = get_prompt_for_training ( prompt_type , slot_value_pair )
@ -271,6 +287,65 @@ def train_prompting(history, slot_value_pair, prompt_type, tokenizer, model, dev
return loss , generated_word
return loss , generated_word
def train_prompt_ensemble ( history , slot_value_pair , prompt_type , tokenizer , model , device ) :
# slot_value_pair = (slot, value)
# get list of prompts for training
prompts = get_prompt_for_training ( prompt_type , slot_value_pair )
# return total loss
total_loss = None
gen_probs , gen_words = [ ] , [ ]
# iterate through each prompt
for prompt in prompts :
# combine history and prompt
input_prompt = history + prompt
# encode the history & prompt using tokenizer
encoded_prompt = tokenizer ( input_prompt , return_tensors = " pt " )
encoded_prompt . to ( device )
# get the last token id
# this could be a slot or value depending on prompt type
last_token = encoded_prompt [ ' input_ids ' ] [ : , - 1 : ]
last_token . to ( device )
# model outputs
outputs = model ( * * encoded_prompt )
# get last token logits [-2 -> for last but one item]
logits = outputs . logits [ : , - 2 , : ]
# softmax probabilities
probs = torch . nn . functional . softmax ( logits , dim = - 1 )
# last token generation probability
last_token_prob = torch . gather ( probs , 1 , last_token ) . squeeze ( - 1 )
# weighted probability
token_prob = PROMPT_WEIGHT * last_token_prob
loss = torch . negative ( torch . log ( token_prob ) )
if total_loss is None :
total_loss = loss
else :
total_loss + = loss
# generated slot
# find the token with the highest probability, this will be the generated word
gen_word_token = torch . argmax ( logits , dim = - 1 )
gen_word_prob = torch . gather ( probs , 1 , gen_word_token [ : , None ] ) . squeeze ( - 1 )
gen_word = tokenizer . decode ( gen_word_token , skip_special_tokens = True ) . strip ( )
# add the generated word and probs to list
gen_probs . append ( gen_word_prob )
gen_words . append ( gen_word )
generated_word = gen_words . index ( max ( gen_probs ) )
# loss is the log of probability
return total_loss , generated_word
# Use this for generating next word (Validation after each epoch)
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt ( history , value , tokenizer , model , device ) :
def generate_slot_from_prompt ( history , value , tokenizer , model , device ) :
# get prompt for training based on "type"
# get prompt for training based on "type"