diff --git a/src/eddymotion/data/dmri.py b/src/eddymotion/data/dmri.py index da2c1388..09fbe350 100644 --- a/src/eddymotion/data/dmri.py +++ b/src/eddymotion/data/dmri.py @@ -71,64 +71,14 @@ class DWI: ) """A path to an HDF5 file to store the whole dataset.""" + def get_filename(self): + """Get the filepath of the HDF5 file.""" + return self._filepath + def __len__(self): """Obtain the number of high-*b* orientations.""" return self.dataobj.shape[-1] - def logo_split(self, index, with_b0=False): - """ - Produce one fold of LOGO (leave-one-gradient-out). - - Parameters - ---------- - index : :obj:`int` - Index of the DWI orientation to be left out in this fold. - with_b0 : :obj:`bool` - Insert the *b=0* reference at the beginning of the training dataset. - - Returns - ------- - (train_data, train_gradients) : :obj:`tuple` - Training DWI and corresponding gradients. - Training data/gradients come **from the updated dataset**. - (test_data, test_gradients) :obj:`tuple` - Test 3D map (one DWI orientation) and corresponding b-vector/value. - The test data/gradient come **from the original dataset**. - - """ - if not Path(self._filepath).exists(): - self.to_filename(self._filepath) - - # read original DWI data & b-vector - with h5py.File(self._filepath, "r") as in_file: - root = in_file["/0"] - dwframe = np.asanyarray(root["dataobj"][..., index]) - bframe = np.asanyarray(root["gradients"][..., index]) - - # if the size of the mask does not match data, cache is stale - mask = np.zeros(len(self), dtype=bool) - mask[index] = True - - train_data = self.dataobj[..., ~mask] - train_gradients = self.gradients[..., ~mask] - - if with_b0: - train_data = np.concatenate( - (np.asanyarray(self.bzero)[..., np.newaxis], train_data), - axis=-1, - ) - b0vec = np.zeros((4, 1)) - b0vec[0, 0] = 1 - train_gradients = np.concatenate( - (b0vec, train_gradients), - axis=-1, - ) - - return ( - (train_data, train_gradients), - (dwframe, bframe), - ) - def set_transform(self, index, affine, order=3): """Set an affine, and update data object and gradients.""" reference = namedtuple("ImageGrid", ("shape", "affine"))( diff --git a/src/eddymotion/data/splitting.py b/src/eddymotion/data/splitting.py new file mode 100644 index 00000000..7d7ff8b4 --- /dev/null +++ b/src/eddymotion/data/splitting.py @@ -0,0 +1,84 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2022 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Data splitting helpers.""" +from pathlib import Path +import numpy as np +import h5py + + +def lovo_split(dataset, index, with_b0=False): + """ + Produce one fold of LOVO (leave-one-volume-out). + + Parameters + ---------- + dataset : :obj:`eddymotion.data.dmri.DWI` + DWI object + index : :obj:`int` + Index of the DWI orientation to be left out in this fold. + + Returns + ------- + (train_data, train_gradients) : :obj:`tuple` + Training DWI and corresponding gradients. + Training data/gradients come **from the updated dataset**. + (test_data, test_gradients) :obj:`tuple` + Test 3D map (one DWI orientation) and corresponding b-vector/value. + The test data/gradient come **from the original dataset**. + + """ + + if not Path(dataset.get_filename()).exists(): + dataset.to_filename(dataset.get_filename()) + + # read original DWI data & b-vector + with h5py.File(dataset.get_filename(), "r") as in_file: + root = in_file["/0"] + data = np.asanyarray(root["dataobj"]) + gradients = np.asanyarray(root["gradients"]) + + # if the size of the mask does not match data, cache is stale + mask = np.zeros(data.shape[-1], dtype=bool) + mask[index] = True + + train_data = data[..., ~mask] + train_gradients = gradients[..., ~mask] + test_data = data[..., mask] + test_gradients = gradients[..., mask] + + if with_b0: + train_data = np.concatenate( + (np.asanyarray(dataset.bzero)[..., np.newaxis], train_data), + axis=-1, + ) + b0vec = np.zeros((4, 1)) + b0vec[0, 0] = 1 + train_gradients = np.concatenate( + (b0vec, train_gradients), + axis=-1, + ) + + return ( + (train_data, train_gradients), + (test_data, test_gradients), + ) diff --git a/src/eddymotion/estimator.py b/src/eddymotion/estimator.py index ead5f7fe..4d1686aa 100644 --- a/src/eddymotion/estimator.py +++ b/src/eddymotion/estimator.py @@ -34,6 +34,7 @@ from pkg_resources import resource_filename as pkg_fn from tqdm import tqdm +from eddymotion.data.splitting import lovo_split from eddymotion.model import ModelFactory @@ -132,7 +133,7 @@ def fit( pbar.set_description_str( f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>" ) - data_train, data_test = dwdata.logo_split(i, with_b0=True) + data_train, data_test = lovo_split(dwdata, i, with_b0=True) grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}" pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs") diff --git a/test/test_model.py b/test/test_model.py index cad6b48a..6ade1d3a 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -26,6 +26,7 @@ import pytest from eddymotion import model +from eddymotion.data.splitting import lovo_split from eddymotion.data.dmri import DWI @@ -97,7 +98,7 @@ def test_two_initialisations(datadir): dmri_dataset = DWI.from_filename(datadir / "dwi.h5") # Split data into test and train set - data_train, data_test = dmri_dataset.logo_split(10) + data_train, data_test = lovo_split(dmri_dataset, 10) # Direct initialisation model1 = model.AverageDWModel( diff --git a/test/test_splitting.py b/test/test_splitting.py new file mode 100644 index 00000000..155a8794 --- /dev/null +++ b/test/test_splitting.py @@ -0,0 +1,62 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2021 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Unit test testing the lovo_split function.""" +import numpy as np +from eddymotion.data.dmri import DWI +from eddymotion.data.splitting import lovo_split + + +def test_lovo_split(datadir): + """ + Test the lovo_split function. + + Parameters: + - datadir: The directory containing the test data. + + Returns: + None + """ + data = DWI.from_filename(datadir / "dwi.h5") + + # Set zeros in dataobj and gradients of the dwi object + data.dataobj[:] = 0 + data.gradients[:] = 0 + + # Select a random index + index = np.random.randint(len(data)) + + # Set 1 in dataobj and gradients of the dwi object at this specific index + data.dataobj[..., index] = 1 + data.gradients[..., index] = 1 + + # Apply the lovo_split function at the specified index + (train_data, train_gradients), \ + (test_data, test_gradients) = lovo_split(data, index) + + # Check if the test data contains only 1s + # and the train data contains only 0s after the split + assert np.all(test_data == 1) + assert np.all(train_data == 0) + assert np.all(test_gradients == 1) + assert np.all(train_gradients == 0) +