-
Notifications
You must be signed in to change notification settings - Fork 3
/
pix2pix_train.py
136 lines (91 loc) · 3.18 KB
/
pix2pix_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# https://arxiv.org/abs/1611.07004 ##pix2pix
# https://arxiv.org/abs/1505.04597 ##Unet
from pix2pix import pix2pix
import image_functions as img
import tensorflow as tf #version 1.4
import scipy
import numpy as np
import os
trainset_path = './edges2shoes/train/'
valiset_path = './edges2shoes/val/'
saver_path = './saver/'
make_image_path = './generate/'
batch_size = 4
def train(model, train_set, epoch):
total_D_loss = 0
total_G_loss = 0
iteration = int(np.ceil(len(train_set)/batch_size))
for i in range( iteration ):
print('epoch', epoch, 'batch', i+1,'/', iteration )
filelist = train_set[batch_size * i : batch_size * (i + 1)]
#batch image read, preprocess, split_A_B
batch = []
for name in filelist:
batch.append(img.image_read(trainset_path+name))
batch = img.image_preprocess(np.array(batch, dtype=np.float32))
X, Y = img.image_split_A_B(batch)
###############
#Discriminator 학습.
_, D_loss = sess.run([model.D_minimize, model.D_loss], {
model.X:X, model.Y:Y, model.is_train:True
}
)
#Generator 학습.
_, G_loss = sess.run([model.G_minimize, model.G_loss], {
model.X:X, model.Y:Y, model.is_train:True
}
)
#parameter sum
total_D_loss += D_loss
total_G_loss += G_loss
return total_D_loss/iteration, total_G_loss/iteration
def write_tensorboard(model, D_loss, G_loss, epoch):
summary = sess.run(model.merged,
{
model.D_loss_tensorboard:D_loss,
model.G_loss_tensorboard:G_loss,
}
)
model.writer.add_summary(summary, epoch)
def gen_image(model, vali_set, epoch):
path = make_image_path+str(epoch)+'/'
if not os.path.exists(make_image_path+str(epoch)+'/'):
os.makedirs(make_image_path+str(epoch)+'/')
for name in vali_set:
vali_image = img.image_read(valiset_path+name)
vali_image = img.image_preprocess(np.array(vali_image, dtype=np.float32))
X, Y = img.image_split_A_B(vali_image, 1) # 256 256 3, 256 256 3
generated = sess.run(model.Gen, { # 1 256 256 3
model.X:[X], model.is_train:False
}
)
concat = np.concatenate((X, Y, generated[0]), axis=1)
img.image_save(path+name, concat)
def run(model, train_set, vali_set, restore = 0):
#restore인지 체크.
if restore != 0:
model.saver.restore(sess, saver_path+str(restore)+".ckpt")
print('training start')
#학습 진행
for epoch in range(restore + 1, 2001):
D_loss, G_loss = train(model, train_set, epoch)
print("epoch : ", epoch, " D_loss : ", D_loss, " G_loss : ", G_loss)
if epoch % 3 == 0:
#tensorboard
write_tensorboard(model, D_loss, G_loss, epoch)
#weight 저장할 폴더 생성
if not os.path.exists(saver_path):
os.makedirs(saver_path)
save_path = model.saver.save(sess, saver_path+str(epoch)+".ckpt")
#생성된 이미지 저장할 폴더 생성
if not os.path.exists(make_image_path):
os.makedirs(make_image_path)
gen_image(model, vali_set, epoch)
sess = tf.Session()
#model
model = pix2pix(sess)
#필요한 batch만큼 디스크에서 읽어 오려고 파일 이름만 가져옴.
train_set = img.get_image_filelist(trainset_path)
vali_set = img.get_image_filelist(valiset_path)
#print(train_set.shape, vali_set.shape)
run(model, train_set, vali_set)