Skip to content

Commit

Permalink
Merge pull request #227 from jhlegarreta/AddGPExperimentScripts
Browse files Browse the repository at this point in the history
ENH: Add GP error analysis experiment script
  • Loading branch information
oesteban authored Oct 23, 2024
2 parents 7414275 + edc0331 commit 796c501
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 1 deletion.
170 changes: 170 additions & 0 deletions scripts/dwi_estimation_error_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#

"""
Simulate the DWI signal from a single fiber and analyze the prediction error of an estimator using
Gaussian processes.
"""

from __future__ import annotations

import argparse
from collections import defaultdict

# import nibabel as nib
import numpy as np
import pandas as pd
from sklearn.model_selection import RepeatedKFold, cross_val_score

from eddymotion.model._sklearn import (
EddyMotionGPR,
SphericalKriging,
)
from eddymotion.testing import simulations as testsims


def cross_validate(
X: np.ndarray,
y: np.ndarray,
cv: int,
) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]:
"""
Perform the experiment by estimating the dMRI signal using a Gaussian process model.
Parameters
----------
gtab : :obj:`~dipy.core.gradients.gradient_table`
Gradient table.
S0 : :obj:`float`
S0 value.
evals1 : :obj:`~numpy.ndarray`
Eigenvalues of the tensor.
evecs : :obj:`~numpy.ndarray`
Eigenvectors of the tensor.
snr : :obj:`float`
Signal-to-noise ratio.
cv : :obj:`int`
number of folds
Returns
-------
:obj:`dict`
Data for the predicted signal and its error.
"""
gpm = EddyMotionGPR(
kernel=SphericalKriging(a=1.15, lambda_s=120),
alpha=100,
optimizer=None,
)

rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv)
scores = cross_val_score(gpm, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
return scores


def _build_arg_parser() -> argparse.ArgumentParser:
"""
Build argument parser for command-line interface.
Returns
-------
:obj:`~argparse.ArgumentParser`
Argument parser for the script.
"""
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"hsph_dirs",
help="Number of diffusion gradient-encoding directions in the half sphere",
type=int,
)
parser.add_argument("bval_shell", help="Shell b-value", type=float)
parser.add_argument("S0", help="S0 value", type=float)
parser.add_argument("--evals1", help="Eigenvalues of the tensor", nargs="+", type=float)
parser.add_argument("--snr", help="Signal to noise ratio", type=float)
parser.add_argument("--repeats", help="Number of repeats", type=int, default=5)
parser.add_argument(
"--kfold", help="Number of directions to leave out/predict", nargs="+", type=int
)
return parser


def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
"""
Parse command-line arguments.
Parameters
----------
parser : :obj:`~argparse.ArgumentParser`
Argument parser for the script.
Returns
-------
:obj:`~argparse.Namespace`
Parsed arguments.
"""
return parser.parse_args()


def main() -> None:
"""Main function for running the experiment and plotting the results."""
parser = _build_arg_parser()
args = _parse_args(parser)

data, gtab = testsims.simulate_voxels(
args.S0,
args.evals1,
args.hsph_dirs,
bval_shell=args.bval_shell,
snr=args.snr,
n_voxels=100,
seed=None,
)

X = gtab[~gtab.b0s_mask].bvecs
y = data[:, ~gtab.b0s_mask]

# Use Scikit-learn cross validation
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)

print(f"Finished {n}-fold cross-validation")

scores_df = pd.DataFrame(scores)
scores_df.to_csv("cv_scores.tsv", sep="\t", index=None, na_rep="n/a")

grouped = scores_df.groupby(["n_folds"])
print(grouped[["rmse"]].mean())
print(grouped[["rmse"]].std())


if __name__ == "__main__":
main()
32 changes: 31 additions & 1 deletion src/eddymotion/testing/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dipy.core.geometry import sphere2cart
from dipy.core.gradients import gradient_table
from dipy.core.sphere import HemiSphere, Sphere, disperse_charges
from dipy.sims.voxel import all_tensor_evecs
from dipy.sims.voxel import all_tensor_evecs, single_tensor


def add_b0(bvals: np.ndarray, bvecs: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -195,3 +195,33 @@ def get_query_vectors(
"""
idx = np.logical_and(~train_mask, ~gtab.b0s_mask)
return gtab.bvecs[idx], np.where(idx)[0]


def single_fiber_voxel(gtab, S0, evals, theta=0, phi=0, snr=20):
# create eigenvectors for a single fiber
evecs = create_single_fiber_evecs(theta=theta, phi=phi)

# Generate some data
return single_tensor(gtab, S0=S0, evals=evals, evecs=evecs, snr=snr)


def simulate_voxels(S0, evals, hsph_dirs, bval_shell=1000, snr=20, n_voxels=1, seed=None):
# Create a gradient table for a single-shell
gtab = create_single_shell_gradient_table(hsph_dirs, bval_shell)

rng = np.random.default_rng(seed)

angles = zip(
rng.uniform(0, np.pi, size=n_voxels),
rng.uniform(0, 2.0 * np.pi, size=n_voxels),
strict=False,
)

signal = np.vstack(
[
single_fiber_voxel(gtab, S0, evals, theta=theta, phi=phi, snr=snr)
for theta, phi in angles
]
)

return signal, gtab

0 comments on commit 796c501

Please sign in to comment.