|
|
|
@ -36,6 +36,10 @@ class BiLstmClassifier(Model):
|
|
|
|
self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="label")] = \
|
|
|
|
self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="label")] = \
|
|
|
|
F1Measure(positive_label=i)
|
|
|
|
F1Measure(positive_label=i)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('No of classes :: ', self.num_classes)
|
|
|
|
|
|
|
|
for i in range(self.num_classes):
|
|
|
|
|
|
|
|
print('Token :: ', i, ' -> ', vocab.get_token_from_index(index=i, namespace="label"))
|
|
|
|
|
|
|
|
|
|
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
|
|
|
self.attention = Attention(encoder.get_output_dim())
|
|
|
|
self.attention = Attention(encoder.get_output_dim())
|
|
|
|
@ -74,11 +78,12 @@ class BiLstmClassifier(Model):
|
|
|
|
# Attention
|
|
|
|
# Attention
|
|
|
|
attn_dist, encoded_text = self.attention(encoded_text, return_attn_distribution=True)
|
|
|
|
attn_dist, encoded_text = self.attention(encoded_text, return_attn_distribution=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_dict = {}
|
|
|
|
if label is not None:
|
|
|
|
if label is not None:
|
|
|
|
logits = self.classifier_feed_forward(encoded_text)
|
|
|
|
logits = self.classifier_feed_forward(encoded_text)
|
|
|
|
class_probabilities = torch.nn.functional.softmax(logits, dim=1)
|
|
|
|
class_probabilities = torch.nn.functional.softmax(logits, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
output_dict = {"logits": logits}
|
|
|
|
output_dict["logits"] = logits
|
|
|
|
|
|
|
|
|
|
|
|
loss = self.loss(logits, label)
|
|
|
|
loss = self.loss(logits, label)
|
|
|
|
output_dict["loss"] = loss
|
|
|
|
output_dict["loss"] = loss
|
|
|
|
|