|
|
|
@ -51,8 +51,8 @@ def make_predictions(model: Model, dataset_reader: DatasetReader, file_path: str
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_run_predictions(saved_model_dir: str):
|
|
|
|
def load_model_and_run_predictions(saved_model_dir: str):
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
train_file_path = project_root + '/data/tsv/train.tsv'
|
|
|
|
dev_file_path = project_root + '/data/jsonl/dev.jsonl'
|
|
|
|
test_file_path = project_root + '/data/tsv/test.tsv'
|
|
|
|
test_file_path = project_root + '/data/jsonl/test.jsonl'
|
|
|
|
|
|
|
|
|
|
|
|
model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz'))
|
|
|
|
model_archive = load_archive(os.path.join(saved_model_dir, 'model.tar.gz'))
|
|
|
|
citation_dataset_reader = CitationDataSetReader()
|
|
|
|
citation_dataset_reader = CitationDataSetReader()
|
|
|
|
|