import json import os import numpy as np class PromptDstDataset: def __init__(self, file_path, shuffle=True): # Assertion check for file availability assert os.path.isfile(file_path) # add all processed data items to this list self.dataset_items = [] self.total_num_slot_value_pairs = 0 print('Loading the dataset from :: ', file_path) dataset_list = json.load(open(file_path)) for item in dataset_list: history_str = '\n '.join(item['history']) # fill this with dialog history and slot-value pairs data_item = {'history': history_str + "\n"} # add extracted values for text/valid datasets if 'values' in item: data_item['values'] = item['values'] belief_states = item['belief_states'] if len(belief_states) == 0: continue slot_value_list = [] for belief_state in belief_states: # split 'slot = value' using '=' delimiter slot_value_split = belief_state.split("=") # check if the 'slot = value' item is valid if len(slot_value_split) != 2: continue slot = slot_value_split[0].strip().lower() value = slot_value_split[1].strip().lower() # skip slot values with invalid data like "none"|"None" if value == "none" or value == "None": continue # don't add this (slot, value) pair - Invalid if slot and value: slot_value_pair = (slot, value) slot_value_list.append(slot_value_pair) # If (slot, value) pairs are empty, continue & don't add this item if len(slot_value_list) <= 0: continue data_item['belief_states'] = slot_value_list self.total_num_slot_value_pairs += len(slot_value_list) # add the data_item dataset_items list # this item should be returned via getitem function self.dataset_items.append(data_item) # shuffle the data items list if shuffle: np.random.shuffle(self.dataset_items) # print some statistics print('Total data items = {}, Total (slot, value) pairs = {}' .format(len(self.dataset_items), self.total_num_slot_value_pairs)) def getitem(self, index): return self.dataset_items[index] def len(self): return len(self.dataset_items) def total_slot_value_pairs(self): return self.total_num_slot_value_pairs