-
Notifications
You must be signed in to change notification settings - Fork 23
/
matmul_utils_4bit.py
139 lines (119 loc) · 4.73 KB
/
matmul_utils_4bit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import numpy as np
from gptq_llama import quant_cuda
# Global Buffer
buffer_mat_dic = {}
use_new = True
auto_switch = True
auto_switch_thd = 8
debug = False
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
else:
if buffer_mat_dic[shape_of_qweight].device != device:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
if buffer_mat_dic[shape_of_qweight].dtype != dtype:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(dtype=dtype)
return buffer_mat_dic[shape_of_qweight]
def _matmul4bit_v1(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
if debug:
print('_matmul4bit_v1')
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.half()
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
if debug:
print('_matmul4bit_v2')
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.half()
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2)
y = y.to(dtype)
return y.reshape(outshape)
def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
if debug:
print('_matmul4bit_v1_recons')
if not transpose:
assert qweight.shape[0] * 8 == x.shape[-1]
else:
assert qweight.shape[1] == x.shape[-1]
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant_cuda.vecquant4recons_v1(qweight, buffer, scales, zeros)
if not transpose:
output = torch.matmul(x, buffer)
else:
output = torch.matmul(x, buffer.T)
return output
def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False):
if debug:
print('_matmul4bit_v2_recons')
if not transpose:
assert qweight.shape[0] * 8 == x.shape[-1]
else:
assert qweight.shape[1] == x.shape[-1]
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
if not transpose:
output = torch.matmul(x, buffer)
if transpose:
output = torch.matmul(x, buffer.T)
return output
def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
if groupsize == -1:
# use v1
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales, zeros)
else:
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
else:
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
else:
# use v2
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, groupsize)
else:
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize)
else:
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize)
return output
def v2_to_v1(scales, zeros):
"""
Convert zeros in V2 model to V1 model when group_num = 1, for debugging
"""
assert zeros.shape[0] == 1
z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1))
z_buffer = torch.zeros((z_mat.shape[0] * 8, z_mat.shape[1]), dtype=torch.float16, device=zeros.device)
z_zeros = torch.zeros(z_mat.shape[1], dtype=torch.float16, device=zeros.device)
z_scales = torch.ones(z_mat.shape[1], dtype=torch.float16, device=zeros.device)
quant_cuda.vecquant4recons_v1(z_mat, z_buffer, z_scales, z_zeros)
z_buffer = z_buffer[:,0]
zeros_recons = z_buffer * scales + scales
return zeros_recons