-
Notifications
You must be signed in to change notification settings - Fork 0
/
ConvNN.py
152 lines (129 loc) · 5.14 KB
/
ConvNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Michael Segev
# Pierre Jacquier
# Albert Faucher
# Group 70
# COMP 551 MP3
# March 18 2019
import torch
import torch.nn as nn
import torch.optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
from helpers import *
class ConvNN(torch.nn.Module):
def __init__(self):
super(ConvNN, self).__init__() # call the inherited class constructor
print("Model: ConvNN")
# define the architecture of the neural network
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5), # output is 60x60
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.MaxPool2d(2, 2) # output is 30x30
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5), # output is 26x26
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.MaxPool2d(2, 2) # output is 13x13
)
self.linear1 = nn.Sequential(
torch.nn.Linear(64*13*13, 1000),
nn.ReLU(True)
)
self.linear2 = nn.Sequential(
torch.nn.Linear(1000, 200),
nn.ReLU(True)
)
self.linear3 = torch.nn.Linear(200, 10)
self.losses = []
self.accuracies = []
self.val_accuracies = []
self.loss_LPF = 2.3
self.criterion = None
self.optimizer = None
def init_optimizer(self):
# loss function
# self.criterion = torch.nn.MSELoss(reduction='sum')
self.criterion = torch.nn.CrossEntropyLoss()
# optimizer
lr = 1e-2
print("Learning rate: {}".format(lr))
# self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9)
def forward(self, x):
h = self.conv1(x)
h = self.conv2(h)
h = h.reshape(h.size(0), -1)
h = self.linear1(h)
h = self.linear2(h)
y_pred = self.linear3(h)
return y_pred
def train_batch(self, x, y):
# Forward pass: Compute predicted y by passing x to the model
y_pred = self(x)
# Compute and print loss
loss = self.criterion(y_pred, y)
self.losses.append(float(loss.data.item()))
# Record accuracy
total = y.size(0)
_, predicted = torch.max(y_pred.data, 1)
correct = (predicted == y).sum().item()
acc = correct / total
self.accuracies.append(acc)
# Reset gradients to zero, perform a backward pass, and update the weights.
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss, acc
def train_all_batches(self, x, y, batch_size, num_epochs, loss_target, device, x_val=[], y_val=[], val_skip=0):
# figure out how many batches we can make
num_batches = int(y.shape[0] / batch_size)
last_batch_size = batch_size
print("Number of batches = {}".format(num_batches))
if y.shape[0] % batch_size != 0:
num_batches += 1
last_batch_size = y.shape[0] % batch_size
for epoch in range(num_epochs):
if self.loss_LPF < loss_target:
print("reached loss target, ending early!")
break
for batch_num in range(num_batches):
# slice tensors according into requested batch
if batch_num == num_batches - 1:
# last batch logic!
# print("Last batch!")
current_batch_size = last_batch_size
else:
current_batch_size = batch_size
x_batch = torch.tensor(
x[batch_num * current_batch_size:batch_num * current_batch_size + current_batch_size],
dtype=torch.float32, requires_grad=True, device=device)
y_batch = torch.tensor(
y[batch_num * current_batch_size:batch_num * current_batch_size + current_batch_size],
dtype=torch.long, requires_grad=False, device=device)
loss, acc = self.train_batch(x_batch, y_batch)
self.loss_LPF = 0.01 * float(loss.data.item()) + 0.99*self.loss_LPF
val_acc = 0
if batch_num % ((val_skip + 1) * 40) == 0 and len(x_val) == len(y_val) and len(x_val) > 0:
val_acc = validate_data(self, x_val, y_val, device)
self.val_accuracies.append(val_acc)
if batch_num % 40 == 0:
toPrint = "Epoch: {}, Loss: {}, Acc: {}%".format(epoch, self.loss_LPF, round(acc * 100, 3))
if (val_acc > 0):
toPrint += ", ValAcc: {}%".format(round(val_acc * 100, 3))
print(toPrint)
def plot_loss(self):
plt.title('Loss over time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(self.losses)
plt.show()
def plot_acc(self):
plt.title('Accuracy over time')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.plot(self.accuracies)
plt.plot(self.val_accuracies)
plt.show()