Prompt Ensemble for generations

main
Pavan Mandava 3 years ago
parent 4a1893ee50
commit ce5bf91cd1

@ -7,6 +7,7 @@ from transformers import AutoModelForCausalLM, GPT2Tokenizer
from dataset import PromptDstDataset
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from prompt_utils import get_ensemble_prompts
from metrics import PromptDSTEvaluator
from datetime import datetime
@ -18,9 +19,12 @@ def set_seed(args):
torch.cuda.manual_seed_all(args.seed)
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt(history, value, tokenizer, model, device):
# get prompt for training based on "type"
# Use this for generating next word (Testing)
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
# check if prompt ensemble is enabled in arguments
if args.with_prompt_ensemble:
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
# get prompt for generating slots
prompt = get_value_based_prompt(value)
# combine history and prompt
@ -32,12 +36,47 @@ def generate_slot_from_prompt(history, value, tokenizer, model, device):
# generate 1 token
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True)
gen_token_id = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True)
return generated_word.strip().lower()
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
# get prompts for ensemble generation
prompts = get_ensemble_prompts(value)
gen_probs, gen_words = [], []
for prompt in prompts:
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token
outputs = model.generate(**encoded_prompt,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=1)
gen_token_id = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:]
gen_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True).strip()
probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
gen_prob = torch.gather(probs, 1, gen_token_id).squeeze(-1)
# add the generated word and probs to list
gen_probs.append(gen_prob.item())
gen_words.append(gen_word)
# return word with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
return generated_word.strip().lower()
def main():
parser = argparse.ArgumentParser()
@ -51,6 +90,8 @@ def main():
# Optional
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--with_prompt_ensemble", action="store_true",
help="Flag for enabling/disabling prompt ensembling while generating")
# parse the arguments
args = parser.parse_args()
@ -78,6 +119,9 @@ def main():
# set eval mode
model.eval()
tqdm.write(str('Generating Slots...'))
tqdm.write(str('Args: [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
# outputs array -> to be saved to output_dir
outputs = []
@ -96,9 +140,9 @@ def main():
# iterate through (slot, value) pairs and generate each slot using value
for value in item['values']:
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(history=history,
generated_slot = generate_slot_from_prompt(args=args,
history=history,
value=value,
tokenizer=tokenizer,
model=model,

@ -10,6 +10,7 @@ from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from prompt_utils import get_ensemble_prompts
from prompt_utils import get_prompt_for_training
from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT
@ -102,6 +103,8 @@ def main():
# set the model in training mode
model.train()
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
# training loop
for epoch in range(args.num_epochs):
running_loss = 0.0
@ -175,6 +178,8 @@ def main():
# if validation file is provided, run evaluation here (after each epoch)
if args.validation_file and validation_data is not None:
tqdm.write(str('Validation In Progress...[with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
# set tqdm progress bars for testing progress
validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False)
@ -197,7 +202,8 @@ def main():
true_states[slot] = value
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(history=history,
generated_slot = generate_slot_from_prompt(args=args,
history=history,
value=value,
tokenizer=tokenizer,
model=model,
@ -338,17 +344,21 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode
gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
# add the generated word and probs to list
gen_probs.append(gen_word_prob)
gen_probs.append(gen_word_prob.item())
gen_words.append(gen_word)
generated_word = gen_words.index(max(gen_probs))
generated_word = gen_words[gen_probs.index(max(gen_probs))]
# loss is the log of probability
return total_loss, generated_word
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt(history, value, tokenizer, model, device):
# get prompt for training based on "type"
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
# check if prompt ensemble is enabled in arguments
if args.with_prompt_ensemble:
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
# get value-based prompt for generation
prompt = get_value_based_prompt(value)
# combine history and prompt
@ -360,9 +370,44 @@ def generate_slot_from_prompt(history, value, tokenizer, model, device):
# generate 1 token
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True)
gen_token = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True)
return generated_word.strip().lower()
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
# get prompts for ensemble generation
prompts = get_ensemble_prompts(value)
gen_probs, gen_words = [], []
for prompt in prompts:
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token
outputs = model.generate(**encoded_prompt,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=1)
gen_token = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:]
gen_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True).strip()
probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
gen_prob = torch.gather(probs, 1, gen_token).squeeze(-1)
# add the generated word and probs to list
gen_probs.append(gen_prob.item())
gen_words.append(gen_word)
# return word with the highest probability
generated_word = gen_words[gen_probs.index(max(gen_probs))]
return generated_word.strip().lower()

@ -47,5 +47,15 @@ def get_prompt_for_training(typ, slot_value):
def get_value_based_prompt(value):
template = Template(PROMPT_TEMPLATES["value-based"]["generate"])
template = Template(PROMPT_TEMPLATES[TYPE_VALUE_BASED_PROMPT]["generate"])
return template.substitute(value=value)
def get_ensemble_prompts(value):
templates = PROMPT_TEMPLATES[TYPE_PROMPT_ENSEMBLE]['generate']
template_list = templates.values()
prompt_list = []
for template_str in template_list:
template = Template(template_str)
prompt_list.append(template.substitute(value=value))
return prompt_list

@ -51,4 +51,5 @@ mkdir -p "${OUTPUTS_DIR}"
python prompt_decode.py \
--output_dir="${OUTPUTS_DIR}" \
--tuned_model_path="${FINE_TUNED_MODEL_PATH}" \
--test_data_file="${TEST_DATA_FILE}"
--test_data_file="${TEST_DATA_FILE}" \
--with_prompt_ensemble

@ -56,4 +56,5 @@ python prompt_train.py \
--num_epochs 10 \
--learning_rate 5e-5 \
--with_inverse_prompt \
--inverse_prompt_weight 0.1
--inverse_prompt_weight 0.1 \
--with_prompt_ensemble
Loading…
Cancel
Save