diff --git a/classifier/nn_ff.py b/classifier/nn_ff.py index f478f1d..3ccda08 100644 --- a/classifier/nn_ff.py +++ b/classifier/nn_ff.py @@ -58,7 +58,7 @@ class FeedForward(torch.nn.Module): self.samples1 = samples1 self.samples2 = samples2 - model.eval() # put into eval mode + self.eval() # put into eval mode # initialize training data self.shuffle() @@ -79,7 +79,7 @@ class FeedForward(torch.nn.Module): self.optimizer.zero_grad() # forward pass - y_pred = model(X_train[a:b]) + y_pred = self.forward(self.X_train[a:b]) loss = self.criterion(y_pred, self.y_train[a:b]) # backward pass