From 7d52d0629b6fe3e890f46c2ed4bc85c23d6aa60b Mon Sep 17 00:00:00 2001 From: Pavan Mandava Date: Wed, 29 Jul 2020 19:05:31 +0200 Subject: [PATCH] Added Predictor code --- run.py | 4 +++ testing/__init__.py | 2 ++ testing/intent_predictor.py | 53 ++++++++++++++++++++++++++++++++++--- 3 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 run.py diff --git a/run.py b/run.py new file mode 100644 index 0000000..8e616e0 --- /dev/null +++ b/run.py @@ -0,0 +1,4 @@ +import classifier +import testing.intent_predictor as pred + +pred.load_model_and_run_predictions("/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4") \ No newline at end of file diff --git a/testing/__init__.py b/testing/__init__.py index e69de29..3fd94c3 100644 --- a/testing/__init__.py +++ b/testing/__init__.py @@ -0,0 +1,2 @@ +from utils.reader import * +from classifier.nn import * \ No newline at end of file diff --git a/testing/intent_predictor.py b/testing/intent_predictor.py index dd401e9..16d2603 100644 --- a/testing/intent_predictor.py +++ b/testing/intent_predictor.py @@ -1,17 +1,62 @@ +from typing import Dict, List, Tuple + from allennlp.common import JsonDict from allennlp.data import Instance from allennlp.predictors import Predictor from overrides import overrides +from allennlp.models import Model +from allennlp.data.dataset_readers import DatasetReader +from allennlp.models.archival import load_archive +from utils.reader import DataReaderJsonLines, CitationDataSetReader + +import os @Predictor.register('citation_intent_predictor') class IntentClassificationPredictor(Predictor): """"Predictor for Citation Intent Classifier""" + def predict(self, text: str, intent: str): + return self.predict_json({"citation_text": text, "intent": intent}) + @overrides def _json_to_instance(self, json_dict: JsonDict) -> Instance: - pass + return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"]) - @overrides - def predict_json(self, inputs: JsonDict) -> JsonDict: - pass + +def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[ + List[Dict[str, float]], list]: + """Make predictions using the given model and dataset reader""" + + predictor = IntentClassificationPredictor(model, dataset_reader) + + prediction_list = [] + true_list = [] + + vocab = model.vocab + + jsonl_reader = DataReaderJsonLines(file_path) + i = 0 + for citation in jsonl_reader.read(): + i += 1 + true_list.append(citation.intent) + output = predictor.predict(citation.text, citation.intent) + prediction_list.append({vocab.get_token_from_index(label_id, 'labels'): prob + for label_id, prob in enumerate(output['probs'])}) + if i == 10: + break + + return prediction_list, true_list + + +def load_model_and_run_predictions(saved_model_dir: str): + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + train_file_path = project_root + '/data/tsv/train.tsv' + test_file_path = project_root + '/data/tsv/test.tsv' + + model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz')) + citation_dataset_reader = CitationDataSetReader() + + y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path) + + print(y_pred)