diff --git a/eval/metrics.py b/eval/metrics.py index f10ee8f..01597f1 100644 --- a/eval/metrics.py +++ b/eval/metrics.py @@ -1,6 +1,8 @@ import utils.constants as const from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt +import numpy as np +import itertools def f1_score(y_true, y_pred, labels, average): @@ -176,22 +178,31 @@ def get_confusion_matrix(y_true, y_pred): def plot_confusion_matrix(confusion_mat, classifier_name): - """ - Takes Confusion Matrix as a parameter and plots the matrix using matplotlib - :param confusion_mat: Confusion Matrix - :param classifier_name: Classifier Name to show it on the Top - """ - fig, ax = plt.subplots(2, 2) + + accuracy = np.trace(confusion_mat) / float(np.sum(confusion_mat)) + mis_class = 1 - accuracy + + plt.figure(figsize=(8, 6)) + plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues')) + plt.title(classifier_name) + plt.colorbar() + + target_names = const.CLASS_LABELS_LIST + if target_names is not None: + tick_marks = np.arange(len(target_names)) + plt.xticks(tick_marks, target_names, rotation=45) + plt.yticks(tick_marks, target_names) + + thresh = confusion_mat.max() / confusion_mat.max() / 2 + for i, j in itertools.product(range(confusion_mat.shape[0]), range(confusion_mat.shape[1])): + plt.text(j, i, "{:,}".format(confusion_mat[i, j]), + horizontalalignment="center", + color="white" if confusion_mat[i, j] > thresh else "black") + + plt.tight_layout() + plt.ylabel('True/Gold') + plt.xlabel('Predicted \nAccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, mis_class)) plt.show() - ax.matshow(confusion_mat, cmap='Greens') - for x in (0, 2): - for y in (0, 2): - ax.text(x, y, confusion_mat[y, x]) - ax.set_xlabel('Predicted') - ax.set_ylabel('True/Gold') - ax.set_xticklabels([''] + const.CLASS_LABELS_LIST) - ax.set_yticklabels([''] + const.CLASS_LABELS_LIST) - ax.set_title(classifier_name) class Result: diff --git a/run.py b/predict.py similarity index 77% rename from run.py rename to predict.py index 1ee5a95..1cbccd2 100644 --- a/run.py +++ b/predict.py @@ -9,4 +9,4 @@ confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred) print(confusion_matrix) -# metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo") +metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo")