Added Predictor code

isaac
Pavan Mandava 5 years ago
parent 0faf344a00
commit 7d52d0629b

@ -0,0 +1,4 @@
import classifier
import testing.intent_predictor as pred
pred.load_model_and_run_predictions("/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4")

@ -0,0 +1,2 @@
from utils.reader import *
from classifier.nn import *

@ -1,17 +1,62 @@
from typing import Dict, List, Tuple
from allennlp.common import JsonDict
from allennlp.data import Instance
from allennlp.predictors import Predictor
from overrides import overrides
from allennlp.models import Model
from allennlp.data.dataset_readers import DatasetReader
from allennlp.models.archival import load_archive
from utils.reader import DataReaderJsonLines, CitationDataSetReader
import os
@Predictor.register('citation_intent_predictor')
class IntentClassificationPredictor(Predictor):
""""Predictor for Citation Intent Classifier"""
def predict(self, text: str, intent: str):
return self.predict_json({"citation_text": text, "intent": intent})
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
pass
return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"])
@overrides
def predict_json(self, inputs: JsonDict) -> JsonDict:
pass
def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[
List[Dict[str, float]], list]:
"""Make predictions using the given model and dataset reader"""
predictor = IntentClassificationPredictor(model, dataset_reader)
prediction_list = []
true_list = []
vocab = model.vocab
jsonl_reader = DataReaderJsonLines(file_path)
i = 0
for citation in jsonl_reader.read():
i += 1
true_list.append(citation.intent)
output = predictor.predict(citation.text, citation.intent)
prediction_list.append({vocab.get_token_from_index(label_id, 'labels'): prob
for label_id, prob in enumerate(output['probs'])})
if i == 10:
break
return prediction_list, true_list
def load_model_and_run_predictions(saved_model_dir: str):
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
train_file_path = project_root + '/data/tsv/train.tsv'
test_file_path = project_root + '/data/tsv/test.tsv'
model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz'))
citation_dataset_reader = CitationDataSetReader()
y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path)
print(y_pred)

Loading…
Cancel
Save