Improved plot confusion matrix

isaac
Pavan Mandava 5 years ago
parent 52efebe53e
commit 3089662a0a

@ -1,6 +1,8 @@
import utils.constants as const import utils.constants as const
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
import itertools
def f1_score(y_true, y_pred, labels, average): def f1_score(y_true, y_pred, labels, average):
@ -176,22 +178,31 @@ def get_confusion_matrix(y_true, y_pred):
def plot_confusion_matrix(confusion_mat, classifier_name): def plot_confusion_matrix(confusion_mat, classifier_name):
"""
Takes Confusion Matrix as a parameter and plots the matrix using matplotlib accuracy = np.trace(confusion_mat) / float(np.sum(confusion_mat))
:param confusion_mat: Confusion Matrix mis_class = 1 - accuracy
:param classifier_name: Classifier Name to show it on the Top
""" plt.figure(figsize=(8, 6))
fig, ax = plt.subplots(2, 2) plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues'))
plt.title(classifier_name)
plt.colorbar()
target_names = const.CLASS_LABELS_LIST
if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
thresh = confusion_mat.max() / confusion_mat.max() / 2
for i, j in itertools.product(range(confusion_mat.shape[0]), range(confusion_mat.shape[1])):
plt.text(j, i, "{:,}".format(confusion_mat[i, j]),
horizontalalignment="center",
color="white" if confusion_mat[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True/Gold')
plt.xlabel('Predicted \nAccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, mis_class))
plt.show() plt.show()
ax.matshow(confusion_mat, cmap='Greens')
for x in (0, 2):
for y in (0, 2):
ax.text(x, y, confusion_mat[y, x])
ax.set_xlabel('Predicted')
ax.set_ylabel('True/Gold')
ax.set_xticklabels([''] + const.CLASS_LABELS_LIST)
ax.set_yticklabels([''] + const.CLASS_LABELS_LIST)
ax.set_title(classifier_name)
class Result: class Result:

@ -9,4 +9,4 @@ confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred)
print(confusion_matrix) print(confusion_matrix)
# metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo") metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo")
Loading…
Cancel
Save