-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
171 lines (128 loc) · 4.82 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Adapted from github.com/affinelayer/pix2pix-tensorflow
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import math
import tensorflow as tf
def check_folder(path_dir):
"""
Checks if directory exists and creates one if not
"""
if not os.path.isdir(path_dir):
os.makedirs(path_dir)
def check_image(image):
"""
Ensure that the given image has 3 channels
"""
assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
with tf.control_dependencies([assertion]):
image = tf.identity(image)
if image.get_shape().ndims not in (3, 4):
raise ValueError("image must be either 3 or 4 dimensions")
shape = list(image.get_shape())
shape[-1] = 3
image.set_shape(shape)
return image
def rgb_to_rgbxy(image):
"""
Append x and y co-ordinates to the RGB channel
"""
image = check_image(image)
red_channel, green_channel, blue_channel = tf.unstack(image, axis=-1)
x_channel = [[i / 256 for i in range(0, 256)] for _ in range(0, 256)]
y_channel = [[i / 256 for _ in range(0, 256)] for i in range(0, 256)]
return tf.stack([red_channel, green_channel, blue_channel, x_channel, y_channel], axis=-1)
def rgbxy_to_rgb(image):
"""
Remove x and y co-ordinates from the RGBXY channel
"""
red_channel, green_channel, blue_channel, x_channel, y_channel = tf.unstack(image, axis=-1)
return tf.stack([red_channel, green_channel, blue_channel], axis=-1)
def convert(image):
"""
Convert image to original type
"""
return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)
def pre_process(image):
"""
Scale pixels of a given image to [-1, 1]
"""
# [0, 1] => [-1, 1]
return (image * 2) - 1
def de_process(image):
"""
Scale pixels of a given image to [0, 1]
"""
# [-1, 1] => [0, 1]
return (image + 1) / 2
def get_name(path):
"""
Get the image filename
"""
name, _ = os.path.splitext(os.path.basename(path))
return name
def load_images(input_dir, batch_size, mode):
"""
Load images from the given input directory
"""
if input_dir is None or not os.path.exists(input_dir):
raise Exception("input_dir does not exist")
input_paths = glob.glob(os.path.join(input_dir, "*.png"))
if len(input_paths) == 0:
raise Exception("input_dir contains no image files")
if all(get_name(path).isdigit() for path in input_paths):
input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
else:
input_paths = sorted(input_paths)
path_queue = tf.train.string_input_producer(input_paths, shuffle=mode == "train")
reader = tf.WholeFileReader()
paths, contents = reader.read(path_queue)
raw_image = tf.image.decode_png(contents)
raw_image = tf.image.convert_image_dtype(raw_image, dtype=tf.float32)
raw_image.set_shape([None, None, 3])
width = tf.shape(raw_image)[1]
inputs, targets = raw_image[:, :width // 2, :], raw_image[:, width // 2:, :]
inputs, targets = rgb_to_rgbxy(inputs), rgb_to_rgbxy(targets)
inputs, targets = pre_process(inputs), pre_process(targets)
paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, inputs, targets], batch_size=batch_size)
steps_per_epoch = int(math.ceil(len(input_paths) / batch_size))
return paths_batch, inputs_batch, targets_batch, steps_per_epoch
def save_images(results, output_dir):
"""
Save images to the given output directory
"""
image_dir = os.path.join(output_dir, "images")
check_folder(image_dir)
filesets = []
for i, in_path in enumerate(results["paths"]):
name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
fileset = {"name": name}
for kind in ["inputs", "outputs", "targets"]:
filename = f"{name}-{kind}.png"
fileset[kind] = filename
out_path = os.path.join(image_dir, filename)
contents = results[kind][i]
with open(out_path, "wb") as file:
file.write(contents)
filesets.append(fileset)
return filesets
def append_index(filesets, output_dir):
"""
Write the test results to the index
"""
index_path = os.path.join(output_dir, "index.html")
if os.path.exists(index_path):
index = open(index_path, "a")
else:
index = open(index_path, "w")
index.write("<html><body><table><tr><th>Name</th><th>Input</th><th>Output</th><th>Target</th></tr>")
for fileset in filesets:
index.write(f"<tr><td>{fileset['name']}</td>")
for kind in ["inputs", "outputs", "targets"]:
index.write(f"<td><img src='images/{fileset[kind]}'></td>")
index.write("</tr>")
return index_path