-
Notifications
You must be signed in to change notification settings - Fork 116
/
fake_quantize.py
125 lines (99 loc) · 4.22 KB
/
fake_quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from torch.nn import Module
from torch.quantization.observer import _with_args
class FakeQuantizeBFloat16(Module):
"""Simulate the quantize and dequantize operations in training time for bfloat16"""
def __init__(self, **observer_kwargs):
super(FakeQuantizeBFloat16, self).__init__()
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
def enable_fake_quant(self, enabled=True):
self.fake_quant_enabled[0] = 1 if enabled else 0
def disable_fake_quant(self):
self.enable_fake_quant(False)
def forward(self, X):
if self.fake_quant_enabled[0] == 1:
if isinstance(X, (tuple, list)):
return [self.forward(x) for x in X]
elif isinstance(X, torch.Tensor) and X.is_floating_point():
dtype = X.dtype
X = X.to(dtype=torch.bfloat16).to(dtype=dtype)
return X
with_args = classmethod(_with_args)
class FakeQuantizeTFLite(torch.quantization.FakeQuantize):
def forward(self, X):
observer_enabled = self.observer_enabled[0] == 1
fake_quant_enabled = self.fake_quant_enabled[0] == 1
if observer_enabled:
if fake_quant_enabled:
torch.quantization.disable_fake_quant(self)
X = super().forward(X)
if fake_quant_enabled:
torch.quantization.enable_fake_quant(self)
if fake_quant_enabled:
if observer_enabled:
torch.quantization.disable_observer(self)
X = X + self.scale * 1e-6 * torch.sign(X.detach())
X = super().forward(X)
if observer_enabled:
torch.quantization.enable_observer(self)
return X
def disable_fake_quant(mod):
"""
Disable fake quantization for this module, if applicable. Example usage::
# model is any PyTorch model
model.apply(tinynn.graph.quantization.disable_fake_quant)
"""
if isinstance(mod, FakeQuantizeBFloat16):
mod.disable_fake_quant()
def enable_fake_quant(mod):
"""
Enable fake quantization for this module, if applicable. Example usage::
# model is any PyTorch model
model.apply(tinynn.graph.quantization.disable_fake_quant)
"""
if isinstance(mod, FakeQuantizeBFloat16):
mod.enable_fake_quant()
class PTQFakeQuantize(torch.quantization.FakeQuantize):
"""Using fake-quantize to do PTQ, speed up quantization error analyze.
When doing calibrate, please enable observer and disable fake-quant,
when validating model with quantization error, please enable fake-quant and disable observer.
"""
def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
if self.fake_quant_enabled[0] == 1:
if self.scale == 1 and self.zero_point == 0:
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
if self.scale.shape != _scale.shape:
self.scale.resize_(_scale.shape)
self.zero_point.resize_(_zero_point.shape)
self.scale.copy_(_scale)
self.zero_point.copy_(_zero_point)
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max
)
else:
X = torch.fake_quantize_per_tensor_affine(
X, self.scale, self.zero_point, self.quant_min, self.quant_max
)
return X
def set_ptq_fake_quantize(name, module):
weight_fq = PTQFakeQuantize.with_args(
observer=torch.quantization.MinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
reduce_range=False,
)
asym_fq = PTQFakeQuantize.with_args(
observer=torch.quantization.HistogramObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
reduce_range=False,
)
qconfig_new = torch.quantization.QConfig(asym_fq, weight_fq)
return qconfig_new