You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.6 KiB
100 lines
3.6 KiB
import json
|
|
import os
|
|
import numpy as np
|
|
import collections
|
|
|
|
|
|
class PromptDstDataset:
|
|
|
|
def __init__(self, file_path, shuffle=True):
|
|
# Assertion check for file availability
|
|
assert os.path.isfile(file_path)
|
|
|
|
# add all processed data items to this list
|
|
self.dataset_items = []
|
|
self.total_num_slot_value_pairs = 0
|
|
|
|
print('Loading the dataset from :: ', file_path)
|
|
dataset_list = json.load(open(file_path))
|
|
|
|
for item in dataset_list:
|
|
|
|
history_str = '\n '.join(item['history'])
|
|
# fill this with dialog history and slot-value pairs
|
|
data_item = {'history': history_str + "\n"}
|
|
|
|
# add extracted values for text/valid datasets
|
|
if 'values' in item:
|
|
data_item['values'] = item['values']
|
|
|
|
belief_states = item['belief_states']
|
|
if len(belief_states) == 0:
|
|
continue
|
|
|
|
slot_value_list = []
|
|
for belief_state in belief_states:
|
|
# split 'slot = value' using '=' delimiter
|
|
slot_value_split = belief_state.split("=")
|
|
|
|
# check if the 'slot = value' item is valid
|
|
if len(slot_value_split) != 2:
|
|
continue
|
|
|
|
slot = slot_value_split[0].strip().lower()
|
|
value = slot_value_split[1].strip().lower()
|
|
|
|
# skip slot values with invalid data like "none"|"None"
|
|
if value == "none" or value == "None":
|
|
continue # don't add this (slot, value) pair - Invalid
|
|
|
|
if slot and value:
|
|
slot_value_pair = (slot, value)
|
|
slot_value_list.append(slot_value_pair)
|
|
|
|
# If (slot, value) pairs are empty, continue & don't add this item
|
|
if len(slot_value_list) <= 0:
|
|
continue
|
|
data_item['belief_states'] = slot_value_list
|
|
self.total_num_slot_value_pairs += len(slot_value_list)
|
|
# add the data_item dataset_items list
|
|
# this item should be returned via getitem function
|
|
self.dataset_items.append(data_item)
|
|
|
|
# shuffle the data items list
|
|
if shuffle:
|
|
np.random.shuffle(self.dataset_items)
|
|
|
|
# print some statistics
|
|
print('Total data items = {}, Total (slot, value) pairs = {}'
|
|
.format(len(self.dataset_items), self.total_num_slot_value_pairs))
|
|
|
|
def getitem(self, index):
|
|
return self.dataset_items[index]
|
|
|
|
def len(self):
|
|
return len(self.dataset_items)
|
|
|
|
def total_slot_value_pairs(self):
|
|
return self.total_num_slot_value_pairs
|
|
|
|
def compute_value_extraction_accuracy(self):
|
|
# iterate through data items list and extract values
|
|
|
|
correct_values, correct_turns = 0, 0
|
|
for item in self.dataset_items:
|
|
extracted_values = collections.Counter(item['values'])
|
|
true_values = collections.Counter([value for _, value in item['belief_states']])
|
|
if extracted_values == true_values:
|
|
correct_turns += 1
|
|
# else:
|
|
# print('Extracted: ', extracted_values.keys())
|
|
# print('True Values: ', true_values.keys())
|
|
# print("")
|
|
for key in true_values:
|
|
if key in extracted_values \
|
|
and true_values[key] == extracted_values[key]:
|
|
correct_values += true_values[key]
|
|
|
|
print('Accuracy :: ', ((correct_turns/self.len())*100))
|
|
print('Slot-Value Accuracy :: ', ((correct_values/self.total_slot_value_pairs())*100))
|