Added Prompt Ensemble templates

main
Pavan Mandava 3 years ago
parent f2255db776
commit 4a1893ee50

@ -117,7 +117,7 @@ After following the environment setup steps in the previous [section](#environme
Change directory to `baseline` and install the requirements. Make sure the correct baseline conda environment is activated before installing the requirements.
```shell
cd baseline
pip install requirements.txt
pip install -r requirements.txt
```
### Train the baseline model
@ -178,7 +178,7 @@ After following the environment setup steps in the previous [section](#environme
Change directory to `prompt-learning` and install the requirements. Make sure the correct prompt-learning `conda` environment is activated before installing the requirements.
```shell
cd prompt-learning
pip install requirements.txt
pip install - requirements.txt
```
### Train the prompt model

@ -13,6 +13,7 @@ from prompt_utils import get_value_based_prompt
from prompt_utils import get_prompt_for_training
from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT
from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT
from metrics import PromptDSTEvaluator
from datetime import datetime
@ -42,6 +43,8 @@ def main():
help="Flag for enabling/disabling inverse prompt during training")
parser.add_argument("--inverse_prompt_weight", default=0.1, type=float,
help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)")
parser.add_argument("--with_prompt_ensemble", action="store_true",
help="Flag for enabling/disabling prompt ensembling during training")
parser.add_argument("--validation_file", default="", type=str,
help="Validation file for evaluating model after each epoch")
@ -111,7 +114,8 @@ def main():
for slot, value in item['belief_states']:
# train/generate using value-based prompt first
loss, gen_slot = train_prompting(history=history,
loss, gen_slot = train_prompting(args=args,
history=history,
slot_value_pair=(slot, value),
prompt_type=TYPE_VALUE_BASED_PROMPT,
tokenizer=tokenizer,
@ -124,7 +128,8 @@ def main():
generated_slot = gen_slot.strip().lower()
# train slot generation using inverse prompt
inv_loss, _ = train_prompting(history=history,
inv_loss, _ = train_prompting(args=args,
history=history,
slot_value_pair=(generated_slot, value),
prompt_type=TYPE_INVERSE_PROMPT,
tokenizer=tokenizer,
@ -231,7 +236,18 @@ def set_seed(args):
torch.cuda.manual_seed_all(args.seed)
def train_prompting(history, slot_value_pair, prompt_type, tokenizer, model, device):
def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device):
# slot_value_pair = (slot, value)
# use prompt ensemble when set in the args
if prompt_type is TYPE_VALUE_BASED_PROMPT and args.with_prompt_ensemble:
return train_prompt_ensemble(history=history,
slot_value_pair=slot_value_pair,
prompt_type=TYPE_PROMPT_ENSEMBLE,
tokenizer=tokenizer,
model=model,
device=device)
# get prompt for training based on "type"
prompt = get_prompt_for_training(prompt_type, slot_value_pair)
@ -271,6 +287,65 @@ def train_prompting(history, slot_value_pair, prompt_type, tokenizer, model, dev
return loss, generated_word
def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, model, device):
# slot_value_pair = (slot, value)
# get list of prompts for training
prompts = get_prompt_for_training(prompt_type, slot_value_pair)
# return total loss
total_loss = None
gen_probs, gen_words = [], []
# iterate through each prompt
for prompt in prompts:
# combine history and prompt
input_prompt = history + prompt
# encode the history & prompt using tokenizer
encoded_prompt = tokenizer(input_prompt, return_tensors="pt")
encoded_prompt.to(device)
# get the last token id
# this could be a slot or value depending on prompt type
last_token = encoded_prompt['input_ids'][:, -1:]
last_token.to(device)
# model outputs
outputs = model(**encoded_prompt)
# get last token logits [-2 -> for last but one item]
logits = outputs.logits[:, -2, :]
# softmax probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# last token generation probability
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
# weighted probability
token_prob = PROMPT_WEIGHT * last_token_prob
loss = torch.negative(torch.log(token_prob))
if total_loss is None:
total_loss = loss
else:
total_loss += loss
# generated slot
# find the token with the highest probability, this will be the generated word
gen_word_token = torch.argmax(logits, dim=-1)
gen_word_prob = torch.gather(probs, 1, gen_word_token[:, None]).squeeze(-1)
gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
# add the generated word and probs to list
gen_probs.append(gen_word_prob)
gen_words.append(gen_word)
generated_word = gen_words.index(max(gen_probs))
# loss is the log of probability
return total_loss, generated_word
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt(history, value, tokenizer, model, device):
# get prompt for training based on "type"

@ -2,29 +2,50 @@ from string import Template
TYPE_VALUE_BASED_PROMPT = "value-based"
TYPE_INVERSE_PROMPT = "inverse-prompt"
TYPE_PROMPT_ENSEMBLE = "prompt-ensemble"
PROMPT_WEIGHT = 0.25
PROMPT_TEMPLATES = {
"value-based": {
"training": "belief states: value = $value, slot = $slot",
"generate": "belief states: value = $value, slot ="
},
"inverse-prompt": {
"training": "belief states: slot = $slot, value = $value",
"generate": "belief states: slot = $slot, value ="
"inverse-prompt": "belief states: slot = $slot, value = $value",
"prompt-ensemble": {
"training": {
"p1": "belief states: value = $value, slot = $slot",
"p2": "belief states: $value = $slot",
"p3": "$value is of slot type $slot",
"p4": "$value is the value of $slot"
},
"generate": {
"p1": "belief states: value = $value, slot =",
"p2": "belief states: $value =",
"p3": "$value is of slot type",
"p4": "$value is the value of"
}
}
}
def get_prompt_for_training(typ, slot_value):
template = Template(PROMPT_TEMPLATES[typ]["training"])
if typ == TYPE_INVERSE_PROMPT:
template = Template(PROMPT_TEMPLATES[typ])
return template.substitute(slot=slot_value[0], value=slot_value[1])
else:
template = PROMPT_TEMPLATES[typ]['training']
if isinstance(template, str):
return Template(template).substitute(slot=slot_value[0], value=slot_value[1])
elif isinstance(template, dict):
template_list = template.values()
prompt_list = []
for template_str in template_list:
template = Template(template_str)
prompt_list.append(template.substitute(slot=slot_value[0], value=slot_value[1]))
return prompt_list
def get_value_based_prompt(value):
template = Template(PROMPT_TEMPLATES["value-based"]["generate"])
return template.substitute(value=value)
def get_inverse_prompt(slot):
template = Template(PROMPT_TEMPLATES["inverse-prompt"]["generate"])
return template.substitute(slot=slot)

Loading…
Cancel
Save