You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
citation-analysis/testing/intent_predictor.py

55 lines
1.9 KiB

from typing import Dict, List, Tuple
from allennlp.common import JsonDict
from allennlp.data import Instance
from allennlp.predictors import Predictor
from overrides import overrides
from allennlp.models import Model
from allennlp.data.dataset_readers import DatasetReader
from allennlp.models.archival import load_archive
from utils.reader import DataReaderJsonLines, CitationDataSetReader
import os
@Predictor.register('citation_intent_predictor')
class IntentClassificationPredictor(Predictor):
""""Predictor for Citation Intent Classifier"""
def predict(self, text: str, intent: str):
return self.predict_json({"citation_text": text, "intent": intent})
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"])
def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[List[Dict[str, float]], list]:
"""Make predictions using the given model and dataset reader"""
predictor = IntentClassificationPredictor(model, dataset_reader)
prediction_list = []
true_list = []
jsonl_reader = DataReaderJsonLines(file_path)
for citation in jsonl_reader.read():
true_list.append(citation.intent)
output = predictor.predict(citation.text, citation.intent)
prediction_list.append(output['prediction'])
return prediction_list, true_list
def load_model_and_predict_test_data(saved_model_dir: str):
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
dev_file_path = project_root + '/data/jsonl/dev.jsonl'
test_file_path = project_root + '/data/jsonl/test.jsonl'
model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz'))
citation_dataset_reader = CitationDataSetReader()
y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path)
return y_pred, y_true