-
Notifications
You must be signed in to change notification settings - Fork 5
/
model.py
106 lines (68 loc) · 2.42 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
import wavemix
from wavemix import Level1Waveblock,Level2Waveblock, Level3Waveblock, DWTForward
class WaveMixModule(nn.Module):
def __init__(
self,
*,
depth,
mult = 2,
ff_channel = 16,
final_dim = 16,
dropout = 0.,
):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, int(final_dim/2), 3, 1, 1),
nn.Conv2d(int(final_dim/2), final_dim, 3, 2, 1),
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
# self.layers.append(Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
self.depthconv = nn.Sequential(
nn.Conv2d(final_dim, final_dim, 5, groups=final_dim, padding="same"),
nn.GELU(),
nn.BatchNorm2d(final_dim)
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(final_dim*2, int(final_dim/2), 4, stride = 2, padding = 1),
nn.BatchNorm2d(int(final_dim/2))
)
self.decoder2 = nn.Sequential(
nn.Conv2d(int(final_dim/2) + 3, 3, 1),
)
def forward(self, img, mask):
x = torch.cat([img, mask], dim=1)
x = self.conv(img)
skip1 = x
for attn in self.layers:
x = attn(x) + x
x = self.depthconv(x)
x = torch.cat([x, skip1], dim=1) # skip connection
x = self.decoder1(x)
x = torch.cat([x, img], dim=1) # skip connection
x = self.decoder2(x)
return x
class WavePaint(nn.Module):
def __init__(
self,
*,
num_modules= 1,
blocks_per_module = 7,
mult = 4,
ff_channel = 16,
final_dim = 16,
dropout = 0.,
):
super().__init__()
self.wavemodules = nn.ModuleList([])
for _ in range(num_modules):
self.wavemodules.append(WaveMixModule(depth = blocks_per_module, mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
def forward(self, img, mask):
x = img
for module in self.wavemodules:
x = module(x, 1-mask) + x
x = x*mask + img
return x