added a few comments

isaac
yelircaasi 5 years ago
parent eb8a225c9b
commit 12c9610f0b

@ -8,7 +8,12 @@ from utils.nn_reader import read_csv_nn
class Feedforward(torch.nn.Module):
"""
Creates and trains a basic feedforward neural network.
"""
def __init__(self, input_size, hidden_size, output_size):
""" Sets up all basic elements of NN. """
super(Feedforward, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
@ -20,6 +25,7 @@ class Feedforward(torch.nn.Module):
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, x):
""" Computes output from a given input x. """
hidden = self.fc1(x)
relu = self.relu(hidden)
output = self.fc2(relu)
@ -29,6 +35,8 @@ class Feedforward(torch.nn.Module):
if __name__=='__main__':
""" Reads in the data, then trains and evaluates the neural network. """
X_train, y_train, X_test = read_csv_nn()
X_train = torch.FloatTensor(X_train)

Loading…
Cancel
Save