-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_pth.py
79 lines (64 loc) · 1.49 KB
/
get_pth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import os
import copy
from collections import OrderedDict
base_ckpt = './logs/CIFAR10/8'
ckpt_teacher = torch.load(os.path.join(base_ckpt, 'ckpt.pt'), map_location='cpu')
time_scale = ckpt_teacher['time_scale']
if time_scale == 1:
ckpt_teacher = ckpt_teacher['ema_model']
else:
ckpt_teacher = ckpt_teacher['net_model']
print(type(ckpt_teacher))
print(ckpt_teacher.keys())
od=OrderedDict()
down_map = {
'0':'0',
'2':'1',
'3':'2',
'5':'3',
'6':'4',
'8':'5',
'9':'6',
}
up_map = {
'0':'0',
'2':'1',
'3':'2',
'4':'3',
'6':'4',
'7':'5',
'8':'6',
'10':'7',
'11':'8',
'12':'9',
'14':'10',
}
for k, v in ckpt_teacher.items():
if 'downblocks' in k:
idx = k.split('.')[1]
if idx in down_map.keys():
new_k = k.replace(idx, down_map[idx], 1)
od[new_k] = v
elif 'upblocks' in k:
idx = k.split('.')[1]
if idx in up_map.keys():
new_k = k.replace(idx, up_map[idx], 1)
od[new_k] = v
else:
od[k] = v
print(od.keys())
ckpt_teacher = torch.load(os.path.join(base_ckpt, 'ckpt.pt'), map_location='cpu')
new_pt = OrderedDict()
for k, v in ckpt_teacher.items():
if time_scale == 1:
if k == 'ema_model':
new_pt[k] = od
else:
new_pt[k] = v
else:
if k == 'net_model':
new_pt[k] = od
else:
new_pt[k] = v
torch.save(new_pt, './logs/ours/8_small/ckpt.pt')