-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
107 lines (82 loc) · 3.16 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os, time, argparse
import numpy as np
from PIL import Image
import glob
import torch
from torchvision.utils import save_image as imwrite
from utils import load_checkpoint, tensor2cuda
from model.models import Generator
import cv2
# 调用GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def addTransparency(img, factor=1):
img = img.convert('RGBA')
img_blender = Image.new('RGBA', img.size, (0, 0, 0, 0))
img = Image.blend(img_blender, img, factor)
return img
def main():
# 开关定义
parser = argparse.ArgumentParser(description="network pytorch")
# train
parser.add_argument("--model", type=str, default="./checkpoint/", help='checkpoint')
parser.add_argument("--model_name", type=str, default='Gmodel_40', help='model name')
# value
parser.add_argument("--intest", type=str, default="./input/", help='input syn path')
parser.add_argument("--outest", type=str, default="./output/", help='output syn path')
argspar = parser.parse_args()
print("\nnetwork pytorch")
for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
print('\t{}: {}'.format(p, v))
print('\n')
arg = parser.parse_args()
# train
print('> Loading Generator...')
name = arg.model_name
Gmodel_name = name + '.tar'
Dmodel_name = name + '.tar'
G_Model, _, _ = load_checkpoint(argspar.model, Generator, Gmodel_name, arg)
os.makedirs(arg.outest, exist_ok=True)
test(argspar, G_Model)
xishu = 0.75
def test(argspar, model):
# init
norm = lambda x: (x - 0.5) / 0.5
denorm = lambda x: (x + 1) / 2
files = os.listdir(argspar.intest)
time_test = []
model.eval()
# test
for i in range(len(files)):
haze = Image.open(argspar.intest + files[i])
x = haze.width
y = haze.height
haze = haze.resize((int(x * xishu), (int(y * xishu))))
print(x * xishu, y * xishu)
haze = np.array(haze.convert('RGB')) / 255
with torch.no_grad():
haze = torch.Tensor(haze.transpose(2, 0, 1)[np.newaxis, :, :, :]).cuda()
haze = tensor2cuda(haze)
starttime = time.time()
haze = norm(haze)
out, att = model(haze)
endtime1 = time.time()
out = denorm(out)
# out = out.resize((int(x), (int(y))))
# out = cv2.resize(out, (x, y))
imwrite(out, argspar.outest + files[i], value_range=(0, 1))
time_test.append(endtime1 - starttime)
print('The ' + str(i) + ' Time: %.4f s.' % (endtime1 - starttime))
# print('Mean Time: %.4f s.'%(time_test/len(time_test)))
path = './output/*.png'
for i in glob.glob(path):
im1 = Image.open(i)
im = Image.open(argspar.intest + files[0])
x = im1.width
y = im1.height
im1 = im1.resize((int(x / xishu), (int(y / xishu))))
if len(im.split()) == 4:
im2 = addTransparency(im1)
im2.save(os.path.join('./output/', os.path.basename(i)))
if __name__ == '__main__':
main()