-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
173 lines (153 loc) · 6.07 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
rlsn 2024
"""
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
from transformers.models.vit.modeling_vit import ViTPooler, ViTEncoder
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import ViTConfig
import torch
import torch.nn as nn
import numpy as np
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, downsample=None):
super().__init__()
self.conv1 = nn.Conv3d(
in_channels,
out_channels,
kernel_size=[3,3,3],
stride=stride,
padding=1,
bias=False
)
self.downsample = downsample
self.bn1 = nn.BatchNorm3d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(
out_channels,
out_channels,
kernel_size=[3,3,3],
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm3d(out_channels)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class CNNFeatureExtractor(nn.Module):
def __init__(self, config):
super().__init__()
patch_size = config.patch_size
image_size = config.image_size
self.in_channels = 64
self.out_size = [3, 8, 8]
self.conv1 = nn.Conv3d(config.num_channels, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm3d(self.in_channels)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(64, 2)
self.layer2 = self._make_layer(128, 2, stride=2)
self.layer3 = self._make_layer(256, 2, stride=2)
# self.avgpool = nn.AdaptiveAvgPool3d(self.out_size)
def _make_layer(self, num_channels, num_layers, stride = 1):
downsample = None
if stride != 1:
downsample = nn.Sequential(
nn.Conv3d(self.in_channels, num_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm3d(num_channels),
)
layers = []
layers.append(ResBlock(self.in_channels, num_channels, stride, downsample))
self.in_channels = num_channels
for _ in range(1, num_layers):
layers.append(ResBlock(self.in_channels, num_channels, 1))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
# x = self.avgpool(x)
return x
class PosEmbedding(nn.Module):
def __init__(self, config, in_channels, in_size):
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.seq_len = np.prod(in_size)
self.projection = nn.Linear(in_channels, config.hidden_size)
self.position_embeddings = nn.Parameter(torch.randn(1, self.seq_len + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x):
batch_size, C, D, W, H = x.shape
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = x.flatten(2).transpose(1,2)
x = self.projection(x)
embeddings = torch.cat((cls_tokens, x), dim=1)
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, num_layers):
super().__init__()
layers = []
for _ in range(num_layers-1):
layers.append(nn.Linear(in_dim, in_dim))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(in_dim, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
x = self.layers(x)
return x
class VitDet3D(PreTrainedModel):
config_class = ViTConfig
def __init__(self, config, add_pooling_layer = True):
super().__init__(config)
self.cnn = CNNFeatureExtractor(config)
self.embeddings = PosEmbedding(config, self.cnn.in_channels, self.cnn.out_size)
self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
self.classification_head = MLP(config.hidden_size, config.num_labels, 3)
self.bbox_head = MLP(config.hidden_size, 6, 3)
self.config = config
def forward(self, pixel_values, labels=None, bbox=None):
feature_maps = self.cnn(pixel_values)
embeddings = self.embeddings(feature_maps)
encoder_outputs = self.encoder(embeddings)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
logits = self.classification_head(pooled_output)
bbox_pred = self.bbox_head(pooled_output)
if labels is not None and bbox is not None:
loss_bbox_fn = MSELoss(reduction='none')
if self.config.num_labels == 1:
loss_cls_fn = BCEWithLogitsLoss()
loss = loss_cls_fn(logits.view(-1), labels.float())
else:
loss_cls_fn = CrossEntropyLoss()
loss = loss_cls_fn(logits, labels)
mask = labels.unsqueeze(-1).bool()
mse_loss = loss_bbox_fn(bbox_pred, bbox)*mask
loss += mse_loss.mean()
else:
loss = None
return ModelOutput(
loss=loss,
logits=logits,
bbox=bbox_pred,
last_hidden_state=sequence_output,
pooler_output=pooled_output,
)