diff --git a/classifier/nn.py b/classifier/nn.py index 5f1addf..7abab20 100644 --- a/classifier/nn.py +++ b/classifier/nn.py @@ -36,6 +36,10 @@ class BiLstmClassifier(Model): self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="label")] = \ F1Measure(positive_label=i) + print('No of classes :: ', self.num_classes) + for i in range(self.num_classes): + print('Token :: ', i, ' -> ', vocab.get_token_from_index(index=i, namespace="label")) + self.loss = torch.nn.CrossEntropyLoss() self.attention = Attention(encoder.get_output_dim()) @@ -74,11 +78,12 @@ class BiLstmClassifier(Model): # Attention attn_dist, encoded_text = self.attention(encoded_text, return_attn_distribution=True) + output_dict = {} if label is not None: logits = self.classifier_feed_forward(encoded_text) class_probabilities = torch.nn.functional.softmax(logits, dim=1) - output_dict = {"logits": logits} + output_dict["logits"] = logits loss = self.loss(logits, label) output_dict["loss"] = loss diff --git a/configs/basic_model.json b/configs/basic_model.json index 4887b66..369ef62 100644 --- a/configs/basic_model.json +++ b/configs/basic_model.json @@ -3,7 +3,7 @@ "type": "citation_dataset_reader" }, "train_data_path": "data/jsonl/train.jsonl", - "validation_data_path": "data/jsonl/test.jsonl", + "validation_data_path": "data/jsonl/dev.jsonl", "test_data_path": "data/jsonl/test.jsonl", "model": { "type": "basic_bilstm_classifier",