|
|
|
|
@ -10,6 +10,7 @@ from torch.optim import AdamW
|
|
|
|
|
from transformers import get_scheduler
|
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
from prompt_utils import get_value_based_prompt
|
|
|
|
|
from prompt_utils import get_ensemble_prompts
|
|
|
|
|
from prompt_utils import get_prompt_for_training
|
|
|
|
|
from prompt_utils import TYPE_VALUE_BASED_PROMPT
|
|
|
|
|
from prompt_utils import TYPE_INVERSE_PROMPT
|
|
|
|
|
@ -102,6 +103,8 @@ def main():
|
|
|
|
|
# set the model in training mode
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
tqdm.write(str('Training starts now... [with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
|
|
|
|
|
|
|
|
|
|
# training loop
|
|
|
|
|
for epoch in range(args.num_epochs):
|
|
|
|
|
running_loss = 0.0
|
|
|
|
|
@ -175,6 +178,8 @@ def main():
|
|
|
|
|
# if validation file is provided, run evaluation here (after each epoch)
|
|
|
|
|
if args.validation_file and validation_data is not None:
|
|
|
|
|
|
|
|
|
|
tqdm.write(str('Validation In Progress...[with_prompt_ensemble = '+str(args.with_prompt_ensemble)+']'))
|
|
|
|
|
|
|
|
|
|
# set tqdm progress bars for testing progress
|
|
|
|
|
validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False)
|
|
|
|
|
|
|
|
|
|
@ -197,7 +202,8 @@ def main():
|
|
|
|
|
true_states[slot] = value
|
|
|
|
|
|
|
|
|
|
# generate slot using value-based prompt
|
|
|
|
|
generated_slot = generate_slot_from_prompt(history=history,
|
|
|
|
|
generated_slot = generate_slot_from_prompt(args=args,
|
|
|
|
|
history=history,
|
|
|
|
|
value=value,
|
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
|
model=model,
|
|
|
|
|
@ -338,17 +344,21 @@ def train_prompt_ensemble(history, slot_value_pair, prompt_type, tokenizer, mode
|
|
|
|
|
gen_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
# add the generated word and probs to list
|
|
|
|
|
gen_probs.append(gen_word_prob)
|
|
|
|
|
gen_probs.append(gen_word_prob.item())
|
|
|
|
|
gen_words.append(gen_word)
|
|
|
|
|
|
|
|
|
|
generated_word = gen_words.index(max(gen_probs))
|
|
|
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
|
|
|
# loss is the log of probability
|
|
|
|
|
return total_loss, generated_word
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Use this for generating next word (Validation after each epoch)
|
|
|
|
|
def generate_slot_from_prompt(history, value, tokenizer, model, device):
|
|
|
|
|
# get prompt for training based on "type"
|
|
|
|
|
def generate_slot_from_prompt(args, history, value, tokenizer, model, device):
|
|
|
|
|
# check if prompt ensemble is enabled in arguments
|
|
|
|
|
if args.with_prompt_ensemble:
|
|
|
|
|
return generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device)
|
|
|
|
|
|
|
|
|
|
# get value-based prompt for generation
|
|
|
|
|
prompt = get_value_based_prompt(value)
|
|
|
|
|
|
|
|
|
|
# combine history and prompt
|
|
|
|
|
@ -360,9 +370,44 @@ def generate_slot_from_prompt(history, value, tokenizer, model, device):
|
|
|
|
|
# generate 1 token
|
|
|
|
|
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
|
|
|
|
|
|
|
|
|
|
gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
|
|
|
|
|
generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True)
|
|
|
|
|
gen_token = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
|
|
|
|
|
generated_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_slot_with_prompt_ensemble(history, value, tokenizer, model, device):
|
|
|
|
|
# get prompts for ensemble generation
|
|
|
|
|
prompts = get_ensemble_prompts(value)
|
|
|
|
|
|
|
|
|
|
gen_probs, gen_words = [], []
|
|
|
|
|
|
|
|
|
|
for prompt in prompts:
|
|
|
|
|
# combine history and prompt
|
|
|
|
|
prompt = history + prompt
|
|
|
|
|
|
|
|
|
|
# encode the history & prompt
|
|
|
|
|
encoded_prompt = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
encoded_prompt.to(device)
|
|
|
|
|
|
|
|
|
|
# generate 1 token
|
|
|
|
|
outputs = model.generate(**encoded_prompt,
|
|
|
|
|
return_dict_in_generate=True,
|
|
|
|
|
output_scores=True,
|
|
|
|
|
max_new_tokens=1)
|
|
|
|
|
|
|
|
|
|
gen_token = outputs.sequences[:, encoded_prompt['input_ids'].shape[-1]:]
|
|
|
|
|
gen_word = tokenizer.decode(gen_token.item(), skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
probs = torch.nn.functional.softmax(outputs.scores[0], dim=-1)
|
|
|
|
|
gen_prob = torch.gather(probs, 1, gen_token).squeeze(-1)
|
|
|
|
|
|
|
|
|
|
# add the generated word and probs to list
|
|
|
|
|
gen_probs.append(gen_prob.item())
|
|
|
|
|
gen_words.append(gen_word)
|
|
|
|
|
|
|
|
|
|
# return word with the highest probability
|
|
|
|
|
generated_word = gen_words[gen_probs.index(max(gen_probs))]
|
|
|
|
|
return generated_word.strip().lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|