-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_data.py
95 lines (75 loc) · 3.83 KB
/
eval_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
MEAN=[0.485, 0.456, 0.406]
STD=[0.229, 0.224, 0.225]
class Dataset(object):
"""
Wrapper class around the new Tensorflows dataset pipeline.
Handles loading, partitioning, and preparing training data.
"""
def __init__(self, tfrecord_path, num_views, height, width, batch_size=1):
self.num_views = num_views
self.resize_h = height
self.resize_w = width
dataset = tf.data.TFRecordDataset(tfrecord_path,
compression_type='GZIP',
num_parallel_reads=batch_size * 4)
# dataset = dataset.map(self._parse_func, num_parallel_calls=8)
# The map transformation takes a function and applies it to every element
# of the dataset.
dataset = dataset.map(self.decode, num_parallel_calls=8)
dataset = dataset.map(self.augment, num_parallel_calls=8)
dataset = dataset.map(self.normalize, num_parallel_calls=8)
# Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size)
# The shuffle transformation uses a finite-sized buffer to shuffle elements
# in memory. The parameter is the number of elements in the buffer. For
# completely uniform shuffling, set the parameter to be the same as the
# number of elements in the dataset.
dataset = dataset.shuffle(1000 + 3 * batch_size)
dataset = dataset.repeat(1)
self.dataset = dataset.batch(batch_size)
def decode(self, serialized_example):
"""Parses an image and label from the given `serialized_example`."""
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'image/filename': tf.FixedLenFeature([self.num_views], tf.string),
'image/encoded': tf.FixedLenFeature([self.num_views], tf.string),
# 'image/label': tf.FixedLenFeature([], tf.int64),
})
# Convert from a scalar string tensor to a float32 tensor with shape
# image_decoded = tf.image.decode_png(features['image/encoded'], channels=3)
# image = tf.image.resize_images(image_decoded, [self.resize_h, self.resize_w])
#
# filename = features['image/filename']
images = []
filenames = []
img_lst = tf.unstack(features['image/encoded'])
filename_lst = tf.unstack(features['image/filename'])
for i, img in enumerate(img_lst):
# Convert from a scalar string tensor to a float32 tensor with shape
image_decoded = tf.image.decode_png(img, channels=3)
image = tf.image.resize_images(image_decoded, [self.resize_h, self.resize_w])
images.append(image)
filenames.append(filename_lst[i])
return images, filenames
def augment(self, images, filenames):
"""Placeholder for data augmentation."""
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
# example, and the next step expects the image to be flattened
# into a vector, we don't bother.
return images, filenames
def normalize(self, images, filenames):
# input[channel] = (input[channel] - mean[channel]) / std[channel]
img_lst = []
img_tensor_lst = tf.unstack(images)
for i, image in enumerate(img_tensor_lst):
# image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
img_lst.append(tf.div(tf.subtract(image, MEAN), STD))
return img_lst, filenames