commit
02baf00dc5
@ -0,0 +1,16 @@
|
||||
import classifier
|
||||
import testing.intent_predictor as pred
|
||||
|
||||
import eval.metrics as metrics
|
||||
|
||||
model_path = '/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4'
|
||||
y_pred, y_true = pred.load_model_and_predict_test_data(model_path)
|
||||
|
||||
confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred)
|
||||
|
||||
print(confusion_matrix)
|
||||
|
||||
plot_file_path = model_path+'/confusion_matrix_plot.png'
|
||||
metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo", plot_file_path)
|
||||
|
||||
print('Confusion Matrix Plot saved to :: ', plot_file_path)
|
||||
@ -0,0 +1,2 @@
|
||||
from utils.reader import *
|
||||
from classifier.nn import *
|
||||
@ -1,17 +1,54 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from allennlp.common import JsonDict
|
||||
from allennlp.data import Instance
|
||||
from allennlp.predictors import Predictor
|
||||
from overrides import overrides
|
||||
from allennlp.models import Model
|
||||
from allennlp.data.dataset_readers import DatasetReader
|
||||
from allennlp.models.archival import load_archive
|
||||
from utils.reader import DataReaderJsonLines, CitationDataSetReader
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@Predictor.register('citation_intent_predictor')
|
||||
class IntentClassificationPredictor(Predictor):
|
||||
""""Predictor for Citation Intent Classifier"""
|
||||
|
||||
def predict(self, text: str, intent: str):
|
||||
return self.predict_json({"citation_text": text, "intent": intent})
|
||||
|
||||
@overrides
|
||||
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
|
||||
pass
|
||||
return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"])
|
||||
|
||||
@overrides
|
||||
def predict_json(self, inputs: JsonDict) -> JsonDict:
|
||||
pass
|
||||
|
||||
def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[List[Dict[str, float]], list]:
|
||||
"""Make predictions using the given model and dataset reader"""
|
||||
|
||||
predictor = IntentClassificationPredictor(model, dataset_reader)
|
||||
|
||||
prediction_list = []
|
||||
true_list = []
|
||||
|
||||
jsonl_reader = DataReaderJsonLines(file_path)
|
||||
for citation in jsonl_reader.read():
|
||||
true_list.append(citation.intent)
|
||||
output = predictor.predict(citation.text, citation.intent)
|
||||
prediction_list.append(output['prediction'])
|
||||
|
||||
return prediction_list, true_list
|
||||
|
||||
|
||||
def load_model_and_predict_test_data(saved_model_dir: str):
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
dev_file_path = project_root + '/data/jsonl/dev.jsonl'
|
||||
test_file_path = project_root + '/data/jsonl/test.jsonl'
|
||||
|
||||
model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz'))
|
||||
citation_dataset_reader = CitationDataSetReader()
|
||||
|
||||
y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path)
|
||||
|
||||
return y_pred, y_true
|
||||
|
||||
Loading…
Reference in new issue