diff --git a/eval/metrics.py b/eval/metrics.py index 9d2d520..1b9cf73 100644 --- a/eval/metrics.py +++ b/eval/metrics.py @@ -177,7 +177,7 @@ def get_confusion_matrix(y_true, y_pred): return confusion_matrix(y_true, y_pred, labels=const.CLASS_LABELS_LIST) -def plot_confusion_matrix(confusion_mat, classifier_name): +def plot_confusion_matrix(confusion_mat, classifier_name, plot_file_name): plt.figure(figsize=(8, 6)) plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues')) @@ -196,10 +196,10 @@ def plot_confusion_matrix(confusion_mat, classifier_name): horizontalalignment="center", color="white" if confusion_mat[i, j] > thresh else "black") - plt.tight_layout() + plt.tight_layout(1.5) plt.ylabel('True/Gold') plt.xlabel('Predicted') - plt.show(block=True) + plt.savefig(plot_file_name) class Result: diff --git a/predict.py b/predict.py index 1cbccd2..c5c3a37 100644 --- a/predict.py +++ b/predict.py @@ -3,10 +3,11 @@ import testing.intent_predictor as pred import eval.metrics as metrics -y_pred, y_true = pred.load_model_and_predict_test_data("/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4") +model_path = '/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4' +y_pred, y_true = pred.load_model_and_predict_test_data(model_path) confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred) print(confusion_matrix) -metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo") +metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo", model_path+'/confusion_matrix_plot.png')