You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
9.3 KiB
261 lines
9.3 KiB
import json
|
|
import re
|
|
import string
|
|
|
|
corenlp_props = {
|
|
'annotators': 'tokenize, pos, ner, dcoref',
|
|
'pipelineLanguage': 'en',
|
|
'outputFormat': 'json',
|
|
'parse.maxlen': '1000',
|
|
'timeout': '500000'
|
|
}
|
|
|
|
STOPWORDS_FILE = "../data/resource/stopwords.txt"
|
|
|
|
DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi", "book"]
|
|
|
|
SLOTS = {'area', 'arrive', 'day', 'departure', 'destination', 'food', 'internet', 'leave',
|
|
'name', 'parking', 'people', 'price', 'stars', 'stay', 'time', 'type'}
|
|
|
|
VALUES_CONVERT = {
|
|
'zero': '0',
|
|
'one': '1',
|
|
'two': '2',
|
|
'three': '3',
|
|
'four': '4',
|
|
'five': '5',
|
|
'six': '6',
|
|
'seven': '7',
|
|
'eight': '8',
|
|
'nine': '9',
|
|
'wifi': 'internet',
|
|
'wlan': 'internet',
|
|
'wi-fi': 'internet',
|
|
'moderately': 'moderate',
|
|
}
|
|
|
|
|
|
def bad_entity(text):
|
|
if text == "this":
|
|
return True
|
|
if text == "that":
|
|
return True
|
|
if text == "there":
|
|
return True
|
|
if text == "here":
|
|
return True
|
|
if text == "|":
|
|
return True
|
|
if text == "less":
|
|
return True
|
|
if text == "more":
|
|
return True
|
|
return False
|
|
|
|
|
|
def fix_stanford_coref(stanford_json):
|
|
true_corefs = {}
|
|
# get a chain
|
|
for key, coref in stanford_json["corefs"].items():
|
|
true_coref = []
|
|
# get an entity mention
|
|
for entity in coref:
|
|
sent_num = entity["sentNum"] - 1 # starting from 0
|
|
start_index = entity["startIndex"] - 1 # starting from 0
|
|
end_index = entity["endIndex"] - 1 # starting from 0
|
|
head_index = entity["headIndex"] - 1 # starting from 0
|
|
entity_label = stanford_json["sentences"][
|
|
sent_num]["tokens"][head_index]["ner"]
|
|
entity["sentNum"] = sent_num
|
|
entity["startIndex"] = start_index
|
|
entity["endIndex"] = end_index
|
|
entity["headIndex"] = head_index
|
|
entity["headWord"] = entity["text"].split(
|
|
" ")[head_index - start_index]
|
|
entity["entityType"] = entity_label
|
|
true_coref.append(entity)
|
|
# check link is not empty
|
|
if len(true_coref) > 0:
|
|
no_representative = True
|
|
has_representative = False
|
|
for idx, entity in enumerate(true_coref):
|
|
if entity["isRepresentativeMention"]:
|
|
if not (entity["type"] == "PRONOMINAL" or
|
|
bad_entity(entity["text"].lower()) or
|
|
len(entity["text"].split(" ")) > 10):
|
|
no_representative = False
|
|
has_representative = True
|
|
# remove bad representative assignments
|
|
else:
|
|
true_coref[idx]["isRepresentativeMention"] = False
|
|
# check there exists one representative mention
|
|
if no_representative:
|
|
for idx, entity in enumerate(true_coref):
|
|
if not (entity["type"] == "PRONOMINAL" or
|
|
bad_entity(entity["text"].lower()) or
|
|
len(entity["text"].split(" ")) > 10):
|
|
true_coref[idx]["isRepresentativeMention"] = True
|
|
has_representative = True
|
|
if has_representative:
|
|
true_corefs[key] = true_coref
|
|
return true_corefs
|
|
|
|
|
|
def clean(corefs: list, stopwords_list: list):
|
|
dup_ids = []
|
|
for i, coref1 in enumerate(corefs):
|
|
consist_num = 0
|
|
short = []
|
|
for j, coref2 in enumerate(corefs):
|
|
if coref1[2][0] <= coref2[2][0] and coref1[2][1] >= coref2[2][1] and (not i == j):
|
|
consist_num += 1
|
|
short.append(j)
|
|
if consist_num > 1:
|
|
dup_ids.append(i)
|
|
elif consist_num == 1:
|
|
dup_ids.extend(short)
|
|
corefs = [corefs[i] for i in range(len(corefs)) if i not in dup_ids]
|
|
|
|
temp = []
|
|
for coref in corefs:
|
|
seq = coref[-1].split()
|
|
while seq and (seq[0] in stopwords_list or seq[-1] in stopwords_list):
|
|
if seq[0] in stopwords_list:
|
|
del seq[0]
|
|
if seq[-1] in stopwords_list:
|
|
del seq[-1]
|
|
if not seq:
|
|
temp.append(coref)
|
|
else:
|
|
coref[-1] = ' '.join(seq)
|
|
for t in temp:
|
|
corefs.remove(t)
|
|
|
|
return corefs
|
|
|
|
|
|
def get_candidates(user_annotation, stopwords_list):
|
|
"""Candidates include adjs, entities and corefs."""
|
|
tokens = []
|
|
candidates = {}
|
|
entities = []
|
|
postags = []
|
|
corefs = []
|
|
base_index = [0]
|
|
|
|
read_annotation(user_annotation, base_index, stopwords_list, tokens, entities, postags, corefs, 0)
|
|
|
|
candidates['postag'] = postags
|
|
candidates['coref'] = clean(corefs, stopwords_list)
|
|
candidates['coref'].extend(entities)
|
|
|
|
return candidates
|
|
|
|
|
|
def is_stop(text: str, stopwords_list: list):
|
|
text = list(filter(lambda x: x.lower() not in stopwords_list, text.split()))
|
|
if text:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def read_annotation(annotation, base_index, stopwords_list, tokens, entities, postags, corefs, num_sen):
|
|
sentences = annotation["sentences"]
|
|
for i, sentence in enumerate(sentences):
|
|
|
|
for entity in sentence['entitymentions']:
|
|
head_idx = base_index[i + num_sen] + entity['tokenBegin']
|
|
head = sentence['tokens'][entity['tokenBegin']]['originalText']
|
|
mention = entity['text']
|
|
mention_start_idx = base_index[i + num_sen] + entity['tokenBegin']
|
|
mention_end_idx = base_index[i + num_sen] + entity['tokenEnd']
|
|
mention_idx = [mention_start_idx, mention_end_idx]
|
|
entities.append([head_idx, head, mention_idx, mention])
|
|
|
|
for j, token in enumerate(sentence['tokens']):
|
|
tokens.append(token['word'])
|
|
pos = token['pos']
|
|
lemma = token['lemma']
|
|
text = token['originalText']
|
|
if pos in ['JJ', 'RB']:
|
|
try:
|
|
prev = sentence['tokens'][j - 1]['originalText']
|
|
except IndexError:
|
|
prev = ''
|
|
if (not re.search(r"([a-z]\.[a-z])", lemma)) \
|
|
and lemma not in stopwords_list and prev != 'not':
|
|
head_idx = base_index[i + num_sen] + token['index'] - 1
|
|
postags.append([head_idx, text])
|
|
|
|
base_index.append(base_index[-1] + len(sentence['tokens']))
|
|
|
|
for coref in annotation['corefs'].values():
|
|
for realization in coref:
|
|
sent_num = realization['sentNum']
|
|
head_index = realization['headIndex']
|
|
head_idx = base_index[sent_num + num_sen] + head_index
|
|
head = sentences[sent_num]['tokens'][head_index]['originalText']
|
|
text_start_index = realization['startIndex']
|
|
text_start_idx = base_index[sent_num + num_sen] + text_start_index
|
|
text_end_index = realization['endIndex']
|
|
text_end_idx = base_index[sent_num + num_sen] + text_end_index
|
|
text_lemma = sentences[sent_num]['tokens'][text_start_index:text_end_index]
|
|
text_lemma = ' '.join(list(map(lambda x: x['originalText'], text_lemma)))
|
|
try:
|
|
prev1 = sentences[sent_num]['tokens'][text_start_index - 1]['originalText']
|
|
prev2 = sentences[sent_num]['tokens'][text_start_index - 2]['originalText']
|
|
except BaseException:
|
|
prev1 = ''
|
|
prev2 = ''
|
|
|
|
should_stop = is_stop(text_lemma, stopwords_list)
|
|
if should_stop and prev1 != 'not' and prev2 != 'not':
|
|
corefs.append([head_idx, head, [text_start_idx, text_end_idx], text_lemma])
|
|
|
|
|
|
def get_value_candidates_from_history(corenlp, history):
|
|
if len(history) == 0:
|
|
return []
|
|
|
|
stopwords = []
|
|
with open(STOPWORDS_FILE, 'r') as fin:
|
|
for line in fin:
|
|
stopwords.append(line.strip())
|
|
|
|
value_candidates = set()
|
|
|
|
user_utterance = ' '.join(utterance[len('user :'):] for utterance in history if utterance.startswith('user :'))
|
|
annotation = json.loads(corenlp.annotate(user_utterance, properties=corenlp_props))
|
|
annotation['corefs'] = fix_stanford_coref(annotation)
|
|
|
|
candidates = get_candidates(annotation, stopwords)
|
|
for _, candidate in candidates.items():
|
|
for c in candidate:
|
|
if len(c) == 2:
|
|
value_candidates.add(c[1].strip().lower())
|
|
else:
|
|
if len(c[3].split()) > 5:
|
|
value_candidates.add(c[1].strip().lower())
|
|
else:
|
|
value_candidates.add(c[3].strip().lower())
|
|
|
|
# clean value candidates
|
|
values = set()
|
|
for value in value_candidates:
|
|
if value in VALUES_CONVERT:
|
|
value = VALUES_CONVERT[value]
|
|
if value not in DOMAINS \
|
|
and value not in SLOTS \
|
|
and value not in string.punctuation \
|
|
and value not in stopwords \
|
|
and not value.startswith("'"):
|
|
# remove spaces before punctuation
|
|
value = re.sub(r"\s+([?.!'])", r"\1", value).strip()
|
|
if value and value[0].isdigit():
|
|
# remove everything after end of a number
|
|
value = re.sub(r'\D+$', '', value)
|
|
if value.strip() and len(value.split()) <= 4:
|
|
values.add(value.strip())
|
|
return list(values)
|