diff --git a/classifier/nn_ff.py b/classifier/nn_ff.py index 21c9024..1177c51 100644 --- a/classifier/nn_ff.py +++ b/classifier/nn_ff.py @@ -8,7 +8,12 @@ from utils.nn_reader import read_csv_nn class Feedforward(torch.nn.Module): + """ + Creates and trains a basic feedforward neural network. + """ + def __init__(self, input_size, hidden_size, output_size): + """ Sets up all basic elements of NN. """ super(Feedforward, self).__init__() self.input_size = input_size self.hidden_size = hidden_size @@ -20,6 +25,7 @@ class Feedforward(torch.nn.Module): self.softmax = torch.nn.Softmax(dim=1) def forward(self, x): + """ Computes output from a given input x. """ hidden = self.fc1(x) relu = self.relu(hidden) output = self.fc2(relu) @@ -29,6 +35,8 @@ class Feedforward(torch.nn.Module): if __name__=='__main__': + """ Reads in the data, then trains and evaluates the neural network. """ + X_train, y_train, X_test = read_csv_nn() X_train = torch.FloatTensor(X_train)