|
|
|
@ -177,7 +177,7 @@ def get_confusion_matrix(y_true, y_pred):
|
|
|
|
return confusion_matrix(y_true, y_pred, labels=const.CLASS_LABELS_LIST)
|
|
|
|
return confusion_matrix(y_true, y_pred, labels=const.CLASS_LABELS_LIST)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_confusion_matrix(confusion_mat, classifier_name):
|
|
|
|
def plot_confusion_matrix(confusion_mat, classifier_name, plot_file_name):
|
|
|
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 6))
|
|
|
|
plt.figure(figsize=(8, 6))
|
|
|
|
plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues'))
|
|
|
|
plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues'))
|
|
|
|
@ -196,10 +196,10 @@ def plot_confusion_matrix(confusion_mat, classifier_name):
|
|
|
|
horizontalalignment="center",
|
|
|
|
horizontalalignment="center",
|
|
|
|
color="white" if confusion_mat[i, j] > thresh else "black")
|
|
|
|
color="white" if confusion_mat[i, j] > thresh else "black")
|
|
|
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
plt.tight_layout(1.5)
|
|
|
|
plt.ylabel('True/Gold')
|
|
|
|
plt.ylabel('True/Gold')
|
|
|
|
plt.xlabel('Predicted')
|
|
|
|
plt.xlabel('Predicted')
|
|
|
|
plt.show(block=True)
|
|
|
|
plt.savefig(plot_file_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Result:
|
|
|
|
class Result:
|
|
|
|
|