-
Notifications
You must be signed in to change notification settings - Fork 3
/
model.py
159 lines (122 loc) · 6.19 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Single-scale super-resolution neural model and its internal functional modules.
The model is a simplification of NAS's SingleNetwork.
References: https://github.com/kaist-ina/NAS_public
"""
__author__ = "Yihang Wu"
import math
import torch.nn as nn
class SingleNetwork(nn.Module):
VALID_SCALES = (1, 2, 3, 4) # the up-scaling factors of images
def __init__(self, scale, num_blocks, num_channels, num_features, bias=True, activation=nn.ReLU(True)):
"""
A single-scale single-image super-resolution neural model.
Args:
scale (int): up-scaling factor. The width of hr image is [scale] times larger than that of a lr image
num_blocks (int): the number of residual blocks
num_channels (int): the number of channels in an image, default 3 for BGR color channels
num_features (int): the number of channels used throughout convolutional computations
bias (bool): whether to use bias in convolutional layers
activation (nn.Module): activate function used in residual blocks
"""
super(SingleNetwork, self).__init__()
self.scale = scale
self.num_blocks = num_blocks
self.num_channels = num_channels
self.num_features = num_features
if self.scale not in SingleNetwork.VALID_SCALES:
raise NotImplementedError
# No early-exit implemented
# Head of model
self.head = nn.Sequential(nn.Conv2d(in_channels=self.num_channels, out_channels=self.num_features,
kernel_size=3, stride=1, padding=1, bias=bias))
# Body of model - consecutive residual blocks
self.body = nn.ModuleList() # ModuleList does not have a forward method
for _ in range(self.num_blocks):
self.body.append(nn.Sequential(ResidualBlock(num_feats=self.num_features, bias=bias, act=activation)))
self.body_end = nn.Sequential(nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features,
kernel_size=3, stride=1, padding=1, bias=bias))
# Upsampling
if self.scale > 1:
self.upsampler = nn.Sequential(Upsampler(scale=self.scale, num_feats=self.num_features, bias=bias))
# Tail of model
self.tail = nn.Sequential(nn.Conv2d(in_channels=self.num_features, out_channels=self.num_channels,
kernel_size=3, stride=1, padding=1, bias=bias))
def forward(self, x):
"""
input shape (*, num_channels, input_height, input_width)
output shape (*, num_channels, target_height, target_width)
"""
x = self.head(x) # (*, num_features, input_height, input_width)
res = x # global residual
for i in range(self.num_blocks):
res = self.body[i](res)
res = self.body_end(res)
res += x # residual connection
x = res # (*, num_features, input_height, input_width)
if self.scale > 1:
x = self.upsampler(x) # (*, num_features, target_height, target_width)
x = self.tail(x) # (*, num_channels, target_height, target_width)
return x
class ResidualBlock(nn.Module):
def __init__(self, num_feats: int, bias: bool = True, batch_norm: bool = False, act: nn.Module = nn.ReLU(True),
residual_scale=1):
"""
The residual block in SingleNetwork, which is a stack of Conv, ReLU, Conv and Sum layers.
Args:
num_feats (int): the number of channels of the convolutional kernel
bias (bool): whether to use bias in convolutional layers, default true
batch_norm (bool): whether to apply batch normalization after convolutional layers, default false (different from SRResNet)
act (nn.Module): activation function
residual_scale (float): the factor to scale the residual
"""
super(ResidualBlock, self).__init__()
modules = []
for i in range(2):
modules.append(nn.Conv2d(in_channels=num_feats, out_channels=num_feats, kernel_size=3, stride=1, padding=1, bias=bias))
if batch_norm:
modules.append(nn.BatchNorm2d(num_feats))
if i == 0:
modules.append(act)
self.block = nn.Sequential(*modules)
self.residual_scale = residual_scale
def forward(self, x):
if self.residual_scale != 1: # scale the residual
res = self.block(x).mul(self.residual_scale)
else:
res = self.block(x)
res += x # residual connection
return res
class Upsampler(nn.Module):
def __init__(self, scale: int, num_feats: int, bias: bool = True, batch_norm: bool = False, act: nn.Module = None):
"""
This module up-samples the inputs to target outputs in terms of a specified scaling factor
Args:
scale (int): scaling factor
num_feats (int): the number of channels of the image
bias (bool): whether to use bias in convolutional layers, default true
batch_norm (bool): whether to apply batch normalization, default false
act (nn.Module): activation function
"""
super(Upsampler, self).__init__()
modules = []
if scale & (scale - 1) == 0: # scale = 1, 2, 4
for _ in range(int(math.log(scale, 2))):
modules.append(nn.Conv2d(in_channels=num_feats, out_channels=4 * num_feats, kernel_size=3, stride=1, padding=1, bias=bias))
modules.append(nn.PixelShuffle(2))
if batch_norm:
modules.append(nn.BatchNorm2d(num_feats))
if act:
modules.append(act)
elif scale == 3:
modules.append(nn.Conv2d(in_channels=num_feats, out_channels=9 * num_feats, kernel_size=3, stride=1, padding=1, bias=bias))
modules.append(nn.PixelShuffle(3))
if batch_norm:
modules.append(nn.BatchNorm2d(num_feats))
if act:
modules.append(act)
else:
raise NotImplementedError
self.upsampler = nn.Sequential(*modules)
def forward(self, x):
return self.upsampler(x)