Skip to content

Commit

Permalink
Merge pull request #108 from hzdr-MedImaging/feature-channelwise-Norm…
Browse files Browse the repository at this point in the history
…alize

implemented channelwise support for Normalize transform.
  • Loading branch information
wolny authored Apr 16, 2024
2 parents 5dc65f4 + 41493f5 commit 22914f3
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions pytorch3dunet/augment/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,28 +546,57 @@ def __call__(self, m):

class Normalize:
"""
Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1].
Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data
in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be
clipped by specifying min_value/max_value either globally using single values or via a
list/tuple channelwise if enabled.
"""

def __init__(self, min_value=None, max_value=None, norm01=False, eps=1e-10, **kwargs):
def __init__(self, min_value=None, max_value=None, norm01=False, channelwise=False,
eps=1e-10, **kwargs):
if min_value is not None and max_value is not None:
assert max_value > min_value
assert max_value > min_value
self.min_value = min_value
self.max_value = max_value
self.norm01 = norm01
self.channelwise = channelwise
self.eps = eps

def __call__(self, m):
if self.min_value is None:
min_value = np.min(m)
if self.channelwise:
# get min/max channelwise
axes = list(range(m.ndim))
axes = tuple(axes[1:])
if self.min_value is None or 'None' in self.min_value:
min_value = np.min(m, axis=axes, keepdims=True)

if self.max_value is None or 'None' in self.max_value:
max_value = np.max(m, axis=axes, keepdims=True)

# check if non None in self.min_value/self.max_value
# if present and if so copy value to min_value
if self.min_value is not None:
for i,v in enumerate(self.min_value):
if v != 'None':
min_value[i] = v

if self.max_value is not None:
for i,v in enumerate(self.max_value):
if v != 'None':
max_value[i] = v
else:
min_value = self.min_value
if self.min_value is None:
min_value = np.min(m)
else:
min_value = self.min_value

if self.max_value is None:
max_value = np.max(m)
else:
max_value = self.max_value
if self.max_value is None:
max_value = np.max(m)
else:
max_value = self.max_value

# calculate norm_0_1 with min_value / max_value with the same dimension
# in case of channelwise application
norm_0_1 = (m - min_value) / (max_value - min_value + self.eps)

if self.norm01 is True:
Expand Down

0 comments on commit 22914f3

Please sign in to comment.