Merge remote-tracking branch 'origin/master'

isaac
Pavan Mandava 5 years ago
commit 0467eeee5f

@ -58,7 +58,7 @@ class FeedForward(torch.nn.Module):
self.samples1 = samples1 self.samples1 = samples1
self.samples2 = samples2 self.samples2 = samples2
model.eval() # put into eval mode self.eval() # put into eval mode
# initialize training data # initialize training data
self.shuffle() self.shuffle()
@ -79,7 +79,7 @@ class FeedForward(torch.nn.Module):
self.optimizer.zero_grad() self.optimizer.zero_grad()
# forward pass # forward pass
y_pred = model(X_train[a:b]) y_pred = self.forward(self.X_train[a:b])
loss = self.criterion(y_pred, self.y_train[a:b]) loss = self.criterion(y_pred, self.y_train[a:b])
# backward pass # backward pass

Loading…
Cancel
Save