-
Notifications
You must be signed in to change notification settings - Fork 52
/
data_loader.py
85 lines (70 loc) · 3.02 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from builtins import object
import numpy as np
from abc import ABCMeta, abstractmethod
class DataLoaderBase(object):
""" Derive from this class and override generate_train_batch. If you don't want to use this you can use any
generator.
You can modify this class however you want. How the data is presented as batch is you responsibility. You can sample
randomly, cycle through the training examples or sample the dtaa according to a specific pattern. Just make sure to
use our default data structure!
{'data':your_batch_of_shape_(b, c, x, y(, z)),
'seg':your_batch_of_shape_(b, c, x, y(, z)),
'anything_else1':whatever,
'anything_else2':whatever2,
...}
(seg is optional)
Args:
data (anything): Your dataset. Stored as member variable self._data
BATCH_SIZE (int): batch size. Stored as member variable self.BATCH_SIZE
num_batches (int): How many batches will be generated before raising StopIteration. None=unlimited. Careful
when using MultiThreadedAugmenter: Each process will produce num_batches batches.
seed (False, None, int): seed to seed the numpy rng with. False = no seeding
"""
def __init__(self, data, BATCH_SIZE, num_batches=None, seed=False):
__metaclass__ = ABCMeta
self._data = data
self.BATCH_SIZE = BATCH_SIZE
self._num_batches = num_batches
self._seed = seed
self._resetted_rng = False
self._iter_initialized = False
self._p = None
if self._num_batches is None:
self._num_batches = 1e100
self._batches_generated = 0
def _initialize_iter(self):
if self._seed is not False:
np.random.seed(self._seed)
self._iter_initialized = True
def __iter__(self):
return self
def __next__(self):
if not self._iter_initialized:
self._initialize_iter()
if self._batches_generated >= self._num_batches:
self._iter_initialized = False
raise StopIteration
minibatch = self.generate_train_batch()
self._batches_generated += 1
return minibatch
@abstractmethod
def generate_train_batch(self):
'''override this'''
'''
Generate your batch from either self._train_data, self._validation_data or self._test_data. Make sure you
generate the correct batch size (self.BATCH_SIZE)
'''
pass