|
|
|
@ -36,21 +36,15 @@ def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str
|
|
|
|
vocab = model.vocab
|
|
|
|
vocab = model.vocab
|
|
|
|
|
|
|
|
|
|
|
|
jsonl_reader = DataReaderJsonLines(file_path)
|
|
|
|
jsonl_reader = DataReaderJsonLines(file_path)
|
|
|
|
i = 0
|
|
|
|
|
|
|
|
for citation in jsonl_reader.read():
|
|
|
|
for citation in jsonl_reader.read():
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
true_list.append(citation.intent)
|
|
|
|
true_list.append(citation.intent)
|
|
|
|
output = predictor.predict(citation.text, citation.intent)
|
|
|
|
output = predictor.predict(citation.text, citation.intent)
|
|
|
|
prediction_list.append(output['prediction'])
|
|
|
|
prediction_list.append(output['prediction'])
|
|
|
|
# prediction_list.append({vocab.get_token_from_index(label_id, 'labels'): prob
|
|
|
|
|
|
|
|
# for label_id, prob in enumerate(output['probabilities'])})
|
|
|
|
|
|
|
|
if i == 10:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return prediction_list, true_list
|
|
|
|
return prediction_list, true_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_run_predictions(saved_model_dir: str):
|
|
|
|
def load_model_and_predict_test_data(saved_model_dir: str):
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
dev_file_path = project_root + '/data/jsonl/dev.jsonl'
|
|
|
|
dev_file_path = project_root + '/data/jsonl/dev.jsonl'
|
|
|
|
test_file_path = project_root + '/data/jsonl/test.jsonl'
|
|
|
|
test_file_path = project_root + '/data/jsonl/test.jsonl'
|
|
|
|
@ -60,5 +54,4 @@ def load_model_and_run_predictions(saved_model_dir: str):
|
|
|
|
|
|
|
|
|
|
|
|
y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path)
|
|
|
|
y_pred, y_true = make_predictions(model_archive.model, citation_dataset_reader, test_file_path)
|
|
|
|
|
|
|
|
|
|
|
|
print('Predictions ', y_pred)
|
|
|
|
retun y_pred,y_true
|
|
|
|
print('True Labels ', y_true)
|
|
|
|
|
|
|
|
|