You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.6 KiB

import json
import os
import numpy as np
class PromptDstDataset:
def __init__(self, file_path, shuffle=True):
# Assertion check for file availability
assert os.path.isfile(file_path)
# add all processed data items to this list
self.dataset_items = []
self.total_num_slot_value_pairs = 0
print('Loading the dataset from :: ', file_path)
dataset_list = json.load(open(file_path))
for item in dataset_list:
history_str = '\n '.join(item['history'])
# fill this with dialog history and slot-value pairs
data_item = {'history': history_str + "\n"}
# add extracted values for text/valid datasets
if 'values' in item:
data_item['values'] = item['values']
belief_states = item['belief_states']
if len(belief_states) == 0:
continue
slot_value_list = []
for belief_state in belief_states:
# split 'slot = value' using '=' delimiter
slot_value_split = belief_state.split("=")
# check if the 'slot = value' item is valid
if len(slot_value_split) != 2:
continue
slot = slot_value_split[0].strip().lower()
value = slot_value_split[1].strip().lower()
# skip slot values with invalid data like "none"|"None"
if value == "none" or value == "None":
continue # don't add this (slot, value) pair - Invalid
if slot and value:
slot_value_pair = (slot, value)
slot_value_list.append(slot_value_pair)
# If (slot, value) pairs are empty, continue & don't add this item
if len(slot_value_list) <= 0:
continue
data_item['belief_states'] = slot_value_list
self.total_num_slot_value_pairs += len(slot_value_list)
# add the data_item dataset_items list
# this item should be returned via getitem function
self.dataset_items.append(data_item)
# shuffle the data items list
if shuffle:
np.random.shuffle(self.dataset_items)
# print some statistics
print('Total data items = {}, Total (slot, value) pairs = {}'
.format(len(self.dataset_items), self.total_num_slot_value_pairs))
def getitem(self, index):
return self.dataset_items[index]
def len(self):
return len(self.dataset_items)
def total_slot_value_pairs(self):
return self.total_num_slot_value_pairs