From 931c99602d13bca28ec250ef5fc875485296ff00 Mon Sep 17 00:00:00 2001 From: Pavan Mandava Date: Sun, 28 Jun 2020 22:31:49 +0200 Subject: [PATCH] Added basic model class for LSTMs and config file for basic classifier --- classifier/nn.py | 11 +++++++++++ configs/basic_model.json | 30 ++++++++++++++++++++++++++++++ utils/reader.py | 1 + 3 files changed, 42 insertions(+) create mode 100644 classifier/nn.py create mode 100644 configs/basic_model.json diff --git a/classifier/nn.py b/classifier/nn.py new file mode 100644 index 0000000..2757837 --- /dev/null +++ b/classifier/nn.py @@ -0,0 +1,11 @@ +from typing import Dict + +import torch +from allennlp.models import Model + + +@Model.register("basic_bilstm_classifier") +class BiLstmClassifier(Model): + + def forward(self, *inputs) -> Dict[str, torch.Tensor]: + pass diff --git a/configs/basic_model.json b/configs/basic_model.json new file mode 100644 index 0000000..f7cb2a7 --- /dev/null +++ b/configs/basic_model.json @@ -0,0 +1,30 @@ +{ + "dataset_reader": { + "type": "citation_dataset_reader" + }, + "train_data_path": "data/jsonl/train.jsonl", + "validation_data_path": "data/jsonl/test.jsonl", + "test_data_path": "data/jsonl/test.jsonl", + "model": { + "type": "basic_bilstm_classifier", + "elmo_text_field_embedder": { + "tokens": { + "type": "embedding", + "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.100d.txt.gz", + "embedding_dim": 100, + "trainable": false + }, + "elmo": { + "type": "elmo_token_embedder", + "options_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", + "weight_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", + "do_layer_norm": true, + "dropout": 0.5 + } + } + }, + "trainer": { + "optimizer": "adam", + "num_epochs": 10 + } +} \ No newline at end of file diff --git a/utils/reader.py b/utils/reader.py index b19d6e5..7fc75e6 100644 --- a/utils/reader.py +++ b/utils/reader.py @@ -11,6 +11,7 @@ from allennlp.data.token_indexers import SingleIdTokenIndexer, ELMoTokenCharacte import utils.constants as const +@DatasetReader.register("citation_dataset_reader") # type for config files class CitationDataSetReader(DatasetReader): def __init__(self): super().__init__()