Added Predictor code

isaac
Pavan Mandava 5 years ago
parent 0faf344a00
commit 7d52d0629b

@ -0,0 +1,4 @@
import classifier
import testing.intent_predictor as pred
pred.load_model_and_run_predictions("/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4")

@ -0,0 +1,2 @@
from utils.reader import *
from classifier.nn import *

@ -1,17 +1,62 @@
from typing import Dict, List, Tuple
from allennlp.common import JsonDict from allennlp.common import JsonDict
from allennlp.data import Instance from allennlp.data import Instance
from allennlp.predictors import Predictor from allennlp.predictors import Predictor
from overrides import overrides from overrides import overrides
from allennlp.models import Model
from allennlp.data.dataset_readers import DatasetReader
from allennlp.models.archival import load_archive
from utils.reader import DataReaderJsonLines, CitationDataSetReader
import os
@Predictor.register('citation_intent_predictor') @Predictor.register('citation_intent_predictor')
class IntentClassificationPredictor(Predictor): class IntentClassificationPredictor(Predictor):
""""Predictor for Citation Intent Classifier""" """"Predictor for Citation Intent Classifier"""
def predict(self, text: str, intent: str):
return self.predict_json({"citation_text": text, "intent": intent})
@overrides @overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance: def _json_to_instance(self, json_dict: JsonDict) -> Instance:
pass return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"])
@overrides
def predict_json(self, inputs: JsonDict) -> JsonDict: def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[
pass List[Dict[str, float]], list]:
"""Make predictions using the given model and dataset reader"""
predictor = IntentClassificationPredictor(model, dataset_reader)
prediction_list = []
true_list = []
vocab = model.vocab
jsonl_reader = DataReaderJsonLines(file_path)
i = 0
for citation in jsonl_reader.read():
i += 1
true_list.append(citation.intent)
output = predictor.predict(citation.text, citation.intent)
prediction_list.append({vocab.get_token_from_index(label_id, 'labels'): prob
for label_id, prob in enumerate(output['probs'])})
if i == 10:
break
return prediction_list, true_list
def load_model_and_run_predictions(saved_model_dir: str):
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
train_file_path = project_root + '/data/tsv/train.tsv'
test_file_path = project_root + '/data/tsv/test.tsv'
model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz'))
citation_dataset_reader = CitationDataSetReader()
y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path)
print(y_pred)

Loading…
Cancel
Save