Merge remote-tracking branch 'origin/master'

isaac
yelircaasi 5 years ago
commit 02baf00dc5

@ -104,10 +104,10 @@ class BiLstmClassifier(Model):
output_dict['probabilities'] = class_probabilities output_dict['probabilities'] = class_probabilities
output_dict['positive_label'] = label output_dict['positive_label'] = label
output_dict['prediction'] = label output_dict['prediction'] = label
citation_text = [] # citation_text = []
for batch_text in output_dict['tokens']: # for batch_text in output_dict['tokens']:
citation_text.append([self.vocab.get_token_from_index(token_id.item()) for token_id in batch_text]) # citation_text.append([self.vocab.get_token_from_index(token_id.item()) for token_id in batch_text])
output_dict['tokens'] = citation_text # output_dict['tokens'] = citation_text
return output_dict return output_dict

@ -48,10 +48,10 @@
}, },
"trainer": { "trainer": {
"optimizer": { "optimizer": {
"type": "adam", "type": "adagrad",
"lr": 0.001 "lr": 0.005
}, },
"num_epochs": 2, "num_epochs": 10,
"cuda_device": 1 "cuda_device": 3
} }
} }

@ -1,4 +1,8 @@
import utils.constants as const import utils.constants as const
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import itertools
def f1_score(y_true, y_pred, labels, average): def f1_score(y_true, y_pred, labels, average):
@ -163,6 +167,41 @@ def calculate_f1_score(precision, recall):
return 2 * (precision * recall) / (precision + recall) return 2 * (precision * recall) / (precision + recall)
def get_confusion_matrix(y_true, y_pred):
"""
takes predicted labels and true labels as parameters and returns Confusion Matrix
:param y_true: True labels
:param y_pred: Predicted labels
:return: returns Confusion Matrix
"""
return confusion_matrix(y_true, y_pred, labels=const.CLASS_LABELS_LIST)
def plot_confusion_matrix(confusion_mat, classifier_name, plot_file_name):
plt.figure(figsize=(8, 6))
plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues'))
plt.title(classifier_name)
plt.colorbar()
target_names = const.CLASS_LABELS_LIST
if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
thresh = confusion_mat.max() / 2
for i, j in itertools.product(range(confusion_mat.shape[0]), range(confusion_mat.shape[1])):
plt.text(j, i, "{:,}".format(confusion_mat[i, j]),
horizontalalignment="center",
color="white" if confusion_mat[i, j] > thresh else "black")
plt.tight_layout(1.5)
plt.ylabel('True/Gold')
plt.xlabel('Predicted')
plt.savefig(plot_file_name)
class Result: class Result:
""" """
Model Class for carrying Evaluation Data (F1 Score, Precision, Recall, ....) Model Class for carrying Evaluation Data (F1 Score, Precision, Recall, ....)

@ -0,0 +1,16 @@
import classifier
import testing.intent_predictor as pred
import eval.metrics as metrics
model_path = '/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4'
y_pred, y_true = pred.load_model_and_predict_test_data(model_path)
confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred)
print(confusion_matrix)
plot_file_path = model_path+'/confusion_matrix_plot.png'
metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo", plot_file_path)
print('Confusion Matrix Plot saved to :: ', plot_file_path)

@ -0,0 +1,2 @@
from utils.reader import *
from classifier.nn import *

@ -1,17 +1,54 @@
from typing import Dict, List, Tuple
from allennlp.common import JsonDict from allennlp.common import JsonDict
from allennlp.data import Instance from allennlp.data import Instance
from allennlp.predictors import Predictor from allennlp.predictors import Predictor
from overrides import overrides from overrides import overrides
from allennlp.models import Model
from allennlp.data.dataset_readers import DatasetReader
from allennlp.models.archival import load_archive
from utils.reader import DataReaderJsonLines, CitationDataSetReader
import os
@Predictor.register('citation_intent_predictor') @Predictor.register('citation_intent_predictor')
class IntentClassificationPredictor(Predictor): class IntentClassificationPredictor(Predictor):
""""Predictor for Citation Intent Classifier""" """"Predictor for Citation Intent Classifier"""
def predict(self, text: str, intent: str):
return self.predict_json({"citation_text": text, "intent": intent})
@overrides @overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance: def _json_to_instance(self, json_dict: JsonDict) -> Instance:
pass return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"])
@overrides
def predict_json(self, inputs: JsonDict) -> JsonDict: def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[List[Dict[str, float]], list]:
pass """Make predictions using the given model and dataset reader"""
predictor = IntentClassificationPredictor(model, dataset_reader)
prediction_list = []
true_list = []
jsonl_reader = DataReaderJsonLines(file_path)
for citation in jsonl_reader.read():
true_list.append(citation.intent)
output = predictor.predict(citation.text, citation.intent)
prediction_list.append(output['prediction'])
return prediction_list, true_list
def load_model_and_predict_test_data(saved_model_dir: str):
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
dev_file_path = project_root + '/data/jsonl/dev.jsonl'
test_file_path = project_root + '/data/jsonl/test.jsonl'
model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz'))
citation_dataset_reader = CitationDataSetReader()
y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path)
return y_pred, y_true

@ -34,3 +34,4 @@ REGEX_CONSTANTS = {
} }
CLASS_LABELS = {"background": 0, "method": 1, "result": 2} CLASS_LABELS = {"background": 0, "method": 1, "result": 2}
CLASS_LABELS_LIST = ['background', 'method', 'result']

@ -75,7 +75,8 @@ class DataReaderJsonLines:
This method opens the file, reads every line and returns a collection of lines This method opens the file, reads every line and returns a collection of lines
:return: collection of Citation Objects, with the required data :return: collection of Citation Objects, with the required data
""" """
for line in jsonlines.open(self.file_path): with jsonlines.open(self.file_path) as jl_reader:
for line in jl_reader:
yield read_json_line(line) yield read_json_line(line)

Loading…
Cancel
Save