diff --git a/eval/metrics.py b/eval/metrics.py index 01597f1..9d2d520 100644 --- a/eval/metrics.py +++ b/eval/metrics.py @@ -179,9 +179,6 @@ def get_confusion_matrix(y_true, y_pred): def plot_confusion_matrix(confusion_mat, classifier_name): - accuracy = np.trace(confusion_mat) / float(np.sum(confusion_mat)) - mis_class = 1 - accuracy - plt.figure(figsize=(8, 6)) plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues')) plt.title(classifier_name) @@ -201,8 +198,8 @@ def plot_confusion_matrix(confusion_mat, classifier_name): plt.tight_layout() plt.ylabel('True/Gold') - plt.xlabel('Predicted \nAccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, mis_class)) - plt.show() + plt.xlabel('Predicted') + plt.show(block=True) class Result: