import json import os import numpy as np import collections class PromptDstDataset: def __init__(self, file_path, shuffle=True): # Assertion check for file availability assert os.path.isfile(file_path) # add all processed data items to this list self.dataset_items = [] self.total_num_slot_value_pairs = 0 print('Loading the dataset from :: ', file_path) dataset_list = json.load(open(file_path)) for item in dataset_list: history_str = '\n '.join(item['history']) # fill this with dialog history and slot-value pairs data_item = {'history': history_str + "\n"} # add extracted values for text/valid datasets if 'values' in item: data_item['values'] = item['values'] belief_states = item['belief_states'] if len(belief_states) == 0: continue slot_value_list = [] for belief_state in belief_states: # split 'slot = value' using '=' delimiter slot_value_split = belief_state.split("=") # check if the 'slot = value' item is valid if len(slot_value_split) != 2: continue slot = slot_value_split[0].strip().lower() value = slot_value_split[1].strip().lower() # skip slot values with invalid data like "none"|"None" if value == "none" or value == "None": continue # don't add this (slot, value) pair - Invalid if slot and value: slot_value_pair = (slot, value) slot_value_list.append(slot_value_pair) # If (slot, value) pairs are empty, continue & don't add this item if len(slot_value_list) <= 0: continue data_item['belief_states'] = slot_value_list self.total_num_slot_value_pairs += len(slot_value_list) # add the data_item dataset_items list # this item should be returned via getitem function self.dataset_items.append(data_item) # shuffle the data items list if shuffle: np.random.shuffle(self.dataset_items) # print some statistics print('Total data items = {}, Total (slot, value) pairs = {}' .format(len(self.dataset_items), self.total_num_slot_value_pairs)) def getitem(self, index): return self.dataset_items[index] def len(self): return len(self.dataset_items) def total_slot_value_pairs(self): return self.total_num_slot_value_pairs def compute_value_extraction_accuracy(self): # iterate through data items list and extract values correct_values, correct_turns = 0, 0 for item in self.dataset_items: extracted_values = collections.Counter(item['values']) true_values = collections.Counter([value for _, value in item['belief_states']]) if extracted_values == true_values: correct_turns += 1 else: print('Extracted: ', extracted_values.keys()) print('True Values: ', true_values.keys()) print("") for key in true_values: if key in extracted_values \ and true_values[key] == extracted_values[key]: correct_values += true_values[key] print('Accuracy :: ', ((correct_turns/self.len())*100)) print('Slot-Value Accuracy :: ', ((correct_values/self.total_slot_value_pairs())*100))