commit
02baf00dc5
@ -0,0 +1,16 @@
|
|||||||
|
import classifier
|
||||||
|
import testing.intent_predictor as pred
|
||||||
|
|
||||||
|
import eval.metrics as metrics
|
||||||
|
|
||||||
|
model_path = '/mount/arbeitsdaten/studenten1/team-lab-nlp/mandavsi_rileyic/saved_models/experiment_4'
|
||||||
|
y_pred, y_true = pred.load_model_and_predict_test_data(model_path)
|
||||||
|
|
||||||
|
confusion_matrix = metrics.get_confusion_matrix(y_true, y_pred)
|
||||||
|
|
||||||
|
print(confusion_matrix)
|
||||||
|
|
||||||
|
plot_file_path = model_path+'/confusion_matrix_plot.png'
|
||||||
|
metrics.plot_confusion_matrix(confusion_matrix, "BiLSTM Classifier + Attention with ELMo", plot_file_path)
|
||||||
|
|
||||||
|
print('Confusion Matrix Plot saved to :: ', plot_file_path)
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
from utils.reader import *
|
||||||
|
from classifier.nn import *
|
||||||
@ -1,17 +1,54 @@
|
|||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from allennlp.common import JsonDict
|
from allennlp.common import JsonDict
|
||||||
from allennlp.data import Instance
|
from allennlp.data import Instance
|
||||||
from allennlp.predictors import Predictor
|
from allennlp.predictors import Predictor
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
from allennlp.models import Model
|
||||||
|
from allennlp.data.dataset_readers import DatasetReader
|
||||||
|
from allennlp.models.archival import load_archive
|
||||||
|
from utils.reader import DataReaderJsonLines, CitationDataSetReader
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
@Predictor.register('citation_intent_predictor')
|
@Predictor.register('citation_intent_predictor')
|
||||||
class IntentClassificationPredictor(Predictor):
|
class IntentClassificationPredictor(Predictor):
|
||||||
""""Predictor for Citation Intent Classifier"""
|
""""Predictor for Citation Intent Classifier"""
|
||||||
|
|
||||||
|
def predict(self, text: str, intent: str):
|
||||||
|
return self.predict_json({"citation_text": text, "intent": intent})
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
|
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
|
||||||
pass
|
return self._dataset_reader.text_to_instance(json_dict["citation_text"], json_dict["intent"])
|
||||||
|
|
||||||
@overrides
|
|
||||||
def predict_json(self, inputs: JsonDict) -> JsonDict:
|
def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str) -> Tuple[List[Dict[str, float]], list]:
|
||||||
pass
|
"""Make predictions using the given model and dataset reader"""
|
||||||
|
|
||||||
|
predictor = IntentClassificationPredictor(model, dataset_reader)
|
||||||
|
|
||||||
|
prediction_list = []
|
||||||
|
true_list = []
|
||||||
|
|
||||||
|
jsonl_reader = DataReaderJsonLines(file_path)
|
||||||
|
for citation in jsonl_reader.read():
|
||||||
|
true_list.append(citation.intent)
|
||||||
|
output = predictor.predict(citation.text, citation.intent)
|
||||||
|
prediction_list.append(output['prediction'])
|
||||||
|
|
||||||
|
return prediction_list, true_list
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_and_predict_test_data(saved_model_dir: str):
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
dev_file_path = project_root + '/data/jsonl/dev.jsonl'
|
||||||
|
test_file_path = project_root + '/data/jsonl/test.jsonl'
|
||||||
|
|
||||||
|
model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz'))
|
||||||
|
citation_dataset_reader = CitationDataSetReader()
|
||||||
|
|
||||||
|
y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path)
|
||||||
|
|
||||||
|
return y_pred, y_true
|
||||||
|
|||||||
Loading…
Reference in new issue