From 4a1893ee5037d899f07c2ef5f2eb689464a6fc0a Mon Sep 17 00:00:00 2001 From: Pavan Mandava Date: Wed, 16 Nov 2022 14:51:28 +0100 Subject: [PATCH] Added Prompt Ensemble templates --- README.md | 4 +- prompt-learning/prompt_train.py | 81 +++++++++++++++++++++++++++++++-- prompt-learning/prompt_utils.py | 41 +++++++++++++---- 3 files changed, 111 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 3a13d33..eecff79 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ After following the environment setup steps in the previous [section](#environme Change directory to `baseline` and install the requirements. Make sure the correct baseline conda environment is activated before installing the requirements. ```shell cd baseline -pip install requirements.txt +pip install -r requirements.txt ``` ### Train the baseline model @@ -178,7 +178,7 @@ After following the environment setup steps in the previous [section](#environme Change directory to `prompt-learning` and install the requirements. Make sure the correct prompt-learning `conda` environment is activated before installing the requirements. ```shell cd prompt-learning -pip install requirements.txt +pip install - requirements.txt ``` ### Train the prompt model diff --git a/prompt-learning/prompt_train.py b/prompt-learning/prompt_train.py index 0721727..ed0fa62 100644 --- a/prompt-learning/prompt_train.py +++ b/prompt-learning/prompt_train.py @@ -13,6 +13,7 @@ from prompt_utils import get_value_based_prompt from prompt_utils import get_prompt_for_training from prompt_utils import TYPE_VALUE_BASED_PROMPT from prompt_utils import TYPE_INVERSE_PROMPT +from prompt_utils import TYPE_PROMPT_ENSEMBLE, PROMPT_WEIGHT from metrics import PromptDSTEvaluator from datetime import datetime @@ -42,6 +43,8 @@ def main(): help="Flag for enabling/disabling inverse prompt during training") parser.add_argument("--inverse_prompt_weight", default=0.1, type=float, help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)") + parser.add_argument("--with_prompt_ensemble", action="store_true", + help="Flag for enabling/disabling prompt ensembling during training") parser.add_argument("--validation_file", default="", type=str, help="Validation file for evaluating model after each epoch") @@ -111,7 +114,8 @@ def main(): for slot, value in item['belief_states']: # train/generate using value-based prompt first - loss, gen_slot = train_prompting(history=history, + loss, gen_slot = train_prompting(args=args, + history=history, slot_value_pair=(slot, value), prompt_type=TYPE_VALUE_BASED_PROMPT, tokenizer=tokenizer, @@ -124,7 +128,8 @@ def main(): generated_slot = gen_slot.strip().lower() # train slot generation using inverse prompt - inv_loss, _ = train_prompting(history=history, + inv_loss, _ = train_prompting(args=args, + history=history, slot_value_pair=(generated_slot, value), prompt_type=TYPE_INVERSE_PROMPT, tokenizer=tokenizer, @@ -231,7 +236,18 @@ def set_seed(args): torch.cuda.manual_seed_all(args.seed) -def train_prompting(history, slot_value_pair, prompt_type, tokenizer, model, device): +def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device): + # slot_value_pair = (slot, value) + + # use prompt ensemble when set in the args + if prompt_type is TYPE_VALUE_BASED_PROMPT and args.with_prompt_ensemble: + return train_prompt_ensemble(history=history, + slot_value_pair=slot_value_pair, + prompt_type=TYPE_PROMPT_ENSEMBLE, + tokenizer=tokenizer, + model=model, + device=device) + # get prompt for training based on "type" prompt = get_prompt_for_training(prompt_type, slot_value_pair) @@ -271,6 +287,65 @@ def train_prompting(history, slot_value_pair, prompt_type, tokenizer, model, dev return loss, generated_word +def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, model, device): + # slot_value_pair = (slot, value) + # get list of prompts for training + prompts = get_prompt_for_training(prompt_type, slot_value_pair) + + # return total loss + total_loss = None + + gen_probs, gen_words = [], [] + + # iterate through each prompt + for prompt in prompts: + + # combine history and prompt + input_prompt = history + prompt + + # encode the history & prompt using tokenizer + encoded_prompt = tokenizer(input_prompt, return_tensors="pt") + encoded_prompt.to(device) + + # get the last token id + # this could be a slot or value depending on prompt type + last_token = encoded_prompt['input_ids'][:, -1:] + last_token.to(device) + + # model outputs + outputs = model(**encoded_prompt) + + # get last token logits [-2 -> for last but one item] + logits = outputs.logits[:, -2, :] + + # softmax probabilities + probs = torch.nn.functional.softmax(logits, dim=-1) + + # last token generation probability + last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1) + # weighted probability + token_prob = PROMPT_WEIGHT * last_token_prob + loss = torch.negative(torch.log(token_prob)) + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + # generated slot + # find the token with the highest probability, this will be the generated word + gen_word_token = torch.argmax(logits, dim=-1) + gen_word_prob = torch.gather(probs, 1, gen_word_token[:, None]).squeeze(-1) + gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip() + + # add the generated word and probs to list + gen_probs.append(gen_word_prob) + gen_words.append(gen_word) + + generated_word = gen_words.index(max(gen_probs)) + # loss is the log of probability + return total_loss, generated_word + + # Use this for generating next word (Validation after each epoch) def generate_slot_from_prompt(history, value, tokenizer, model, device): # get prompt for training based on "type" diff --git a/prompt-learning/prompt_utils.py b/prompt-learning/prompt_utils.py index da18ede..b6a39c6 100644 --- a/prompt-learning/prompt_utils.py +++ b/prompt-learning/prompt_utils.py @@ -2,29 +2,50 @@ from string import Template TYPE_VALUE_BASED_PROMPT = "value-based" TYPE_INVERSE_PROMPT = "inverse-prompt" +TYPE_PROMPT_ENSEMBLE = "prompt-ensemble" + +PROMPT_WEIGHT = 0.25 PROMPT_TEMPLATES = { "value-based": { "training": "belief states: value = $value, slot = $slot", "generate": "belief states: value = $value, slot =" }, - "inverse-prompt": { - "training": "belief states: slot = $slot, value = $value", - "generate": "belief states: slot = $slot, value =" + "inverse-prompt": "belief states: slot = $slot, value = $value", + "prompt-ensemble": { + "training": { + "p1": "belief states: value = $value, slot = $slot", + "p2": "belief states: $value = $slot", + "p3": "$value is of slot type $slot", + "p4": "$value is the value of $slot" + }, + "generate": { + "p1": "belief states: value = $value, slot =", + "p2": "belief states: $value =", + "p3": "$value is of slot type", + "p4": "$value is the value of" + } } } def get_prompt_for_training(typ, slot_value): - template = Template(PROMPT_TEMPLATES[typ]["training"]) - return template.substitute(slot=slot_value[0], value=slot_value[1]) + if typ == TYPE_INVERSE_PROMPT: + template = Template(PROMPT_TEMPLATES[typ]) + return template.substitute(slot=slot_value[0], value=slot_value[1]) + else: + template = PROMPT_TEMPLATES[typ]['training'] + if isinstance(template, str): + return Template(template).substitute(slot=slot_value[0], value=slot_value[1]) + elif isinstance(template, dict): + template_list = template.values() + prompt_list = [] + for template_str in template_list: + template = Template(template_str) + prompt_list.append(template.substitute(slot=slot_value[0], value=slot_value[1])) + return prompt_list def get_value_based_prompt(value): template = Template(PROMPT_TEMPLATES["value-based"]["generate"]) return template.substitute(value=value) - - -def get_inverse_prompt(slot): - template = Template(PROMPT_TEMPLATES["inverse-prompt"]["generate"]) - return template.substitute(slot=slot)