-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
347 lines (307 loc) · 11.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import math
from dataclasses import dataclass
import torch
import yaml
from torch import nn
from torch.nn import functional as F
@dataclass
class Config:
# model
blocks: tuple[int]
channels: tuple[int]
stride: tuple[int]
context_size: int
dim: int
depth: int
num_heads: int
bias: bool
mlp_ratio: int
dropout: float
attn_dropout: float
# training
steps: int
eval_interval: int
batch_size: int
learning_rate: float
min_learning_rate: float
weight_decay: float
warmup_steps: int
@classmethod
def from_file(cls, file, **kwargs):
with open(file) as fh:
kwargs_from_file = yaml.safe_load(fh)
kwargs = {**kwargs_from_file, **kwargs} # override file config with kwargs
return cls(**kwargs)
class MaskedVisionTransformer(nn.Module):
num_meta_features = 4
num_outputs = 1
def __init__(self, config: Config, flash_attn=True):
super().__init__()
self.config = config
self.frame_encoder = FrameEncoder(config)
# self.frame_embed = nn.Linear(config.channels[-1], config.dim, bias=config.bias)
self.meta_embed = nn.Linear(self.num_meta_features, config.dim, bias=True)
self.pos_embed = nn.Parameter(get_1d_pos_embed(config.context_size, config.dim), requires_grad=False)
self.dropout = nn.Dropout(p=config.dropout)
self.blocks = nn.ModuleList([
TransformerBlock(
dim=config.dim,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
dropout=config.dropout,
attn_dropout=config.attn_dropout,
bias=config.bias,
flash_attn=flash_attn
) for _ in range(config.depth)
])
self.norm = LayerNorm(config.dim, bias=config.bias)
self.fc = nn.Linear(config.dim, self.num_outputs, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, config.dim))
nn.init.trunc_normal_(self.mask_token, mean=0., std=0.02)
self.init_weights()
def forward(self, video, metadata, targets=None, targets_mask=None, drop_ratio=0.):
B, N, C, H, W = video.shape
if targets_mask is None:
targets_mask = torch.ones(B, N, dtype=torch.bool, device=video.device)
if drop_ratio > 0:
targets_mask = torch.bernoulli((1 - drop_ratio) * targets_mask.float()).bool()
keep_ids = targets_mask.flatten().nonzero().squeeze(-1)
drop_ids = (~targets_mask).flatten().nonzero().squeeze(-1)
restore_ids = torch.argsort(torch.cat([keep_ids, drop_ids]))
video = video.view(B * N, C, H, W)[keep_ids]
video = self.frame_encoder(video)
# video = self.frame_embed(video)
metadata = metadata.view(B * N, -1)[keep_ids]
metadata = self.meta_embed(metadata)
video = video + metadata
video = torch.cat([video, self.mask_token.repeat(len(drop_ids), 1)])
video = torch.gather(video, dim=0, index=restore_ids.unsqueeze(-1).repeat(1, self.config.dim))
video = video.view(B, N, self.config.dim)
video = video + self.pos_embed[:, :N]
video = self.dropout(video)
for block in self.blocks:
video = block(video)
video = self.norm(video)
logits = self.fc(video).squeeze(dim=-1)
if targets is not None:
logits_ = torch.masked_select(logits, targets_mask)
targets_ = torch.masked_select(targets, targets_mask)
loss = (
F.mse_loss(logits_.sin(), targets_.sin()) +
F.mse_loss(logits_.cos(), targets_.cos())
)
return logits, loss
return logits
def get_optimizer(self, fused=False):
decay = set()
decay_modules = (nn.Linear, nn.Conv2d)
for module_name, module in self.named_modules():
for param_name, param in module.named_parameters():
param_name = f'{module_name}.{param_name}' if module_name else param_name
if isinstance(module, decay_modules) and param_name.endswith('weight'):
decay.add(param_name)
params = dict(self.named_parameters())
optim_groups = [
{'params': [param for name, param in params.items() if name in decay], 'weight_decay': self.config.weight_decay},
{'params': [param for name, param in params.items() if name not in decay], 'weight_decay': 0.}
]
optimizer = torch.optim.AdamW(
params=optim_groups,
lr=self.config.learning_rate,
fused=fused)
return optimizer
def init_weights(self):
for name, module in self.named_modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (GroupNorm, LayerNorm)):
nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
nn.init.trunc_normal_(module.weight, mean=0., std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def get_1d_pos_embed(context_size, dim):
assert dim % 2 == 0
position = torch.arange(context_size).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
pos_embed = torch.zeros(1, context_size, dim)
pos_embed[0, :, 0::2] = torch.sin(position * div_term)
pos_embed[0, :, 1::2] = torch.cos(position * div_term)
return pos_embed
class CausalSelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.,
attn_dropout=0., bias=True, flash_attn=True):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.attn_dropout = attn_dropout
self.flash_attn = flash_attn
self.W_qkv = nn.Linear(dim, 3 * dim, bias=bias)
self.W_o = nn.Linear(dim, dim, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, D = x.size()
qkv = self.W_qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.flash_attn:
attn = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.attn_dropout, is_causal=True)
else: # use standard implementation of attention to get attention weights for visualization
attn_mask = torch.ones(N, N, device=q.device).tril(diagonal=0)
attn_mask = attn_mask.masked_fill(attn_mask == 0, -float('inf'))
attn_weight = torch.softmax((q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, self.attn_dropout, train=self.training)
attn = attn_weight @ v
self.attn_weights = attn_weight
attn = attn.transpose(1, 2).reshape(B, N, D)
attn = self.dropout(self.W_o(attn))
return attn
class MLP(nn.Module):
def __init__(self, dim, mlp_ratio, dropout=0., bias=True):
super().__init__()
self.W_1 = nn.Linear(dim, mlp_ratio * dim, bias=bias)
self.W_2 = nn.Linear(mlp_ratio * dim, dim, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.W_2(F.gelu(self.W_1(x))))
class LayerNorm(nn.LayerNorm):
# same as nn.LayerNorm, but allows setting bias to None
def __init__(self, *args, **kwargs):
bias = kwargs.pop('bias', True)
super().__init__(*args, **kwargs)
if not bias:
self.register_parameter('bias', None)
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio, dropout=0.,
attn_dropout=0., bias=True, flash_attn=True):
super().__init__()
self.mha_norm = LayerNorm(dim, bias=bias)
self.mha = CausalSelfAttention(
dim=dim,
num_heads=num_heads,
dropout=dropout,
attn_dropout=attn_dropout,
bias=bias,
flash_attn=flash_attn
)
self.mlp_norm = LayerNorm(dim, bias=bias)
self.mlp = MLP(
dim=dim,
mlp_ratio=mlp_ratio,
dropout=dropout,
bias=bias
)
def forward(self, x):
# pre-norm residual units as per: https://arxiv.org/abs/1906.01787
x = x + self.mha(self.mha_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x
class FrameEncoder(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.patch_dropout = PatchDropout(patch_size=16, p=config.dropout)
in_channels = out_channels = config.channels[0]
self.conv1 = conv2d(
in_channels=3,
out_channels=out_channels,
kernel_size=7,
stride=2,
bias=config.bias
)
self.blocks = []
for out_channels, num_blocks, stride in zip(config.channels, config.blocks, config.stride):
self.blocks.append(ResidualBlock(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
kernel_size=3,
bias=config.bias
))
for _ in range(num_blocks - 1):
self.blocks.append(ResidualBlock(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
bias=config.bias
))
in_channels = out_channels
self.blocks = nn.ModuleList(self.blocks)
self.norm1 = norm2d(out_channels, bias=config.bias)
self.relu1 = nn.ReLU()
self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
def forward(self, x):
x = self.patch_dropout(x)
x = self.conv1(x)
for block in self.blocks:
x = block(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.global_pool(x).squeeze(dim=(-1, -2))
return x
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True):
super().__init__()
if stride > 1 or in_channels != out_channels:
self.shortcut = conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=bias)
else:
self.shortcut = nn.Identity()
self.norm1 = norm2d(in_channels, bias=bias)
self.relu1 = nn.ReLU()
self.conv1 = conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
self.norm2 = norm2d(out_channels, bias=bias)
self.relu2 = nn.ReLU()
self.conv2 = conv2d(out_channels, out_channels, kernel_size, bias=bias)
def forward(self, x):
shortcut = self.shortcut(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.relu2(x)
x = self.conv2(x)
out = x + shortcut
return out
class PatchDropout(nn.Module):
def __init__(self, patch_size, p=0.5):
super().__init__()
self.patch_size = patch_size
self.p = p
def forward(self, x):
if self.training and self.p > 0:
mask = drop_patches(x, self.patch_size, self.p)
return x * mask * (1.0 / (1 - self.p))
return x
def drop_patches(x, patch_size, drop_ratio):
B, C, H, W = x.size()
assert H == W
assert H % patch_size == 0
keep_ratio = 1 - drop_ratio
num_patches = H // patch_size
patches = torch.bernoulli( # randomly drop patches
torch.full((B, 1, num_patches, 1, num_patches, 1),
fill_value=keep_ratio, device=x.device)
)
patches = patches.repeat(1, 1, 1, patch_size, 1, patch_size) # upscale patches
mask = patches.reshape(B, 1, H, W) # reshape the patches into a mask
return mask
class GroupNorm(nn.GroupNorm):
# same as nn.GroupNorm, but allows setting bias to None
def __init__(self, *args, **kwargs):
bias = kwargs.pop('bias', True)
super().__init__(*args, **kwargs)
if not bias:
self.register_parameter('bias', None)
def conv2d(in_channels, out_channels, kernel_size, stride=1, bias=True, **kwargs):
assert kernel_size % 2 == 1 # for simplicity
padding = (kernel_size - 1) // 2
return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, **kwargs)
def norm2d(in_channels, num_groups=None, bias=True):
# replace batch normalization with group normalization as per: https://arxiv.org/abs/2003.00295
if num_groups is None:
num_groups = in_channels // 16
return GroupNorm(num_groups, in_channels, bias=bias)