You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
4.7 KiB

import json
import os
from pathlib import Path
BELIEF_PREFIX = 'belief :'
ALL_SLOTS = ['area', 'arriveby', 'day', 'departure', 'destination',
'food', 'internet', 'leaveat', 'name', 'parking',
'people', 'pricerange', 'stars', 'time', 'type']
MODIFIED_SLOTS = {
'area': 'area',
'arriveby': 'arrive',
'day': 'day',
'departure': 'departure',
'destination': 'destination',
'food': 'food',
'internet': 'internet',
'leaveat': 'leave',
'name': 'name',
'parking': 'parking',
'people': 'people',
'pricerange': 'price',
'stars': 'stars',
'stay': 'stay',
'time': 'time',
'type': 'type'
}
def convert_slot_for_prompting(slot_value_item):
# check if the 'slot = value' item is valid
if len(slot_value_item.split('=')) != 2:
return ''
# format of slot-value item is 'slot = value'
# split by '=' sign and strip() the whitespaces to get slots
slot = slot_value_item.split('=')[0].strip()
value = slot_value_item.split('=')[1].strip()
# modify the slot for prompting
modified_slot = MODIFIED_SLOTS[slot]
# compose the 'slot = value' string from modified slot and return
return modified_slot + ' = ' + value
def create_belief_states_data_for_prompt_learning(data_tuple):
print('creating belief states data for :: ', data_tuple[1], ' => ', data_tuple[2])
# Assertion check for file availability
assert os.path.isfile(data_tuple[0])
data = json.load(open(data_tuple[0]))
print('Opening file => ', data_tuple[0], ' [Size = ', len(data), ']')
if len(data) <= 0:
return
# data to be saved for prompt learning
belief_states_dataset = []
for item in data:
# map to be added to the list for saving
belief_states_data_item = {}
# add the history & domains of dialog to the data item
belief_states_data_item['history'] = item['history']
belief_states_data_item['domains'] = item['domains']
# extract belief states
belief_states = item['belief']
# remove 'belief:' from the beginning
if belief_states.startswith(BELIEF_PREFIX):
belief_states = belief_states[len(BELIEF_PREFIX):]
# belief states can have multiple domains separated by '|'
belief_state_splits = belief_states.split('|')
# contains list of belief state items -> 'slot = value'
belief_slot_value_list = []
for belief_state in belief_state_splits:
if belief_state == '':
continue
if len(belief_state.split()) == 0:
continue
domain = belief_state.split()[0]
if domain == 'none':
continue
# remove domain from belief state
belief_state = ' '.join(belief_state.split()[1:])
# split belief state slot-value pairs
slot_value_list = belief_state.split(';')
for list_item in slot_value_list:
slot_value_list_item = list_item.strip()
if slot_value_list_item == '':
continue
# modify the slots for prompting (convert them to natural language)
slot_value_list_item = convert_slot_for_prompting(slot_value_list_item)
if slot_value_list_item == '':
continue
# add the 'slot = value' string to the list
belief_slot_value_list.append(slot_value_list_item)
# add belief states list to data item map (will be saved)
belief_states_data_item['belief_states'] = belief_slot_value_list
# add to the dataset (to be saved!)
belief_states_dataset.append(belief_states_data_item)
# save the dataset file
save_file_path = '../data/prompt-learning/' + data_tuple[2] + '/'
save_file_name = data_tuple[1] + '.soloist.json'
Path(save_file_path).mkdir(parents=True, exist_ok=True)
print('Saving file => ', save_file_path, ' [Size = ', len(belief_states_dataset), ']')
json.dump(belief_states_dataset, open(save_file_path + save_file_name, 'w'), indent=2)
# List contains tuples
# Each tuple has (data filepath, data type, split name)
data_list = [
("../data/baseline/test/test.soloist.json", "test", "test"),
("../data/baseline/valid/valid.soloist.json", "valid", "valid"),
("../data/baseline/50-dpd/train.soloist.json", "train", "50-dpd"),
("../data/baseline/100-dpd/train.soloist.json", "train", "100-dpd"),
("../data/baseline/125-dpd/train.soloist.json", "train", "125-dpd"),
("../data/baseline/250-dpd/train.soloist.json", "train", "250-dpd")
]
for file_tuple in data_list:
create_belief_states_data_for_prompt_learning(data_tuple=file_tuple)