forked from IST-DASLab/sparsegpt
-
Notifications
You must be signed in to change notification settings - Fork 4
/
llama.py
283 lines (239 loc) · 8.43 KB
/
llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import math
import time
import torch
import torch.nn as nn
import transformers
from sparsegpt import *
from modelutils import *
# bandaid fix
dev = torch.device("cuda")
def get_llama(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto')
model.seqlen = 2048
return model
@torch.no_grad()
def llama_sequential(model, dataloader, dev):
print('Starting ...')
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.layers
model.model.embed_tokens = model.model.embed_tokens.to(dev)
model.model.norm = model.model.norm.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {'i': 0, 'attention_mask': None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
raise ValueError
layers[0] = Catcher(layers[0])
for batch in dataloader:
try:
model(batch[0].to(dev))
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.embed_tokens = model.model.embed_tokens.cpu()
model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask']
print('Ready.')
for i in range(len(layers)):
layer = layers[i].to(dev)
subset = find_layers(layer)
gpts = {}
for name in subset:
if (not (args.minlayer <= i < args.maxlayer and args.prune_only in name)) == (not args.invert):
continue
gpts[name] = SparseGPT(subset[name])
def add_batch(name):
def tmp(_, inp, out):
gpts[name].add_batch(inp[0].data, out.data)
return tmp
handles = []
for name in gpts:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
for h in handles:
h.remove()
for name in gpts:
print(i, name)
print('pruning ...')
gpts[name].fasterprune(
args.sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp
)
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del gpts
torch.cuda.empty_cache()
inps, outs = outs, inps
model.config.use_cache = use_cache
@torch.no_grad()
def llama_eval(model, testenc, dev):
print('Evaluating ...')
testenc = testenc.input_ids
nsamples = testenc.numel() // model.seqlen
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.layers
model.model.embed_tokens = model.model.embed_tokens.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {'i': 0, 'attention_mask': None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
raise ValueError
layers[0] = Catcher(layers[0])
for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
try:
model(batch)
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.embed_tokens = model.model.embed_tokens.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask']
for i in range(len(layers)):
print(i)
layer = layers[i].to(dev)
if args.gmp:
subset = find_layers(layer)
for name in subset:
W = subset[name].weight.data
thresh = torch.sort(torch.abs(W.flatten()))[0][int(W.numel() * args.sparsity)]
W.data[torch.abs(W.data) <= thresh] = 0
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del layer
torch.cuda.empty_cache()
inps, outs = outs, inps
# Here!
if model.model.norm is not None:
model.model.norm = model.model.norm.to(dev)
model.lm_head = model.lm_head.to(dev)
testenc = testenc.to(dev)
nlls = []
for i in range(nsamples):
hidden_states = inps[i].unsqueeze(0)
if model.model.norm is not None:
hidden_states = model.model.norm(hidden_states)
lm_logits = model.lm_head(hidden_states)
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = testenc[
:, (i * model.seqlen):((i + 1) * model.seqlen)
][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * model.seqlen
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
print(ppl.item())
model.config.use_cache = use_cache
if __name__ == '__main__':
import argparse
from datautils import *
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str,
help='LLaMA model to load, pass `decapoda-research/llama-65b-hf`'
)
parser.add_argument(
'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from'
)
parser.add_argument(
'--seed',
type=int, default=0, help='Seed for sampling the calibration data'
)
parser.add_argument(
'--nsamples', type=int, default=128,
help='Number of calibration data samples.'
)
parser.add_argument(
'--percdamp', type=float, default=.01,
help='Percent of the average Hessian diagonal to use for dampening.'
)
parser.add_argument(
'--sparsity', type=float, default=0,
help='Target sparsity'
)
parser.add_argument(
'--prunen', type=int, default=0,
help='N for N:M pruning.'
)
parser.add_argument(
'--prunem', type=int, default=0,
help='M for N:M pruning.'
)
parser.add_argument(
'--gmp', action='store_true',
help='Whether to run the GMP baseline.'
)
parser.add_argument(
'--minlayer', type=int, default=-1,
help='Prune all layers with id >= this.'
)
parser.add_argument(
'--maxlayer', type=int, default=1000,
help='Prune all layers with id < this.'
)
parser.add_argument(
'--prune_only', type=str, default='',
help='Prune only layers that contain this text.'
)
parser.add_argument(
'--invert', action='store_true',
help='Invert subset.'
)
args = parser.parse_args()
model = get_llama(args.model)
model.eval()
dataloader, testloader = get_loaders(
args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
)
if (args.sparsity or args.prunen) and not args.gmp:
tick = time.time()
llama_sequential(model, dataloader, DEV)
for n, p in model.named_parameters():
print(n, torch.mean((p == 0).float()))
if 'dense_4h_to_h' in n:
break
print(time.time() - tick)
for dataset in ['wikitext2', 'ptb', 'c4']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen
)
print(dataset)
llama_eval(model, testloader, DEV)