Saving Confusion Matrix Plot PNG

isaac
Pavan Mandava 5 years ago
parent 244e27bee6
commit 87efce8f82

@ -177,7 +177,7 @@ def get_confusion_matrix(y_true, y_pred):
return confusion_matrix(y_true, y_pred, labels=const.CLASS_LABELS_LIST) return confusion_matrix(y_true, y_pred, labels=const.CLASS_LABELS_LIST)
def plot_confusion_matrix(confusion_mat, classifier_name): def plot_confusion_matrix(confusion_mat, classifier_name, plot_file_name):
plt.figure(figsize=(8, 6)) plt.figure(figsize=(8, 6))
plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues')) plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues'))
@ -196,10 +196,10 @@ def plot_confusion_matrix(confusion_mat, classifier_name):
horizontalalignment="center", horizontalalignment="center",
color="white" if confusion_mat[i, j] > thresh else "black") color="white" if confusion_mat[i, j] > thresh else "black")
plt.tight_layout() plt.tight_layout(1.5)
plt.ylabel('True/Gold') plt.ylabel('True/Gold')
plt.xlabel('Predicted') plt.xlabel('Predicted')
plt.show(block=True) plt.savefig(plot_file_name)
class Result: class Result:

@ -3,10 +3,11 @@ import testing.intent_predictor as pred
import eval.metrics as metrics import eval.metrics as metrics
y_pred, y_true = pred.load_model_and_predict_test_data("/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4") model_path = '/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4'
y_pred, y_true = pred.load_model_and_predict_test_data(model_path)
confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred) confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred)
print(confusion_matrix) print(confusion_matrix)
metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo") metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo", model_path+'/confusion_matrix_plot.png')

Loading…
Cancel
Save