-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
83 lines (61 loc) · 3.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Adapted from github.com/affinelayer/pix2pix-tensorflow
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from model import GAN
from utils import *
import argparse
import tensorflow as tf
flags = tf.flags
flags.DEFINE_integer('batch_size', 1, 'The number of images in each batch.')
flags.DEFINE_integer('max_epochs', 200, 'The number of training epochs.')
flags.DEFINE_integer('ngf', 64, 'The number of generator filters in the first convolution layer.')
flags.DEFINE_integer('ndf', 64, 'The number of discriminator filters in the first convolution layer.')
flags.DEFINE_float('lr', 0.0002, 'The initial learning rate for ADAM.')
flags.DEFINE_float('beta1', 0.5, 'The momentum term of ADAM.')
flags.DEFINE_float('l1_weight', 100.0, 'The weight on the L1 term for the generator gradient.')
flags.DEFINE_float('gan_weight', 1.0, 'The weight on the GAN term for the generator gradient.')
flags.DEFINE_integer('progress_freq', 50, 'The number of steps to take before displaying progress.')
flags.DEFINE_integer('save_freq', 5000, 'The number of steps to take before saving the model.')
FLAGS = flags.FLAGS
def parse_args():
"""
Parse the arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True, help="directory containing the input images")
parser.add_argument("--mode", required=True, choices=["train", "test"], help="operation that will be used")
parser.add_argument("--output_dir", required=True, help="directory where the output images will be saved")
parser.add_argument("--checkpoint", default=None, help="checkpoint to resume training from or use for testing")
return parser.parse_args()
def main():
# Parse the arguments from the command line
args = parse_args()
# Create output directory if it does not exist
check_folder(args.output_dir)
# Ensure checkpoint exists before testing
if args.mode == "test" and args.checkpoint is None:
raise Exception("Checkpoint is required for test mode")
# Load the images from the input directory
paths, inputs, targets, steps_per_epoch = load_images(args.input_dir, FLAGS.batch_size, args.mode)
# Initialise the GAN before running
model = GAN(args.input_dir, args.output_dir, args.checkpoint, paths, inputs, targets, FLAGS.batch_size,
steps_per_epoch, FLAGS.ngf, FLAGS.ndf, FLAGS.lr, FLAGS.beta1, FLAGS.l1_weight, FLAGS.gan_weight)
# Output images for model
display_images = {
"paths": paths,
"inputs": tf.map_fn(tf.image.encode_png, convert(rgbxy_to_rgb(de_process(inputs))), dtype=tf.string, name="inputs_pngs"),
"targets": tf.map_fn(tf.image.encode_png, convert(rgbxy_to_rgb(de_process(targets))), dtype=tf.string, name="target_pngs"),
"outputs": tf.map_fn(tf.image.encode_png, convert(rgbxy_to_rgb(de_process(model.get_outputs()))), dtype=tf.string, name="output_pngs"),
}
sv = tf.train.Supervisor(logdir=args.output_dir, save_summaries_secs=0, saver=None)
with sv.managed_session() as sess:
# Train or test the initialised GAN based on the chosen mode
if args.mode == "train":
model.train(sv, sess, FLAGS.max_epochs, FLAGS.progress_freq, FLAGS.save_freq)
else:
model.test(sess, display_images)
if __name__ == '__main__':
main()