Model Testing Code added

isaac
Pavan Mandava 6 years ago
parent 89f6cfdf88
commit ce8b6684f7

@ -130,7 +130,7 @@ class MultiClassPerceptron:
for epoch in range(self.epochs): for epoch in range(self.epochs):
# get a random number within the size of training set # get a random number within the size of training set
rand_num = random.randint(0, train_len) rand_num = random.randint(0, train_len-1)
# pick a random data instance with the generated random number # pick a random data instance with the generated random number
inst = X_train[rand_num] inst = X_train[rand_num]

@ -4,6 +4,8 @@ from utils.csv import read_csv_file
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
train_file_path = project_root+'/data/tsv/train.tsv' train_file_path = project_root+'/data/tsv/train.tsv'
test_file_path = project_root+'/data/tsv/test.tsv'
print(train_file_path) print(train_file_path)
data = read_csv_file(csv_file_path=train_file_path, delimiter='\t') data = read_csv_file(csv_file_path=train_file_path, delimiter='\t')

@ -1,3 +1,31 @@
from classifier.linear_model import get_sample_weights_with_features from classifier.linear_model import MultiClassPerceptron
from utils.csv import read_csv_file
from eval.metrics import f1_score
import utils.constants as const
import os
print(get_sample_weights_with_features()) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
train_file_path = project_root+'/data/tsv/train.tsv'
test_file_path = project_root+'/data/tsv/test.tsv'
X_train_inst = read_csv_file(train_file_path, '\t')
labels = set([inst.true_label for inst in X_train_inst])
X_test_inst = read_csv_file(test_file_path, '\t')
epochs = int(len(X_train_inst)*0.75)
clf = MultiClassPerceptron(epochs, 1)
clf.fit(X_train=X_train_inst, labels=list(labels))
y_test = clf.predict(X_test_inst)
y_true = [inst.true_label for inst in X_test_inst]
f1_score_list = f1_score(y_true, y_test, labels, const.AVG_MICRO)
for result in f1_score_list:
result.print_result()

Loading…
Cancel
Save