-
Notifications
You must be signed in to change notification settings - Fork 32
/
FusionNet.py
executable file
·110 lines (82 loc) · 3.61 KB
/
FusionNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from Basic_blocks import *
class Conv_residual_conv(nn.Module):
def __init__(self,in_dim,out_dim,act_fn):
super(Conv_residual_conv,self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
act_fn = act_fn
self.conv_1 = conv_block(self.in_dim,self.out_dim,act_fn)
self.conv_2 = conv_block_3(self.out_dim,self.out_dim,act_fn)
self.conv_3 = conv_block(self.out_dim,self.out_dim,act_fn)
def forward(self,input):
conv_1 = self.conv_1(input)
conv_2 = self.conv_2(conv_1)
res = conv_1 + conv_2
conv_3 = self.conv_3(res)
return conv_3
class FusionGenerator(nn.Module):
def __init__(self,input_nc, output_nc, ngf):
super(FusionGenerator,self).__init__()
self.in_dim = input_nc
self.out_dim = ngf
self.final_out_dim = output_nc
act_fn = nn.LeakyReLU(0.2, inplace=True)
act_fn_2 = nn.ReLU()
print("\n------Initiating FusionNet------\n")
# encoder
self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn)
self.pool_1 = maxpool()
self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn)
self.pool_2 = maxpool()
self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn)
self.pool_3 = maxpool()
self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn)
self.pool_4 = maxpool()
# bridge
self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn)
# decoder
self.deconv_1 = conv_trans_block(self.out_dim * 16, self.out_dim * 8, act_fn_2)
self.up_1 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2)
self.deconv_2 = conv_trans_block(self.out_dim * 8, self.out_dim * 4, act_fn_2)
self.up_2 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2)
self.deconv_3 = conv_trans_block(self.out_dim * 4, self.out_dim * 2, act_fn_2)
self.up_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2)
self.deconv_4 = conv_trans_block(self.out_dim * 2, self.out_dim, act_fn_2)
self.up_4 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2)
# output
self.out = nn.Conv2d(self.out_dim,self.final_out_dim, kernel_size=3, stride=1, padding=1)
self.out_2 = nn.Tanh()
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def forward(self,input):
down_1 = self.down_1(input)
pool_1 = self.pool_1(down_1)
down_2 = self.down_2(pool_1)
pool_2 = self.pool_2(down_2)
down_3 = self.down_3(pool_2)
pool_3 = self.pool_3(down_3)
down_4 = self.down_4(pool_3)
pool_4 = self.pool_4(down_4)
bridge = self.bridge(pool_4)
deconv_1 = self.deconv_1(bridge)
skip_1 = (deconv_1 + down_4)/2
up_1 = self.up_1(skip_1)
deconv_2 = self.deconv_2(up_1)
skip_2 = (deconv_2 + down_3)/2
up_2 = self.up_2(skip_2)
deconv_3 = self.deconv_3(up_2)
skip_3 = (deconv_3 + down_2)/2
up_3 = self.up_3(skip_3)
deconv_4 = self.deconv_4(up_3)
skip_4 = (deconv_4 + down_1)/2
up_4 = self.up_4(skip_4)
out = self.out(up_4)
out = self.out_2(out)
#out = torch.clamp(out, min=-1, max=1)
return out