diff --git a/prompt-learning/prompt_decode.py b/prompt-learning/prompt_decode.py index 6ca5b52..8f14523 100644 --- a/prompt-learning/prompt_decode.py +++ b/prompt-learning/prompt_decode.py @@ -24,7 +24,8 @@ def generate_slot_from_prompt(args, history, value, tokenizer, model, device): # check if prompt ensemble is enabled in arguments if args.with_prompt_ensemble: return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) - # get prompt for generating slots + + # get value-based prompt for generating slots prompt = get_value_based_prompt(value) # combine history and prompt @@ -33,7 +34,7 @@ def generate_slot_from_prompt(args, history, value, tokenizer, model, device): encoded_prompt = tokenizer(prompt, return_tensors="pt") encoded_prompt.to(device) - # generate 1 token + # generate 1 token (max length of slot = 1) outputs = model.generate(**encoded_prompt, max_new_tokens=1) gen_token_id = outputs[:, encoded_prompt['input_ids'].shape[-1]:] @@ -56,7 +57,7 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) encoded_prompt = tokenizer(prompt, return_tensors="pt") encoded_prompt.to(device) - # generate 1 token + # generate 1 token (max length of slot = 1) outputs = model.generate(**encoded_prompt, return_dict_in_generate=True, output_scores=True, @@ -92,14 +93,27 @@ def main(): parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument("--with_prompt_ensemble", action="store_true", help="Flag for enabling/disabling prompt ensembling while generating") + parser.add_argument("--with_answered_prompts", action="store_true", + help="Flag to enable/disable the use of answered prompts while generating") # parse the arguments args = parser.parse_args() + # check for args.json file in the saved model path & check if trained with prompt ensemble + args_file = os.path.join(args.tuned_model_path, 'args.json') + if os.path.isfile(args_file): + args_dict = json.load(open(args_file)) + if 'with_prompt_ensemble' in args_dict: + args.with_prompt_ensemble = args_dict['with_prompt_ensemble'] + else: + print("No 'args.json' file found in the saved epoch dir!") + # setup CUDA device for training on GPU (if available) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.n_gpu = torch.cuda.device_count() + print('Generation Config :: ', json.dumps(vars(args), indent=2)) + # prepare model & tokenizer -> load pre-trained model tokenizer = GPT2Tokenizer.from_pretrained(args.tuned_model_path) model = AutoModelForCausalLM.from_pretrained(args.tuned_model_path, pad_token_id=tokenizer.eos_token_id) @@ -119,15 +133,14 @@ def main(): # set eval mode model.eval() - tqdm.write(str('Generating Slots...')) - tqdm.write(str('Args: [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']')) - # outputs array -> to be saved to output_dir outputs = [] # JGA metric evaluator = PromptDSTEvaluator() + tqdm.write(str('Generating slots now...')) + # iterate through test dataset and generate slots for item in dataset.dataset_items: history = item['history'] @@ -165,12 +178,15 @@ def main(): # compute JGA & print results evaluator.compute_joint_goal_accuracy() + # output file extension (for prompt ensemble & answered prompts) + out_ext = "_pe" if args.with_prompt_ensemble else "" + out_ext += "_pa" if args.with_answered_prompts else "" # save the outputs to output_dir if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) now = datetime.now() datetime_str = now.strftime("%Y%m%dT%H%M%S") - output_file = os.path.join(args.output_dir, 'outputs-{}.json'.format(datetime_str)) + output_file = os.path.join(args.output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str)) print('Saving Outputs file :: ', output_file) json.dump(outputs, open(output_file, 'w'), indent=2) diff --git a/prompt-learning/prompt_train.py b/prompt-learning/prompt_train.py index db9cde6..4de020a 100644 --- a/prompt-learning/prompt_train.py +++ b/prompt-learning/prompt_train.py @@ -46,6 +46,8 @@ def main(): help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)") parser.add_argument("--with_prompt_ensemble", action="store_true", help="Flag for enabling/disabling prompt ensembling during training") + parser.add_argument("--with_answered_prompts", action="store_true", + help="Flag to enable/disable the use of answered prompts while validation") parser.add_argument("--validation_file", default="", type=str, help="Validation file for evaluating model after each epoch") @@ -56,6 +58,10 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.n_gpu = torch.cuda.device_count() + # Print args + print('Training Args :: ', json.dumps(vars(args), indent=2)) + save_args(args.save_model_dir, args) + # prepare model & tokenizer -> load pre-trained model tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_model_path, do_lower_case=True) model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_path, pad_token_id=tokenizer.eos_token_id) @@ -166,6 +172,8 @@ def main(): output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1))) if not os.path.exists(output_dir): os.makedirs(output_dir) + # save training args for each epoch (useful when testing/generating) + save_args(output_dir, args) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, 'training_args.bin')) @@ -227,10 +235,14 @@ def main(): jga_score = evaluator.compute_joint_goal_accuracy(no_print=True) tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score))) + # output file extension (for prompt ensemble & answered prompts) + out_ext = "_pe" if args.with_prompt_ensemble else "" + out_ext += "_pa" if args.with_answered_prompts else "" + # save the outputs to trained epoch dir now = datetime.now() datetime_str = now.strftime("%Y%m%dT%H%M%S") - output_file = os.path.join(output_dir, 'outputs-{}.json'.format(datetime_str)) + output_file = os.path.join(output_dir, 'outputs{}-{}.json'.format(out_ext, datetime_str)) tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file))) json.dump(outputs, open(output_file, 'w'), indent=2) @@ -242,6 +254,13 @@ def set_seed(args): torch.cuda.manual_seed_all(args.seed) +def save_args(save_dir, args): + if not os.path.exists(save_dir): + os.makedirs(save_dir) + args_file = os.path.join(save_dir, 'args.json') + json.dump(vars(args), open(args_file, "w"), indent=2) + + def train_prompting(args, history, slot_value_pair, prompt_type, tokenizer, model, device): # slot_value_pair = (slot, value) @@ -298,8 +317,8 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode # get list of prompts for training prompts = get_prompt_for_training(prompt_type, slot_value_pair) - # return total loss - total_loss = None + # total probability of all prompt functions + total_prob = None gen_probs, gen_words = [], [] @@ -329,13 +348,15 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode # last token generation probability last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1) - # weighted probability + + # weighted probability | sum of all prompt weights must be equals to 1 token_prob = PROMPT_WEIGHT * last_token_prob - loss = torch.negative(torch.log(token_prob)) - if total_loss is None: - total_loss = loss + + # sum the probs for all prompt functions + if total_prob is None: + total_prob = token_prob else: - total_loss += loss + total_prob += token_prob # generated slot # find the token with the highest probability, this will be the generated word @@ -349,7 +370,8 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode generated_word = gen_words[gen_probs.index(max(gen_probs))] # loss is the log of probability - return total_loss, generated_word + loss = torch.negative(torch.log(total_prob)) + return loss, generated_word # Use this for generating next word (Validation after each epoch) @@ -390,7 +412,7 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) encoded_prompt = tokenizer(prompt, return_tensors="pt") encoded_prompt.to(device) - # generate 1 token + # generate 1 token (max length of slot = 1) outputs = model.generate(**encoded_prompt, return_dict_in_generate=True, output_scores=True, @@ -406,9 +428,9 @@ def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device) gen_probs.append(gen_prob.item()) gen_words.append(gen_word) - # return word with the highest probability - generated_word = gen_words[gen_probs.index(max(gen_probs))] - return generated_word.strip().lower() + # return word with the highest probability + generated_word = gen_words[gen_probs.index(max(gen_probs))] + return generated_word.strip().lower() if __name__ == "__main__": diff --git a/prompt-learning/test_prompting.sh b/prompt-learning/test_prompting.sh index 6e1543d..8ff2d8d 100644 --- a/prompt-learning/test_prompting.sh +++ b/prompt-learning/test_prompting.sh @@ -51,5 +51,4 @@ mkdir -p "${OUTPUTS_DIR}" python prompt_decode.py \ --output_dir="${OUTPUTS_DIR}" \ --tuned_model_path="${FINE_TUNED_MODEL_PATH}" \ ---test_data_file="${TEST_DATA_FILE}" \ ---with_prompt_ensemble \ No newline at end of file +--test_data_file="${TEST_DATA_FILE}" \ No newline at end of file