-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
65 lines (52 loc) · 2.17 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# model.py
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
def CBR(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.enc1 = CBR(in_channels, 64)
self.enc2 = CBR(64, 128)
self.enc3 = CBR(128, 256)
self.enc4 = CBR(256, 512)
self.pool = nn.MaxPool2d(2)
self.bottleneck = CBR(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.dec4 = CBR(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = CBR(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = CBR(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = CBR(128, 64)
self.conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(self.pool(enc1))
enc3 = self.enc3(self.pool(enc2))
enc4 = self.enc4(self.pool(enc3))
bottleneck = self.bottleneck(self.pool(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.dec4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.dec3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.dec2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.dec1(dec1)
return torch.sigmoid(self.conv(dec1))
# If you want to make the model creation easier, you can add a function like this:
def create_unet(in_channels=3, out_channels=1):
return UNet(in_channels, out_channels)