minor fixes

isaac
Pavan Mandava 6 years ago
parent b8fd2e047f
commit a6793e1585

@ -36,6 +36,10 @@ class BiLstmClassifier(Model):
self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="label")] = \
F1Measure(positive_label=i)
print('No of classes :: ', self.num_classes)
for i in range(self.num_classes):
print('Token :: ', i, ' -> ', vocab.get_token_from_index(index=i, namespace="label"))
self.loss = torch.nn.CrossEntropyLoss()
self.attention = Attention(encoder.get_output_dim())
@ -74,11 +78,12 @@ class BiLstmClassifier(Model):
# Attention
attn_dist, encoded_text = self.attention(encoded_text, return_attn_distribution=True)
output_dict = {}
if label is not None:
logits = self.classifier_feed_forward(encoded_text)
class_probabilities = torch.nn.functional.softmax(logits, dim=1)
output_dict = {"logits": logits}
output_dict["logits"] = logits
loss = self.loss(logits, label)
output_dict["loss"] = loss

@ -3,7 +3,7 @@
"type": "citation_dataset_reader"
},
"train_data_path": "data/jsonl/train.jsonl",
"validation_data_path": "data/jsonl/test.jsonl",
"validation_data_path": "data/jsonl/dev.jsonl",
"test_data_path": "data/jsonl/test.jsonl",
"model": {
"type": "basic_bilstm_classifier",

Loading…
Cancel
Save