parent
e9b1f31c49
commit
5daea1a2a8
@ -0,0 +1 @@
|
|||||||
|
classifier
|
||||||
@ -1,2 +0,0 @@
|
|||||||
from .nn import *
|
|
||||||
from utils.reader import *
|
|
||||||
@ -1,157 +0,0 @@
|
|||||||
from typing import Dict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from allennlp.common.checks import ConfigurationError
|
|
||||||
from allennlp.data import Vocabulary
|
|
||||||
from allennlp.models import Model
|
|
||||||
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder, FeedForward, Elmo
|
|
||||||
from allennlp.nn import util
|
|
||||||
from allennlp.training.metrics import CategoricalAccuracy, F1Measure
|
|
||||||
from overrides import overrides
|
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
@Model.register("basic_bilstm_classifier")
|
|
||||||
class BiLstmClassifier(Model):
|
|
||||||
|
|
||||||
def __init__(self, vocab: Vocabulary,
|
|
||||||
text_field_embedder: TextFieldEmbedder,
|
|
||||||
encoder: Seq2SeqEncoder,
|
|
||||||
classifier_feedforward: FeedForward,
|
|
||||||
elmo: Elmo = None,
|
|
||||||
use_input_elmo: bool = False):
|
|
||||||
super().__init__(vocab)
|
|
||||||
self.elmo = elmo
|
|
||||||
self.use_elmo = use_input_elmo
|
|
||||||
self.text_field_embedder = text_field_embedder
|
|
||||||
self.num_classes = self.vocab.get_vocab_size("label")
|
|
||||||
self.encoder = encoder
|
|
||||||
self.classifier_feed_forward = classifier_feedforward
|
|
||||||
self.label_accuracy = CategoricalAccuracy()
|
|
||||||
|
|
||||||
self.label_f1_metrics = {}
|
|
||||||
|
|
||||||
for i in range(self.num_classes):
|
|
||||||
self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="label")] = \
|
|
||||||
F1Measure(positive_label=i)
|
|
||||||
|
|
||||||
self.loss = torch.nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
self.attention = Attention(encoder.get_output_dim())
|
|
||||||
|
|
||||||
@overrides
|
|
||||||
def forward(self, tokens: Dict[str, torch.LongTensor],
|
|
||||||
label: torch.LongTensor) -> Dict[str, torch.LongTensor]:
|
|
||||||
|
|
||||||
global input_elmo
|
|
||||||
elmo_tokens = tokens.pop("elmo", None)
|
|
||||||
|
|
||||||
embedded_text = self.text_field_embedder(tokens)
|
|
||||||
text_mask = util.get_text_field_mask(tokens)
|
|
||||||
|
|
||||||
if elmo_tokens is not None:
|
|
||||||
tokens["elmo"] = elmo_tokens
|
|
||||||
|
|
||||||
# Create ELMo embeddings if applicable
|
|
||||||
if self.elmo:
|
|
||||||
if elmo_tokens is not None:
|
|
||||||
elmo_representations = self.elmo(elmo_tokens["elmo_tokens"])["elmo_representations"]
|
|
||||||
if self.use_elmo:
|
|
||||||
input_elmo = elmo_representations.pop()
|
|
||||||
assert not elmo_representations
|
|
||||||
else:
|
|
||||||
raise ConfigurationError("Model was built to use Elmo, but input text is not tokenized for Elmo.")
|
|
||||||
|
|
||||||
if self.use_elmo:
|
|
||||||
if embedded_text is not None:
|
|
||||||
embedded_text = torch.cat([embedded_text, input_elmo], dim=-1)
|
|
||||||
else:
|
|
||||||
embedded_text = input_elmo
|
|
||||||
|
|
||||||
encoded_text = self.encoder(embedded_text, text_mask)
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
attn_dist, encoded_text = self.attention(encoded_text, return_attn_distribution=True)
|
|
||||||
|
|
||||||
if label is not None:
|
|
||||||
logits = self.classifier_feed_forward(encoded_text)
|
|
||||||
class_probabilities = torch.nn.functional.softmax(logits, dim=1)
|
|
||||||
|
|
||||||
output_dict = {"logits": logits}
|
|
||||||
|
|
||||||
loss = self.loss(logits, label)
|
|
||||||
output_dict["loss"] = loss
|
|
||||||
|
|
||||||
# compute F1 per label
|
|
||||||
for i in range(self.num_classes):
|
|
||||||
metric = self.label_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="label")]
|
|
||||||
metric(class_probabilities, label)
|
|
||||||
output_dict['label'] = label
|
|
||||||
|
|
||||||
output_dict['tokens'] = tokens['tokens']
|
|
||||||
|
|
||||||
return output_dict
|
|
||||||
|
|
||||||
@overrides
|
|
||||||
def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
||||||
class_probabilities = torch.nn.functional.softmax(output_dict['logits'], dim=-1)
|
|
||||||
predictions = class_probabilities.cpu().data.numpy()
|
|
||||||
argmax_indices = np.argmax(predictions, axis=-1)
|
|
||||||
label = [self.vocab.get_token_from_index(x, namespace="label")
|
|
||||||
for x in argmax_indices]
|
|
||||||
output_dict['probabilities'] = class_probabilities
|
|
||||||
output_dict['positive_label'] = label
|
|
||||||
output_dict['prediction'] = label
|
|
||||||
citation_text = []
|
|
||||||
for batch_text in output_dict['tokens']:
|
|
||||||
citation_text.append([self.vocab.get_token_from_index(token_id.item()) for token_id in batch_text])
|
|
||||||
output_dict['tokens'] = citation_text
|
|
||||||
|
|
||||||
return output_dict
|
|
||||||
|
|
||||||
@overrides
|
|
||||||
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
|
|
||||||
metric_dict = {}
|
|
||||||
|
|
||||||
sum_f1 = 0.0
|
|
||||||
for name, metric in self.label_f1_metrics.items():
|
|
||||||
metric_val = metric.get_metric(reset)
|
|
||||||
metric_dict[name + '_P'] = metric_val[0]
|
|
||||||
metric_dict[name + '_R'] = metric_val[1]
|
|
||||||
metric_dict[name + '_F1'] = metric_val[2]
|
|
||||||
if name != 'none': # do not consider `none` label in averaging F1
|
|
||||||
sum_f1 += metric_val[2]
|
|
||||||
|
|
||||||
names = list(self.label_f1_metrics.keys())
|
|
||||||
total_len = len(names) if 'none' not in names else len(names) - 1
|
|
||||||
average_f1 = sum_f1 / total_len
|
|
||||||
metric_dict['average_F1'] = average_f1
|
|
||||||
|
|
||||||
return metric_dict
|
|
||||||
|
|
||||||
|
|
||||||
def new_parameter(*size):
|
|
||||||
out = Parameter(torch.FloatTensor(*size))
|
|
||||||
torch.nn.init.xavier_normal_(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(torch.nn.Module):
|
|
||||||
""" Simple multiplicative attention"""
|
|
||||||
|
|
||||||
def __init__(self, attention_size):
|
|
||||||
super(Attention, self).__init__()
|
|
||||||
self.attention = new_parameter(attention_size, 1)
|
|
||||||
|
|
||||||
def forward(self, x_in, reduction_dim=-2, return_attn_distribution=False):
|
|
||||||
# calculate attn weights
|
|
||||||
attn_score = torch.matmul(x_in, self.attention).squeeze()
|
|
||||||
# add one dimension at the end and get a distribution out of scores
|
|
||||||
attn_distrib = torch.nn.functional.softmax(attn_score.squeeze(), dim=-1).unsqueeze(-1)
|
|
||||||
scored_x = x_in * attn_distrib
|
|
||||||
weighted_sum = torch.sum(scored_x, dim=reduction_dim)
|
|
||||||
if return_attn_distribution:
|
|
||||||
return attn_distrib.reshape(x_in.shape[0], -1), weighted_sum
|
|
||||||
else:
|
|
||||||
return weighted_sum
|
|
||||||
@ -1,56 +0,0 @@
|
|||||||
{
|
|
||||||
"dataset_reader": {
|
|
||||||
"type": "citation_dataset_reader"
|
|
||||||
},
|
|
||||||
"train_data_path": "data/jsonl/train.jsonl",
|
|
||||||
"validation_data_path": "data/jsonl/test.jsonl",
|
|
||||||
"test_data_path": "data/jsonl/test.jsonl",
|
|
||||||
"model": {
|
|
||||||
"type": "basic_bilstm_classifier",
|
|
||||||
"text_field_embedder": {
|
|
||||||
"token_embedders": {
|
|
||||||
"tokens": {
|
|
||||||
"pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.100d.txt.gz",
|
|
||||||
"type": "embedding",
|
|
||||||
"embedding_dim": 100,
|
|
||||||
"trainable": false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"encoder": {
|
|
||||||
"type": "lstm",
|
|
||||||
"input_size": 1124,
|
|
||||||
"hidden_size": 100,
|
|
||||||
"num_layers": 1,
|
|
||||||
"bidirectional": true
|
|
||||||
},
|
|
||||||
"elmo": {
|
|
||||||
"options_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json",
|
|
||||||
"weight_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5",
|
|
||||||
"do_layer_norm": true,
|
|
||||||
"dropout": 0.5,
|
|
||||||
"num_output_representations": 1
|
|
||||||
},
|
|
||||||
"use_input_elmo": true,
|
|
||||||
"classifier_feedforward": {
|
|
||||||
"input_dim": 200,
|
|
||||||
"num_layers": 2,
|
|
||||||
"hidden_dims": [20, 3],
|
|
||||||
"activations": ["linear", "linear"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"data_loader": {
|
|
||||||
"batch_sampler": {
|
|
||||||
"type": "bucket",
|
|
||||||
"batch_size" : 16
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"trainer": {
|
|
||||||
"optimizer": {
|
|
||||||
"type": "adam",
|
|
||||||
"lr": 0.001
|
|
||||||
},
|
|
||||||
"num_epochs": 2,
|
|
||||||
"cuda_device": -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
from typing import Iterable
|
|
||||||
|
|
||||||
import jsonlines
|
|
||||||
from allennlp.data import Instance
|
|
||||||
from allennlp.data.dataset_readers import DatasetReader
|
|
||||||
from allennlp.data.fields import TextField, LabelField
|
|
||||||
from allennlp.data.token_indexers import SingleIdTokenIndexer, ELMoTokenCharactersIndexer
|
|
||||||
from allennlp.data.tokenizers import SpacyTokenizer
|
|
||||||
from overrides import overrides
|
|
||||||
|
|
||||||
from utils.data import Citation
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetReader.register("citation_dataset_reader") # type for config files
|
|
||||||
class CitationDataSetReader(DatasetReader):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.tokenizer = SpacyTokenizer()
|
|
||||||
|
|
||||||
@overrides
|
|
||||||
def _read(self, file_path: str) -> Iterable[Instance]:
|
|
||||||
ds_reader = DataReaderJsonLines(file_path)
|
|
||||||
for citation in ds_reader.read():
|
|
||||||
yield self.text_to_instance(citation_text=citation.text, intent=citation.intent)
|
|
||||||
|
|
||||||
@overrides
|
|
||||||
def text_to_instance(self, citation_text: str,
|
|
||||||
intent: str) -> Instance:
|
|
||||||
citation_tokens = self.tokenizer.tokenize(citation_text)
|
|
||||||
token_indexers = {"elmo": ELMoTokenCharactersIndexer(),
|
|
||||||
"tokens": SingleIdTokenIndexer()}
|
|
||||||
|
|
||||||
fields = {'tokens': TextField(citation_tokens, token_indexers),
|
|
||||||
'label': LabelField(intent)}
|
|
||||||
|
|
||||||
return Instance(fields)
|
|
||||||
|
|
||||||
|
|
||||||
class DataReaderJsonLines:
|
|
||||||
def __init__(self, file_path):
|
|
||||||
self.file_path = file_path
|
|
||||||
|
|
||||||
def read(self):
|
|
||||||
for line in jsonlines.open(self.file_path):
|
|
||||||
yield read_json_line(line)
|
|
||||||
|
|
||||||
|
|
||||||
def read_json_line(line):
|
|
||||||
citation = Citation(
|
|
||||||
text=line['string'],
|
|
||||||
citing_paper_id=line['citingPaperId'],
|
|
||||||
cited_paper_id=line['citedPaperId'],
|
|
||||||
section_title=line['sectionName'],
|
|
||||||
intent=line['label'],
|
|
||||||
citation_id=line['id'])
|
|
||||||
|
|
||||||
return citation
|
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
from .nn import *
|
||||||
|
from utils.reader import *
|
||||||
Loading…
Reference in new issue