You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

187 lines
6.3 KiB

import json
import os
from pathlib import Path
from tqdm.auto import tqdm
from corenlp import ValueExtractor
BELIEF_PREFIX = 'belief :'
MODIFIED_SLOTS = {
'area': 'area',
'arriveby': 'arrive',
'day': 'day',
'departure': 'departure',
'destination': 'destination',
'food': 'food',
'internet': 'internet',
'leaveat': 'leave',
'name': 'name',
'parking': 'parking',
'people': 'people',
'pricerange': 'price',
'stars': 'stars',
'stay': 'stay',
'time': 'time',
'type': 'type'
}
CORRECTIONS = {
"pizza hut fenditton": "pizza hut fen ditton",
"saint johns chop house": "saint johns chop shop house",
"1515": "15:15",
"center": "centre",
"pizza express fen ditton": "pizza hut fen ditton",
"apha-milton": "alpha-milton guest house",
"concerthall": "concert hall",
"oak bistro": "the oak bistro",
"nightclub": "night club",
"christs college": "christ college",
"museums": "museum",
"alexander": "alexander bed and breakfast",
"ian hong house": "lan hong house"
}
def convert_slot_for_prompting(slot_value_item):
# check if the 'slot = value' item is valid
if len(slot_value_item.split('=')) != 2:
return ''
# format of slot-value item is 'slot = value'
# split by '=' sign and strip() the whitespaces to get slots
slot = slot_value_item.split('=')[0].strip()
value = slot_value_item.split('=')[1].strip()
# skip invalid slot values
if value.lower() == 'none':
return ''
# modify the slot for prompting
modified_slot = MODIFIED_SLOTS[slot]
# correct the value (if required)
if value in CORRECTIONS:
value = CORRECTIONS[value]
# compose the 'slot = value' string from modified slot and return
return modified_slot + ' = ' + value
def create_belief_states_data_for_prompt_learning(data_tuple):
print('creating belief states data for :: ', data_tuple[1], ' => ', data_tuple[2])
# Assertion check for file availability
assert os.path.isfile(data_tuple[0])
data = json.load(open(data_tuple[0]))
print('Opening file => ', data_tuple[0], ' [Size = ', len(data), ']')
if len(data) <= 0:
return
progress = tqdm(total=len(data), desc="Creating Dataset ("+data_tuple[2]+")", leave=False)
extractor = None
# start the CoreNLP server for Value Extraction
# Only required for test/valid dataset
if data_tuple[1] in ['test', 'valid']:
extractor = ValueExtractor()
extractor.start()
# data to be saved for prompt learning
belief_states_dataset = []
for item in data:
# map to be added to the list for saving
belief_states_data_item = {}
# add the history & domains of dialog to the data item
belief_states_data_item['history'] = item['history']
belief_states_data_item['domains'] = item['domains']
# extract value candidates using stanford CoreNLP & add to test/valid dataset
if data_tuple[1] in ['test', 'valid']:
values = extractor.extract_value_candidates(item['history'])
correct_values = [CORRECTIONS[value] if value in CORRECTIONS else value for value in values]
belief_states_data_item['values'] = correct_values
# extract belief states
belief_states = item['belief']
# remove 'belief:' from the beginning
if belief_states.startswith(BELIEF_PREFIX):
belief_states = belief_states[len(BELIEF_PREFIX):]
# belief states can have multiple domains separated by '|'
belief_state_splits = belief_states.split('|')
# contains list of belief state items -> 'slot = value'
belief_slot_value_list = []
for belief_state in belief_state_splits:
if belief_state == '':
continue
if len(belief_state.split()) == 0:
continue
domain = belief_state.split()[0]
if domain == 'none':
continue
# remove domain from belief state
belief_state = ' '.join(belief_state.split()[1:])
# split belief state slot-value pairs
slot_value_list = belief_state.split(';')
for list_item in slot_value_list:
slot_value_list_item = list_item.strip()
if slot_value_list_item == '':
continue
# modify the slots for prompting (convert them to natural language)
slot_value_list_item = convert_slot_for_prompting(slot_value_list_item)
if slot_value_list_item == '':
continue
# add the 'slot = value' string to the list
belief_slot_value_list.append(slot_value_list_item)
# add belief states list to data item map (will be saved)
belief_states_data_item['belief_states'] = belief_slot_value_list
# update tqdm progress
progress.update(1)
# add to the dataset (to be saved!)
belief_states_dataset.append(belief_states_data_item)
# close tqdm progress
progress.close()
# stop CoreNLP server
if data_tuple[1] in ['test', 'valid'] and extractor is not None:
extractor.stop()
# save the dataset file
save_file_path = '../data/prompt-learning/' + data_tuple[2] + '/'
save_file_name = data_tuple[1] + '.soloist.json'
Path(save_file_path).mkdir(parents=True, exist_ok=True)
print('Saving file => ', save_file_path, ' [Size = ', len(belief_states_dataset), ']')
json.dump(belief_states_dataset, open(save_file_path + save_file_name, 'w'), indent=2)
# List contains tuples
# Each tuple has (data filepath, data type, split name)
data_list = [
("../data/baseline/test/test.soloist.json", "test", "test"),
("../data/baseline/valid/valid.soloist.json", "valid", "valid"),
("../data/baseline/5-dpd/train.soloist.json", "train", "5-dpd"),
("../data/baseline/10-dpd/train.soloist.json", "train", "10-dpd"),
("../data/baseline/50-dpd/train.soloist.json", "train", "50-dpd"),
("../data/baseline/100-dpd/train.soloist.json", "train", "100-dpd"),
("../data/baseline/125-dpd/train.soloist.json", "train", "125-dpd"),
("../data/baseline/250-dpd/train.soloist.json", "train", "250-dpd")
]
for file_tuple in data_list:
create_belief_states_data_for_prompt_learning(data_tuple=file_tuple)