|
|
|
@ -25,7 +25,7 @@ class BiLstmClassifier(Model):
|
|
|
|
self.elmo = elmo
|
|
|
|
self.elmo = elmo
|
|
|
|
self.use_elmo = use_input_elmo
|
|
|
|
self.use_elmo = use_input_elmo
|
|
|
|
self.text_field_embedder = text_field_embedder
|
|
|
|
self.text_field_embedder = text_field_embedder
|
|
|
|
self.num_classes = self.vocab.get_vocab_size("label")
|
|
|
|
self.num_classes = self.vocab.get_vocab_size("labels")
|
|
|
|
self.encoder = encoder
|
|
|
|
self.encoder = encoder
|
|
|
|
self.classifier_feed_forward = classifier_feedforward
|
|
|
|
self.classifier_feed_forward = classifier_feedforward
|
|
|
|
self.label_accuracy = CategoricalAccuracy()
|
|
|
|
self.label_accuracy = CategoricalAccuracy()
|
|
|
|
@ -33,7 +33,7 @@ class BiLstmClassifier(Model):
|
|
|
|
self.label_f1_metrics = {}
|
|
|
|
self.label_f1_metrics = {}
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(self.num_classes):
|
|
|
|
for i in range(self.num_classes):
|
|
|
|
self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="label")] = \
|
|
|
|
self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] = \
|
|
|
|
F1Measure(positive_label=i)
|
|
|
|
F1Measure(positive_label=i)
|
|
|
|
|
|
|
|
|
|
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
@ -86,7 +86,7 @@ class BiLstmClassifier(Model):
|
|
|
|
|
|
|
|
|
|
|
|
# compute F1 per label
|
|
|
|
# compute F1 per label
|
|
|
|
for i in range(self.num_classes):
|
|
|
|
for i in range(self.num_classes):
|
|
|
|
metric = self.label_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="label")]
|
|
|
|
metric = self.label_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="labels")]
|
|
|
|
metric(class_probabilities, label)
|
|
|
|
metric(class_probabilities, label)
|
|
|
|
output_dict['label'] = label
|
|
|
|
output_dict['label'] = label
|
|
|
|
|
|
|
|
|
|
|
|
@ -99,7 +99,7 @@ class BiLstmClassifier(Model):
|
|
|
|
class_probabilities = torch.nn.functional.softmax(output_dict['logits'], dim=-1)
|
|
|
|
class_probabilities = torch.nn.functional.softmax(output_dict['logits'], dim=-1)
|
|
|
|
predictions = class_probabilities.cpu().data.numpy()
|
|
|
|
predictions = class_probabilities.cpu().data.numpy()
|
|
|
|
argmax_indices = np.argmax(predictions, axis=-1)
|
|
|
|
argmax_indices = np.argmax(predictions, axis=-1)
|
|
|
|
label = [self.vocab.get_token_from_index(x, namespace="label")
|
|
|
|
label = [self.vocab.get_token_from_index(x, namespace="labels")
|
|
|
|
for x in argmax_indices]
|
|
|
|
for x in argmax_indices]
|
|
|
|
output_dict['probabilities'] = class_probabilities
|
|
|
|
output_dict['probabilities'] = class_probabilities
|
|
|
|
output_dict['positive_label'] = label
|
|
|
|
output_dict['positive_label'] = label
|
|
|
|
@ -118,8 +118,6 @@ class BiLstmClassifier(Model):
|
|
|
|
sum_f1 = 0.0
|
|
|
|
sum_f1 = 0.0
|
|
|
|
for name, metric in self.label_f1_metrics.items():
|
|
|
|
for name, metric in self.label_f1_metrics.items():
|
|
|
|
metric_val = metric.get_metric(reset)
|
|
|
|
metric_val = metric.get_metric(reset)
|
|
|
|
metric_dict[name + '_P'] = metric_val[0]
|
|
|
|
|
|
|
|
metric_dict[name + '_R'] = metric_val[1]
|
|
|
|
|
|
|
|
metric_dict[name + '_F1'] = metric_val[2]
|
|
|
|
metric_dict[name + '_F1'] = metric_val[2]
|
|
|
|
if name != 'none': # do not consider `none` label in averaging F1
|
|
|
|
if name != 'none': # do not consider `none` label in averaging F1
|
|
|
|
sum_f1 += metric_val[2]
|
|
|
|
sum_f1 += metric_val[2]
|
|
|
|
@ -127,7 +125,7 @@ class BiLstmClassifier(Model):
|
|
|
|
names = list(self.label_f1_metrics.keys())
|
|
|
|
names = list(self.label_f1_metrics.keys())
|
|
|
|
total_len = len(names) if 'none' not in names else len(names) - 1
|
|
|
|
total_len = len(names) if 'none' not in names else len(names) - 1
|
|
|
|
average_f1 = sum_f1 / total_len
|
|
|
|
average_f1 = sum_f1 / total_len
|
|
|
|
metric_dict['average_F1'] = average_f1
|
|
|
|
metric_dict['AVG_F1_Score'] = average_f1
|
|
|
|
|
|
|
|
|
|
|
|
return metric_dict
|
|
|
|
return metric_dict
|
|
|
|
|
|
|
|
|
|
|
|
|