Skip to content

Commit

Permalink
Merge pull request #98 from teresamg/logosplit
Browse files Browse the repository at this point in the history
ENH: Outsource leave-one-out splitter so it can be used across data types
  • Loading branch information
oesteban authored Apr 3, 2024
2 parents 2d71ba0 + 9591b81 commit a1bcacd
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 56 deletions.
58 changes: 4 additions & 54 deletions src/eddymotion/data/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,64 +71,14 @@ class DWI:
)
"""A path to an HDF5 file to store the whole dataset."""

def get_filename(self):
"""Get the filepath of the HDF5 file."""
return self._filepath

def __len__(self):
"""Obtain the number of high-*b* orientations."""
return self.dataobj.shape[-1]

def logo_split(self, index, with_b0=False):
"""
Produce one fold of LOGO (leave-one-gradient-out).
Parameters
----------
index : :obj:`int`
Index of the DWI orientation to be left out in this fold.
with_b0 : :obj:`bool`
Insert the *b=0* reference at the beginning of the training dataset.
Returns
-------
(train_data, train_gradients) : :obj:`tuple`
Training DWI and corresponding gradients.
Training data/gradients come **from the updated dataset**.
(test_data, test_gradients) :obj:`tuple`
Test 3D map (one DWI orientation) and corresponding b-vector/value.
The test data/gradient come **from the original dataset**.
"""
if not Path(self._filepath).exists():
self.to_filename(self._filepath)

# read original DWI data & b-vector
with h5py.File(self._filepath, "r") as in_file:
root = in_file["/0"]
dwframe = np.asanyarray(root["dataobj"][..., index])
bframe = np.asanyarray(root["gradients"][..., index])

# if the size of the mask does not match data, cache is stale
mask = np.zeros(len(self), dtype=bool)
mask[index] = True

train_data = self.dataobj[..., ~mask]
train_gradients = self.gradients[..., ~mask]

if with_b0:
train_data = np.concatenate(
(np.asanyarray(self.bzero)[..., np.newaxis], train_data),
axis=-1,
)
b0vec = np.zeros((4, 1))
b0vec[0, 0] = 1
train_gradients = np.concatenate(
(b0vec, train_gradients),
axis=-1,
)

return (
(train_data, train_gradients),
(dwframe, bframe),
)

def set_transform(self, index, affine, order=3):
"""Set an affine, and update data object and gradients."""
reference = namedtuple("ImageGrid", ("shape", "affine"))(
Expand Down
84 changes: 84 additions & 0 deletions src/eddymotion/data/splitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2022 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""Data splitting helpers."""
from pathlib import Path
import numpy as np
import h5py


def lovo_split(dataset, index, with_b0=False):
"""
Produce one fold of LOVO (leave-one-volume-out).
Parameters
----------
dataset : :obj:`eddymotion.data.dmri.DWI`
DWI object
index : :obj:`int`
Index of the DWI orientation to be left out in this fold.
Returns
-------
(train_data, train_gradients) : :obj:`tuple`
Training DWI and corresponding gradients.
Training data/gradients come **from the updated dataset**.
(test_data, test_gradients) :obj:`tuple`
Test 3D map (one DWI orientation) and corresponding b-vector/value.
The test data/gradient come **from the original dataset**.
"""

if not Path(dataset.get_filename()).exists():
dataset.to_filename(dataset.get_filename())

# read original DWI data & b-vector
with h5py.File(dataset.get_filename(), "r") as in_file:
root = in_file["/0"]
data = np.asanyarray(root["dataobj"])
gradients = np.asanyarray(root["gradients"])

# if the size of the mask does not match data, cache is stale
mask = np.zeros(data.shape[-1], dtype=bool)
mask[index] = True

train_data = data[..., ~mask]
train_gradients = gradients[..., ~mask]
test_data = data[..., mask]
test_gradients = gradients[..., mask]

if with_b0:
train_data = np.concatenate(
(np.asanyarray(dataset.bzero)[..., np.newaxis], train_data),
axis=-1,
)
b0vec = np.zeros((4, 1))
b0vec[0, 0] = 1
train_gradients = np.concatenate(
(b0vec, train_gradients),
axis=-1,
)

return (
(train_data, train_gradients),
(test_data, test_gradients),
)
3 changes: 2 additions & 1 deletion src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pkg_resources import resource_filename as pkg_fn
from tqdm import tqdm

from eddymotion.data.splitting import lovo_split
from eddymotion.model import ModelFactory


Expand Down Expand Up @@ -132,7 +133,7 @@ def fit(
pbar.set_description_str(
f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>"
)
data_train, data_test = dwdata.logo_split(i, with_b0=True)
data_train, data_test = lovo_split(dwdata, i, with_b0=True)
grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}"
pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs")

Expand Down
3 changes: 2 additions & 1 deletion test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest

from eddymotion import model
from eddymotion.data.splitting import lovo_split
from eddymotion.data.dmri import DWI


Expand Down Expand Up @@ -97,7 +98,7 @@ def test_two_initialisations(datadir):
dmri_dataset = DWI.from_filename(datadir / "dwi.h5")

# Split data into test and train set
data_train, data_test = dmri_dataset.logo_split(10)
data_train, data_test = lovo_split(dmri_dataset, 10)

# Direct initialisation
model1 = model.AverageDWModel(
Expand Down
62 changes: 62 additions & 0 deletions test/test_splitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""Unit test testing the lovo_split function."""
import numpy as np
from eddymotion.data.dmri import DWI
from eddymotion.data.splitting import lovo_split


def test_lovo_split(datadir):
"""
Test the lovo_split function.
Parameters:
- datadir: The directory containing the test data.
Returns:
None
"""
data = DWI.from_filename(datadir / "dwi.h5")

# Set zeros in dataobj and gradients of the dwi object
data.dataobj[:] = 0
data.gradients[:] = 0

# Select a random index
index = np.random.randint(len(data))

# Set 1 in dataobj and gradients of the dwi object at this specific index
data.dataobj[..., index] = 1
data.gradients[..., index] = 1

# Apply the lovo_split function at the specified index
(train_data, train_gradients), \
(test_data, test_gradients) = lovo_split(data, index)

# Check if the test data contains only 1s
# and the train data contains only 0s after the split
assert np.all(test_data == 1)
assert np.all(train_data == 0)
assert np.all(test_gradients == 1)
assert np.all(train_gradients == 0)

0 comments on commit a1bcacd

Please sign in to comment.