You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
6.4 KiB
190 lines
6.4 KiB
import json
|
|
import os
|
|
from pathlib import Path
|
|
from tqdm.auto import tqdm
|
|
from corenlp import ValueExtractor
|
|
|
|
|
|
BELIEF_PREFIX = 'belief :'
|
|
|
|
MODIFIED_SLOTS = {
|
|
'area': 'area',
|
|
'arriveby': 'arrive',
|
|
'day': 'day',
|
|
'departure': 'departure',
|
|
'destination': 'destination',
|
|
'food': 'food',
|
|
'internet': 'internet',
|
|
'leaveat': 'leave',
|
|
'name': 'name',
|
|
'parking': 'parking',
|
|
'people': 'people',
|
|
'pricerange': 'price',
|
|
'stars': 'stars',
|
|
'stay': 'stay',
|
|
'time': 'time',
|
|
'type': 'type'
|
|
}
|
|
|
|
CORRECTIONS = {
|
|
"pizza hut fenditton": "pizza hut fen ditton",
|
|
"saint johns chop house": "saint johns chop shop house",
|
|
"1515": "15:15",
|
|
"center": "centre",
|
|
"pizza express fen ditton": "pizza hut fen ditton",
|
|
"apha-milton": "alpha-milton guest house",
|
|
"concerthall": "concert hall",
|
|
"oak bistro": "the oak bistro",
|
|
"nightclub": "night club",
|
|
"christs college": "christ college",
|
|
"museums": "museum",
|
|
"alexander": "alexander bed and breakfast",
|
|
"ian hong house": "lan hong house",
|
|
"saint catharines college": "saint catherines college",
|
|
"gandhi": "the gandhi",
|
|
"cambridge punte": "cambridge punter"
|
|
}
|
|
|
|
|
|
def convert_slot_for_prompting(slot_value_item):
|
|
# check if the 'slot = value' item is valid
|
|
if len(slot_value_item.split('=')) != 2:
|
|
return ''
|
|
|
|
# format of slot-value item is 'slot = value'
|
|
# split by '=' sign and strip() the whitespaces to get slots
|
|
slot = slot_value_item.split('=')[0].strip()
|
|
value = slot_value_item.split('=')[1].strip()
|
|
|
|
# skip invalid slot values
|
|
if value.lower() == 'none':
|
|
return ''
|
|
|
|
# modify the slot for prompting
|
|
modified_slot = MODIFIED_SLOTS[slot]
|
|
|
|
# correct the value (if required)
|
|
if value in CORRECTIONS:
|
|
value = CORRECTIONS[value]
|
|
|
|
# compose the 'slot = value' string from modified slot and return
|
|
return modified_slot + ' = ' + value
|
|
|
|
|
|
def create_belief_states_data_for_prompt_learning(data_tuple):
|
|
print('creating belief states data for :: ', data_tuple[1], ' => ', data_tuple[2])
|
|
|
|
# Assertion check for file availability
|
|
assert os.path.isfile(data_tuple[0])
|
|
|
|
data = json.load(open(data_tuple[0]))
|
|
print('Opening file => ', data_tuple[0], ' [Size = ', len(data), ']')
|
|
|
|
if len(data) <= 0:
|
|
return
|
|
|
|
progress = tqdm(total=len(data), desc="Creating Dataset ("+data_tuple[2]+")", leave=False)
|
|
|
|
extractor = None
|
|
# start the CoreNLP server for Value Extraction
|
|
# Only required for test/valid dataset
|
|
if data_tuple[1] in ['test', 'valid']:
|
|
extractor = ValueExtractor()
|
|
extractor.start()
|
|
|
|
# data to be saved for prompt learning
|
|
belief_states_dataset = []
|
|
|
|
for item in data:
|
|
# map to be added to the list for saving
|
|
belief_states_data_item = {}
|
|
|
|
# add the history & domains of dialog to the data item
|
|
belief_states_data_item['history'] = item['history']
|
|
belief_states_data_item['domains'] = item['domains']
|
|
|
|
# extract value candidates using stanford CoreNLP & add to test/valid dataset
|
|
if data_tuple[1] in ['test', 'valid']:
|
|
values = extractor.extract_value_candidates(item['history'])
|
|
correct_values = [CORRECTIONS[value] if value in CORRECTIONS else value for value in values]
|
|
belief_states_data_item['values'] = correct_values
|
|
|
|
# extract belief states
|
|
belief_states = item['belief']
|
|
|
|
# remove 'belief:' from the beginning
|
|
if belief_states.startswith(BELIEF_PREFIX):
|
|
belief_states = belief_states[len(BELIEF_PREFIX):]
|
|
|
|
# belief states can have multiple domains separated by '|'
|
|
belief_state_splits = belief_states.split('|')
|
|
|
|
# contains list of belief state items -> 'slot = value'
|
|
belief_slot_value_list = []
|
|
for belief_state in belief_state_splits:
|
|
if belief_state == '':
|
|
continue
|
|
if len(belief_state.split()) == 0:
|
|
continue
|
|
domain = belief_state.split()[0]
|
|
if domain == 'none':
|
|
continue
|
|
|
|
# remove domain from belief state
|
|
belief_state = ' '.join(belief_state.split()[1:])
|
|
|
|
# split belief state slot-value pairs
|
|
slot_value_list = belief_state.split(';')
|
|
|
|
for list_item in slot_value_list:
|
|
slot_value_list_item = list_item.strip()
|
|
if slot_value_list_item == '':
|
|
continue
|
|
|
|
# modify the slots for prompting (convert them to natural language)
|
|
slot_value_list_item = convert_slot_for_prompting(slot_value_list_item)
|
|
if slot_value_list_item == '':
|
|
continue
|
|
|
|
# add the 'slot = value' string to the list
|
|
belief_slot_value_list.append(slot_value_list_item)
|
|
|
|
# add belief states list to data item map (will be saved)
|
|
belief_states_data_item['belief_states'] = belief_slot_value_list
|
|
|
|
# update tqdm progress
|
|
progress.update(1)
|
|
# add to the dataset (to be saved!)
|
|
belief_states_dataset.append(belief_states_data_item)
|
|
|
|
# close tqdm progress
|
|
progress.close()
|
|
|
|
# stop CoreNLP server
|
|
if data_tuple[1] in ['test', 'valid'] and extractor is not None:
|
|
extractor.stop()
|
|
|
|
# save the dataset file
|
|
save_file_path = '../data/prompt-learning/' + data_tuple[2] + '/'
|
|
save_file_name = data_tuple[1] + '.soloist.json'
|
|
Path(save_file_path).mkdir(parents=True, exist_ok=True)
|
|
print('Saving file => ', save_file_path, ' [Size = ', len(belief_states_dataset), ']')
|
|
json.dump(belief_states_dataset, open(save_file_path + save_file_name, 'w'), indent=2)
|
|
|
|
|
|
# List contains tuples
|
|
# Each tuple has (data filepath, data type, split name)
|
|
data_list = [
|
|
("../data/baseline/test/test.soloist.json", "test", "test"),
|
|
("../data/baseline/valid/valid.soloist.json", "valid", "valid"),
|
|
("../data/baseline/5-dpd/train.soloist.json", "train", "5-dpd"),
|
|
("../data/baseline/10-dpd/train.soloist.json", "train", "10-dpd"),
|
|
("../data/baseline/50-dpd/train.soloist.json", "train", "50-dpd"),
|
|
("../data/baseline/100-dpd/train.soloist.json", "train", "100-dpd"),
|
|
("../data/baseline/125-dpd/train.soloist.json", "train", "125-dpd"),
|
|
("../data/baseline/250-dpd/train.soloist.json", "train", "250-dpd")
|
|
]
|
|
|
|
for file_tuple in data_list:
|
|
create_belief_states_data_for_prompt_learning(data_tuple=file_tuple)
|