From 69f913c801b0610556759eb9cbae71591277711d Mon Sep 17 00:00:00 2001 From: yelircaasi Date: Sun, 2 Aug 2020 22:52:02 +0200 Subject: [PATCH] cleaned up FeedForward class to pass test --- classifier/nn_ff.py | 6 ++++-- testing/ff_model_testing.py | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/classifier/nn_ff.py b/classifier/nn_ff.py index 3ccda08..24accaf 100644 --- a/classifier/nn_ff.py +++ b/classifier/nn_ff.py @@ -24,6 +24,7 @@ class FeedForward(torch.nn.Module): self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size) self.sigmoid = torch.nn.Sigmoid() self.softmax = torch.nn.Softmax(dim=1) + self.read_data() def forward(self, x): """ Computes output from a given input x. """ @@ -35,7 +36,8 @@ class FeedForward(torch.nn.Module): def read_data(self): """" Reads in training and test data and converts it to proper format. """ - self.X_train_, self.y_train_, self.X_test_ = read_csv_nn() + self.X_train_, self.y_train_, self.X_test = read_csv_nn() + self.X_test = torch.FloatTensor(self.X_test) yclass = np.array([(x[1] == 1) + 2 * (x[2] == 1) for x in self.y_train_]) is0 = yclass == 0 is1 = yclass == 1 @@ -118,7 +120,7 @@ class FeedForward(torch.nn.Module): p0 = torch.randperm(self.l0) p1 = torch.randperm(self.l1) p2 = torch.randperm(self.l2) - n = self.l0 + self.l1 + self.l2 + n = self.samples0 + self.samples1 + self.samples2 p = torch.randperm(n) # sample and shuffle data diff --git a/testing/ff_model_testing.py b/testing/ff_model_testing.py index 3313924..a290a61 100644 --- a/testing/ff_model_testing.py +++ b/testing/ff_model_testing.py @@ -1,3 +1,6 @@ +import sys +import os +sys.path.append(os.getcwd()) from classifier.nn_ff import FeedForward