-
Notifications
You must be signed in to change notification settings - Fork 12
/
utils.py
96 lines (77 loc) · 2.71 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio
# For logger
def to_np(x):
return x.data.cpu().numpy()
def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
# De-normalization
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
# Plot losses
def plot_loss(d_losses, g_losses, num_epochs, save=False, save_dir='results/', show=False):
fig, ax = plt.subplots()
ax.set_xlim(0, num_epochs)
ax.set_ylim(0, max(np.max(g_losses), np.max(d_losses))*1.1)
plt.xlabel('# of Epochs')
plt.ylabel('Loss values')
plt.plot(d_losses, label='Discriminator')
plt.plot(g_losses, label='Generator')
plt.legend()
# save figure
if save:
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_fn = save_dir + 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png'
plt.savefig(save_fn)
if show:
plt.show()
else:
plt.close()
def plot_test_result(input, target, gen_image, epoch, training=True, save=False, save_dir='results/', show=False, fig_size=(5, 5)):
if not training:
fig_size = (input.size(2) * 3 / 100, input.size(3)/100)
fig, axes = plt.subplots(1, 3, figsize=fig_size)
imgs = [input, gen_image, target]
for ax, img in zip(axes.flatten(), imgs):
ax.axis('off')
ax.set_adjustable('box-forced')
# Scale to 0-255
img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)
ax.imshow(img, cmap=None, aspect='equal')
plt.subplots_adjust(wspace=0, hspace=0)
if training:
title = 'Epoch {0}'.format(epoch + 1)
fig.text(0.5, 0.04, title, ha='center')
# save figure
if save:
if not os.path.exists(save_dir):
os.mkdir(save_dir)
if training:
save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch+1) + '.png'
else:
save_fn = save_dir + 'Test_result_{:d}'.format(epoch+1) + '.png'
fig.subplots_adjust(bottom=0)
fig.subplots_adjust(top=1)
fig.subplots_adjust(right=1)
fig.subplots_adjust(left=0)
plt.savefig(save_fn)
if show:
plt.show()
else:
plt.close()
# Make gif
def make_gif(dataset, num_epochs, save_dir='results/'):
gen_image_plots = []
for epoch in range(num_epochs):
# plot for generating gif
save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png'
gen_image_plots.append(imageio.imread(save_fn))
imageio.mimsave(save_dir + dataset + '_pix2pix_epochs_{:d}'.format(num_epochs) + '.gif', gen_image_plots, fps=5)