From ce5bf91cd15d3406575a65867185ed56aec78adb Mon Sep 17 00:00:00 2001 From: Pavan Mandava Date: Wed, 16 Nov 2022 23:05:03 +0100 Subject: [PATCH] Prompt Ensemble for generations --- prompt-learning/prompt_decode.py | 58 +++++++++++++++++++++++++---- prompt-learning/prompt_train.py | 59 ++++++++++++++++++++++++++---- prompt-learning/prompt_utils.py | 12 +++++- prompt-learning/test_prompting.sh | 3 +- prompt-learning/train_prompting.sh | 3 +- 5 files changed, 118 insertions(+), 17 deletions(-) diff --git a/prompt-learning/prompt_decode.py b/prompt-learning/prompt_decode.py index 931be32..6ca5b52 100644 --- a/prompt-learning/prompt_decode.py +++ b/prompt-learning/prompt_decode.py @@ -7,6 +7,7 @@ from transformers import AutoModelForCausalLM, GPT2Tokenizer from dataset import PromptDstDataset from tqdm.auto import tqdm from prompt_utils import get_value_based_prompt +from prompt_utils import get_ensemble_prompts from metrics import PromptDSTEvaluator from datetime import datetime @@ -18,9 +19,12 @@ def set_seed(args): torch.cuda.manual_seed_all(args.seed) -# Use this for generating next word (Validation after each epoch) -def generate_slot_from_prompt(history, value, tokenizer, model, device): - # get prompt for training based on "type" +# Use this for generating next word (Testing) +def generate_slot_from_prompt(args, history, value, tokenizer, model, device): + # check if prompt ensemble is enabled in arguments + if args.with_prompt_ensemble: + return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) + # get prompt for generating slots prompt = get_value_based_prompt(value) # combine history and prompt @@ -32,12 +36,47 @@ def generate_slot_from_prompt(history, value, tokenizer, model, device): # generate 1 token outputs = model.generate(**encoded_prompt, max_new_tokens=1) - gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:] - generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True) + gen_token_id = outputs[:, encoded_prompt['input_ids'].shape[-1]:] + generated_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True) return generated_word.strip().lower() +def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device): + # get prompts for ensemble generation + prompts = get_ensemble_prompts(value) + + gen_probs, gen_words = [], [] + + for prompt in prompts: + # combine history and prompt + prompt = history + prompt + + # encode the history & prompt + encoded_prompt = tokenizer(prompt, return_tensors="pt") + encoded_prompt.to(device) + + # generate 1 token + outputs = model.generate(**encoded_prompt, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=1) + + gen_token_id = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:] + gen_word = tokenizer.decode(gen_token_id.item(), skip_special_tokens=True).strip() + + probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1) + gen_prob = torch.gather(probs, 1, gen_token_id).squeeze(-1) + + # add the generated word and probs to list + gen_probs.append(gen_prob.item()) + gen_words.append(gen_word) + + # return word with the highest probability + generated_word = gen_words[gen_probs.index(max(gen_probs))] + return generated_word.strip().lower() + + def main(): parser = argparse.ArgumentParser() @@ -51,6 +90,8 @@ def main(): # Optional parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") + parser.add_argument("--with_prompt_ensemble", action="store_true", + help="Flag for enabling/disabling prompt ensembling while generating") # parse the arguments args = parser.parse_args() @@ -78,6 +119,9 @@ def main(): # set eval mode model.eval() + tqdm.write(str('Generating Slots...')) + tqdm.write(str('Args: [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) + # outputs array -> to be saved to output_dir outputs = [] @@ -96,9 +140,9 @@ def main(): # iterate through (slot, value) pairs and generate each slot using value for value in item['values']: - # generate slot using value-based prompt - generated_slot = generate_slot_from_prompt(history=history, + generated_slot = generate_slot_from_prompt(args=args, + history=history, value=value, tokenizer=tokenizer, model=model, diff --git a/prompt-learning/prompt_train.py b/prompt-learning/prompt_train.py index ed0fa62..db9cde6 100644 --- a/prompt-learning/prompt_train.py +++ b/prompt-learning/prompt_train.py @@ -10,6 +10,7 @@ from torch.optim import AdamW from transformers import get_scheduler from tqdm.auto import tqdm from prompt_utils import get_value_based_prompt +from prompt_utils import get_ensemble_prompts from prompt_utils import get_prompt_for_training from prompt_utils import TYPE_VALUE_BASED_PROMPT from prompt_utils import TYPE_INVERSE_PROMPT @@ -102,6 +103,8 @@ def main(): # set the model in training mode model.train() + tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) + # training loop for epoch in range(args.num_epochs): running_loss = 0.0 @@ -175,6 +178,8 @@ def main(): # if validation file is provided, run evaluation here (after each epoch) if args.validation_file and validation_data is not None: + tqdm.write(str('Validation In Progress...[with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) + # set tqdm progress bars for testing progress validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False) @@ -197,7 +202,8 @@ def main(): true_states[slot] = value # generate slot using value-based prompt - generated_slot = generate_slot_from_prompt(history=history, + generated_slot = generate_slot_from_prompt(args=args, + history=history, value=value, tokenizer=tokenizer, model=model, @@ -338,17 +344,21 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip() # add the generated word and probs to list - gen_probs.append(gen_word_prob) + gen_probs.append(gen_word_prob.item()) gen_words.append(gen_word) - generated_word = gen_words.index(max(gen_probs)) + generated_word = gen_words[gen_probs.index(max(gen_probs))] # loss is the log of probability return total_loss, generated_word # Use this for generating next word (Validation after each epoch) -def generate_slot_from_prompt(history, value, tokenizer, model, device): - # get prompt for training based on "type" +def generate_slot_from_prompt(args, history, value, tokenizer, model, device): + # check if prompt ensemble is enabled in arguments + if args.with_prompt_ensemble: + return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) + + # get value-based prompt for generation prompt = get_value_based_prompt(value) # combine history and prompt @@ -360,11 +370,46 @@ def generate_slot_from_prompt(history, value, tokenizer, model, device): # generate 1 token outputs = model.generate(**encoded_prompt, max_new_tokens=1) - gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:] - generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True) + gen_token = outputs[:, encoded_prompt['input_ids'].shape[-1]:] + generated_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True) return generated_word.strip().lower() +def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device): + # get prompts for ensemble generation + prompts = get_ensemble_prompts(value) + + gen_probs, gen_words = [], [] + + for prompt in prompts: + # combine history and prompt + prompt = history + prompt + + # encode the history & prompt + encoded_prompt = tokenizer(prompt, return_tensors="pt") + encoded_prompt.to(device) + + # generate 1 token + outputs = model.generate(**encoded_prompt, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=1) + + gen_token = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:] + gen_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True).strip() + + probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1) + gen_prob = torch.gather(probs, 1, gen_token).squeeze(-1) + + # add the generated word and probs to list + gen_probs.append(gen_prob.item()) + gen_words.append(gen_word) + + # return word with the highest probability + generated_word = gen_words[gen_probs.index(max(gen_probs))] + return generated_word.strip().lower() + + if __name__ == "__main__": main() diff --git a/prompt-learning/prompt_utils.py b/prompt-learning/prompt_utils.py index b6a39c6..19ccaae 100644 --- a/prompt-learning/prompt_utils.py +++ b/prompt-learning/prompt_utils.py @@ -47,5 +47,15 @@ def get_prompt_for_training(typ, slot_value): def get_value_based_prompt(value): - template = Template(PROMPT_TEMPLATES["value-based"]["generate"]) + template = Template(PROMPT_TEMPLATES[TYPE_VALUE_BASED_PROMPT]["generate"]) return template.substitute(value=value) + + +def get_ensemble_prompts(value): + templates = PROMPT_TEMPLATES[TYPE_PROMPT_ENSEMBLE]['generate'] + template_list = templates.values() + prompt_list = [] + for template_str in template_list: + template = Template(template_str) + prompt_list.append(template.substitute(value=value)) + return prompt_list diff --git a/prompt-learning/test_prompting.sh b/prompt-learning/test_prompting.sh index 8ff2d8d..6e1543d 100644 --- a/prompt-learning/test_prompting.sh +++ b/prompt-learning/test_prompting.sh @@ -51,4 +51,5 @@ mkdir -p "${OUTPUTS_DIR}" python prompt_decode.py \ --output_dir="${OUTPUTS_DIR}" \ --tuned_model_path="${FINE_TUNED_MODEL_PATH}" \ ---test_data_file="${TEST_DATA_FILE}" \ No newline at end of file +--test_data_file="${TEST_DATA_FILE}" \ +--with_prompt_ensemble \ No newline at end of file diff --git a/prompt-learning/train_prompting.sh b/prompt-learning/train_prompting.sh index 5c12053..ab9c689 100644 --- a/prompt-learning/train_prompting.sh +++ b/prompt-learning/train_prompting.sh @@ -56,4 +56,5 @@ python prompt_train.py \ --num_epochs 10 \ --learning_rate 5e-5 \ --with_inverse_prompt \ ---inverse_prompt_weight 0.1 \ No newline at end of file +--inverse_prompt_weight 0.1 \ +--with_prompt_ensemble \ No newline at end of file