-
Notifications
You must be signed in to change notification settings - Fork 1
/
joint_bilateral_filter_layer.py
212 lines (182 loc) · 11.4 KB
/
joint_bilateral_filter_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Trainable joint bilateral filter layer.
Author: Fabian Wagner
Contact: fabian.wagner@fau.de
"""
import torch
import torch.nn as nn
import jointbilateralfilter_cpu_lib
import jointbilateralfilter_gpu_lib
class JointBilateralFilterFunction3dCPU(torch.autograd.Function):
"""
3D Differentiable joint bilateral filter to remove noise while preserving edges. C++ accelerated layer (CPU).
See:
Paris, S. (2007). A gentle introduction to bilateral filtering and its applications: https://dl.acm.org/doi/pdf/10.1145/1281500.1281604
Args:
input_img: input tensor: [B, C, X, Y, Z]
guidance_img: guidance tensor: [B, C, X, Y, Z]
sigma_x: standard deviation of the spatial blur in x direction.
sigma_y: standard deviation of the spatial blur in y direction.
sigma_z: standard deviation of the spatial blur in z direction.
color_sigma: standard deviation of the range kernel.
Returns:
output (torch.Tensor): Filtered tensor.
"""
@staticmethod
def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma):
assert len(input_img.shape) == 5, "Input tensor shape of 3d joint bilateral filter layer must equal [B, C, X, Y, Z]."
assert len(guidance_img.shape) == 5, "Guidance tensor shape of 3d joint bilateral filter layer must equal [B, C, X, Y, Z]."
assert input_img.shape[1] == 1, "Currently channel dimensions >1 are not supported."
assert input_img.shape == guidance_img.shape, "Shape of input tensor must equal shape of guidance tensor."
# Use c++ implementation for better performance.
outputTensor, outputWeightsTensor, dO_dz_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z = jointbilateralfilter_cpu_lib.forward_3d_cpu(input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma)
ctx.save_for_backward(input_img,
sigma_x,
sigma_y,
sigma_z,
color_sigma,
outputTensor,
outputWeightsTensor,
dO_dz_ki,
dO_dsig_r,
dO_dsig_x,
dO_dsig_y,
dO_dsig_z,
guidance_img) # save for backward
return outputTensor
@staticmethod
def backward(ctx, grad_output):
grad_sig_x = None
grad_sig_y = None
grad_sig_z = None
grad_color_sigma = None
input_img = ctx.saved_tensors[0] # input image
sigma_x = ctx.saved_tensors[1]
sigma_y = ctx.saved_tensors[2]
sigma_z = ctx.saved_tensors[3]
color_sigma = ctx.saved_tensors[4]
outputTensor = ctx.saved_tensors[5] # filtered image
outputWeightsTensor = ctx.saved_tensors[6] # weights
dO_dz_ki = ctx.saved_tensors[7] # derivative of output with respect to input, while k==i
dO_dsig_r = ctx.saved_tensors[8] # derivative of output with respect to range sigma
dO_dsig_x = ctx.saved_tensors[9] # derivative of output with respect to sigma x
dO_dsig_y = ctx.saved_tensors[10] # derivative of output with respect to sigma y
dO_dsig_z = ctx.saved_tensors[11] # derivative of output with respect to sigma z
guidance_img = ctx.saved_tensors[12] # guidance image
# calculate gradient with respect to the sigmas
grad_color_sigma = torch.sum(grad_output * dO_dsig_r)
grad_sig_x = torch.sum(grad_output * dO_dsig_x)
grad_sig_y = torch.sum(grad_output * dO_dsig_y)
grad_sig_z = torch.sum(grad_output * dO_dsig_z)
grad_output_tensor, grad_guidance_tensor = jointbilateralfilter_cpu_lib.backward_3d_cpu(grad_output,
input_img,
guidance_img,
outputTensor,
outputWeightsTensor,
dO_dz_ki,
sigma_x,
sigma_y,
sigma_z,
color_sigma)
return grad_output_tensor, grad_guidance_tensor, grad_sig_x, grad_sig_y, grad_sig_z, grad_color_sigma
class JointBilateralFilterFunction3dGPU(torch.autograd.Function):
"""
3D Differentiable joint bilateral filter to remove noise while preserving edges. CUDA accelerated layer.
See:
Paris, S. (2007). A gentle introduction to bilateral filtering and its applications: https://dl.acm.org/doi/pdf/10.1145/1281500.1281604
Args:
input_img: input tensor: [B, C, X, Y, Z]
guidance_img: guidance tensor: [B, C, X, Y, Z]
sigma_x: standard deviation of the spatial blur in x direction.
sigma_y: standard deviation of the spatial blur in y direction.
sigma_z: standard deviation of the spatial blur in z direction.
color_sigma: standard deviation of the range kernel.
Returns:
output (torch.Tensor): Filtered tensor.
"""
@staticmethod
def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma):
assert len(input_img.shape) == 5, "Input tensor shape of 3d joint bilateral filter layer must equal [B, C, X, Y, Z]."
assert len(guidance_img.shape) == 5, "Guidance tensor shape of 3d joint bilateral filter layer must equal [B, C, X, Y, Z]."
assert input_img.shape[1] == 1, "Currently channel dimensions >1 are not supported."
assert input_img.shape == guidance_img.shape, "Shape of input tensor must equal shape of guidance tensor."
# Use c++ implementation for better performance.
outputTensor, outputWeightsTensor, dO_dz_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z = jointbilateralfilter_gpu_lib.forward_3d_gpu(input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma)
ctx.save_for_backward(input_img,
sigma_x,
sigma_y,
sigma_z,
color_sigma,
outputTensor,
outputWeightsTensor,
dO_dz_ki,
dO_dsig_r,
dO_dsig_x,
dO_dsig_y,
dO_dsig_z,
guidance_img) # save for backward
return outputTensor
@staticmethod
def backward(ctx, grad_output):
grad_sig_x = None
grad_sig_y = None
grad_sig_z = None
grad_color_sigma = None
input_img = ctx.saved_tensors[0] # input image
sigma_x = ctx.saved_tensors[1]
sigma_y = ctx.saved_tensors[2]
sigma_z = ctx.saved_tensors[3]
color_sigma = ctx.saved_tensors[4]
outputTensor = ctx.saved_tensors[5] # filtered image
outputWeightsTensor = ctx.saved_tensors[6] # weights
dO_dz_ki = ctx.saved_tensors[7] # derivative of output with respect to input, while k==i
dO_dsig_r = ctx.saved_tensors[8] # derivative of output with respect to range sigma
dO_dsig_x = ctx.saved_tensors[9] # derivative of output with respect to sigma x
dO_dsig_y = ctx.saved_tensors[10] # derivative of output with respect to sigma y
dO_dsig_z = ctx.saved_tensors[11] # derivative of output with respect to sigma z
guidance_img = ctx.saved_tensors[12] # guidance image
# calculate gradient with respect to the sigmas
grad_color_sigma = torch.sum(grad_output * dO_dsig_r)
grad_sig_x = torch.sum(grad_output * dO_dsig_x)
grad_sig_y = torch.sum(grad_output * dO_dsig_y)
grad_sig_z = torch.sum(grad_output * dO_dsig_z)
grad_output_tensor, grad_guidance_tensor = jointbilateralfilter_gpu_lib.backward_3d_gpu(grad_output,
input_img,
guidance_img,
outputTensor,
outputWeightsTensor,
dO_dz_ki,
sigma_x,
sigma_y,
sigma_z,
color_sigma)
return grad_output_tensor, grad_guidance_tensor, grad_sig_x, grad_sig_y, grad_sig_z, grad_color_sigma
class JointBilateralFilter3d(nn.Module):
def __init__(self, sigma_x, sigma_y, sigma_z, color_sigma, use_gpu=True):
super(JointBilateralFilter3d, self).__init__()
self.use_gpu = use_gpu
# make sigmas trainable parameters
self.sigma_x = nn.Parameter(torch.tensor(sigma_x))
self.sigma_y = nn.Parameter(torch.tensor(sigma_y))
self.sigma_z = nn.Parameter(torch.tensor(sigma_z))
self.color_sigma = nn.Parameter(torch.tensor(color_sigma))
def forward(self, input_tensor, guidance_tensor):
assert len(input_tensor.shape) == 5, "Input tensor shape of 3d joint bilateral filter layer must equal [B, C, X, Y, Z]."
assert len(guidance_tensor.shape) == 5, "Guidance tensor shape of 3d joint bilateral filter layer must equal [B, C, X, Y, Z]."
assert input_tensor.shape[1] == 1, "Currently channel dimensions >1 are not supported."
assert input_tensor.shape == guidance_tensor.shape, "Shape of input tensor must equal shape of guidance tensor."
# Choose between CPU processing and CUDA acceleration.
if self.use_gpu:
return JointBilateralFilterFunction3dGPU.apply(input_tensor,
guidance_tensor,
self.sigma_x,
self.sigma_y,
self.sigma_z,
self.color_sigma)
else:
return JointBilateralFilterFunction3dCPU.apply(input_tensor,
guidance_tensor,
self.sigma_x,
self.sigma_y,
self.sigma_z,
self.color_sigma)