From 804533bc231542f88296cfd7541c413aee9b0cec Mon Sep 17 00:00:00 2001 From: Pavan Mandava Date: Fri, 31 Jul 2020 17:33:03 +0200 Subject: [PATCH] Code documentation/comments for predictor --- eval/metrics.py | 11 +++++++ predict.py | 6 ++-- testing/intent_predictor.py | 59 ++++++++++++++++++++++++++++++++++--- 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/eval/metrics.py b/eval/metrics.py index f7cdb4c..741c29c 100644 --- a/eval/metrics.py +++ b/eval/metrics.py @@ -170,6 +170,9 @@ def calculate_f1_score(precision, recall): def get_confusion_matrix(y_true, y_pred): """ takes predicted labels and true labels as parameters and returns Confusion Matrix + + - uses sklearn metric s functions + :param y_true: True labels :param y_pred: Predicted labels :return: returns Confusion Matrix @@ -178,6 +181,14 @@ def get_confusion_matrix(y_true, y_pred): def plot_confusion_matrix(confusion_mat, classifier_name, plot_file_name): + """ + Saves the confusion matrix plot with the specified file name + + :param confusion_mat: takes Confusion Matrix as an argument + :param classifier_name: Classifier name + :param plot_file_name: file name (with path) to save + + """ plt.figure(figsize=(8, 6)) plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.get_cmap('Blues')) diff --git a/predict.py b/predict.py index 394189c..9b5b199 100644 --- a/predict.py +++ b/predict.py @@ -3,14 +3,14 @@ import testing.intent_predictor as pred import eval.metrics as metrics -model_path = '/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4' -y_pred, y_true = pred.load_model_and_predict_test_data(model_path) +saved_model_dir = '/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4' +y_pred, y_true = pred.load_model_and_predict_test_data(saved_model_dir) confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred) print(confusion_matrix) -plot_file_path = model_path+'/confusion_matrix_plot.png' +plot_file_path = saved_model_dir+'/confusion_matrix_plot.png' metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo", plot_file_path) print('Confusion Matrix Plot saved to :: ', plot_file_path) diff --git a/testing/intent_predictor.py b/testing/intent_predictor.py index 2c68414..54a82ee 100644 --- a/testing/intent_predictor.py +++ b/testing/intent_predictor.py @@ -14,41 +14,92 @@ import os @Predictor.register('citation_intent_predictor') class IntentClassificationPredictor(Predictor): - """"Predictor for Citation Intent Classifier""" + """ + ~~~Predictor for Citation Intent Classifier~~~ + + - This is just a wrapper class around AllenNLP Model + used for making predictions from the trained/saved model + + """ def predict(self, text: str, intent: str): + """ + This function can be called for each data point from the test dataset, + takes citation text and the target intent as parameters and + returns output dictionary from :func: `~classifier.nn.BiLstmClassifier.forward` method + + :param text: Citation text from test data + :param intent: target intent of the data point + :return: returns output dictionary from Model's forward method + """ return self.predict_json({"citation_text": text, "intent": intent}) @overrides def _json_to_instance(self, json_dict: JsonDict) -> Instance: + """ + we get a callback to this method from AllenNLP Predictor, + passes JsonDict as a parameter with the data that we passed to the prediction_json function earlier. + + And this callback should return the AllenNLP Instance with tokens and target label. + + :param json_dict: json dictionary data with text and intent label + :return: returns AllenNLP Instance with tokens(ELMo) and target label + """ return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"]) -def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[List[Dict[str, float]], list]: - """Make predictions using the given model and dataset reader""" +def make_predictions(model: Model, dataset_reader: DatasetReader, dataset_file_path: str) -> Tuple[list, list]: + """ + This function takes the pre-trained(saved) Model and DatasetReader(and dataset file path) as arguments + and returns a Tuple of prediction list and gold/true list. + + - Creates a predictor object with the pre-trained model and dataset reader. + - Read the data from the passed dataset file path and for each data point, use predictor to predict the intent + + :param model: a trained/saved AllenNLP Model + :param dataset_reader: Dataset reader object (for tokenizing text and creating Instances) + :param dataset_file_path: a dataset file path to make predictions + :return: returns a Tuple of prediction list and true labels list + """ + + # Create predictor class object predictor = IntentClassificationPredictor(model, dataset_reader) prediction_list = [] true_list = [] - jsonl_reader = DataReaderJsonLines(file_path) + # read JSON Lines file and Iterate through each datapoint to predict + jsonl_reader = DataReaderJsonLines(dataset_file_path) for citation in jsonl_reader.read(): true_list.append(citation.intent) output = predictor.predict(citation.text, citation.intent) prediction_list.append(output['prediction']) + # returns prediction list and gold labels list - Tuple return prediction_list, true_list def load_model_and_predict_test_data(saved_model_dir: str): + """ + + This function loads the saved model from the specified directory and calls make_predictions function. + + :param saved_model_dir: path of the saved AllenNLP model (typically from IMS common space) + + :return: returns a list of prediction list and true list + """ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) dev_file_path = project_root + '/data/jsonl/dev.jsonl' test_file_path = project_root + '/data/jsonl/test.jsonl' + # load the archived/saved model model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz')) + + # create dataset reader object citation_dataset_reader = CitationDataSetReader() + # make predictions y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path) return y_pred, y_true