diff --git a/classifier/nn.py b/classifier/nn.py index 8550ccb..1f0f51e 100644 --- a/classifier/nn.py +++ b/classifier/nn.py @@ -104,10 +104,10 @@ class BiLstmClassifier(Model): output_dict['probabilities'] = class_probabilities output_dict['positive_label'] = label output_dict['prediction'] = label - citation_text = [] - for batch_text in output_dict['tokens']: - citation_text.append([self.vocab.get_token_from_index(token_id.item()) for token_id in batch_text]) - output_dict['tokens'] = citation_text + # citation_text = [] + # for batch_text in output_dict['tokens']: + # citation_text.append([self.vocab.get_token_from_index(token_id.item()) for token_id in batch_text]) + # output_dict['tokens'] = citation_text return output_dict diff --git a/configs/basic_model.json b/configs/basic_model.json index 55d9af8..e802261 100644 --- a/configs/basic_model.json +++ b/configs/basic_model.json @@ -48,10 +48,10 @@ }, "trainer": { "optimizer": { - "type": "adam", - "lr": 0.001 + "type": "adagrad", + "lr": 0.005 }, - "num_epochs": 2, - "cuda_device": 1 + "num_epochs": 10, + "cuda_device": 3 } } diff --git a/eval/metrics.py b/eval/metrics.py index f844a3d..f7cdb4c 100644 --- a/eval/metrics.py +++ b/eval/metrics.py @@ -1,4 +1,8 @@ import utils.constants as const +from sklearn.metrics import confusion_matrix +import matplotlib.pyplot as plt +import numpy as np +import itertools def f1_score(y_true, y_pred, labels, average): @@ -163,6 +167,41 @@ def calculate_f1_score(precision, recall): return 2 * (precision * recall) / (precision + recall) +def get_confusion_matrix(y_true, y_pred): + """ + takes predicted labels and true labels as parameters and returns Confusion Matrix + :param y_true: True labels + :param y_pred: Predicted labels + :return: returns Confusion Matrix + """ + return confusion_matrix(y_true, y_pred, labels=const.CLASS_LABELS_LIST) + + +def plot_confusion_matrix(confusion_mat, classifier_name, plot_file_name): + + plt.figure(figsize=(8, 6)) + plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues')) + plt.title(classifier_name) + plt.colorbar() + + target_names = const.CLASS_LABELS_LIST + if target_names is not None: + tick_marks = np.arange(len(target_names)) + plt.xticks(tick_marks, target_names, rotation=45) + plt.yticks(tick_marks, target_names) + + thresh = confusion_mat.max() / 2 + for i, j in itertools.product(range(confusion_mat.shape[0]), range(confusion_mat.shape[1])): + plt.text(j, i, "{:,}".format(confusion_mat[i, j]), + horizontalalignment="center", + color="white" if confusion_mat[i, j] > thresh else "black") + + plt.tight_layout(1.5) + plt.ylabel('True/Gold') + plt.xlabel('Predicted') + plt.savefig(plot_file_name) + + class Result: """ Model Class for carrying Evaluation Data (F1 Score, Precision, Recall, ....) diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..394189c --- /dev/null +++ b/predict.py @@ -0,0 +1,16 @@ +import classifier +import testing.intent_predictor as pred + +import eval.metrics as metrics + +model_path = '/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4' +y_pred, y_true = pred.load_model_and_predict_test_data(model_path) + +confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred) + +print(confusion_matrix) + +plot_file_path = model_path+'/confusion_matrix_plot.png' +metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo", plot_file_path) + +print('Confusion Matrix Plot saved to :: ', plot_file_path) diff --git a/testing/__init__.py b/testing/__init__.py index e69de29..3fd94c3 100644 --- a/testing/__init__.py +++ b/testing/__init__.py @@ -0,0 +1,2 @@ +from utils.reader import * +from classifier.nn import * \ No newline at end of file diff --git a/testing/intent_predictor.py b/testing/intent_predictor.py index dd401e9..2c68414 100644 --- a/testing/intent_predictor.py +++ b/testing/intent_predictor.py @@ -1,17 +1,54 @@ +from typing import Dict, List, Tuple + from allennlp.common import JsonDict from allennlp.data import Instance from allennlp.predictors import Predictor from overrides import overrides +from allennlp.models import Model +from allennlp.data.dataset_readers import DatasetReader +from allennlp.models.archival import load_archive +from utils.reader import DataReaderJsonLines, CitationDataSetReader + +import os @Predictor.register('citation_intent_predictor') class IntentClassificationPredictor(Predictor): """"Predictor for Citation Intent Classifier""" + def predict(self, text: str, intent: str): + return self.predict_json({"citation_text": text, "intent": intent}) + @overrides def _json_to_instance(self, json_dict: JsonDict) -> Instance: - pass + return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"]) - @overrides - def predict_json(self, inputs: JsonDict) -> JsonDict: - pass + +def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[List[Dict[str, float]], list]: + """Make predictions using the given model and dataset reader""" + + predictor = IntentClassificationPredictor(model, dataset_reader) + + prediction_list = [] + true_list = [] + + jsonl_reader = DataReaderJsonLines(file_path) + for citation in jsonl_reader.read(): + true_list.append(citation.intent) + output = predictor.predict(citation.text, citation.intent) + prediction_list.append(output['prediction']) + + return prediction_list, true_list + + +def load_model_and_predict_test_data(saved_model_dir: str): + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + dev_file_path = project_root + '/data/jsonl/dev.jsonl' + test_file_path = project_root + '/data/jsonl/test.jsonl' + + model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz')) + citation_dataset_reader = CitationDataSetReader() + + y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path) + + return y_pred, y_true diff --git a/utils/constants.py b/utils/constants.py index 5389eae..091ead2 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -34,3 +34,4 @@ REGEX_CONSTANTS = { } CLASS_LABELS = {"background": 0, "method": 1, "result": 2} +CLASS_LABELS_LIST = ['background', 'method', 'result'] diff --git a/utils/reader.py b/utils/reader.py index ce606da..c208eac 100644 --- a/utils/reader.py +++ b/utils/reader.py @@ -75,8 +75,9 @@ class DataReaderJsonLines: This method opens the file, reads every line and returns a collection of lines :return: collection of Citation Objects, with the required data """ - for line in jsonlines.open(self.file_path): - yield read_json_line(line) + with jsonlines.open(self.file_path) as jl_reader: + for line in jl_reader: + yield read_json_line(line) def read_json_line(line):