-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: Add a script to plot the signal estimated by the GP
Add a script to plot the signal estimated by the GP. Add the necessary helper functions to the signal visualization module.
- Loading branch information
1 parent
796c501
commit 737ee8d
Showing
2 changed files
with
435 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,341 @@ | ||
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- | ||
# vi: set ft=python sts=4 ts=4 sw=4 et: | ||
# | ||
# Copyright The NiPreps Developers <nipreps@gmail.com> | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# We support and encourage derived works from this project, please read | ||
# about our expectations at | ||
# | ||
# https://www.nipreps.org/community/licensing/ | ||
# | ||
|
||
""" | ||
Simulate the DWI signal from a single fiber and plot the predicted signal using a Gaussian process | ||
estimator. | ||
""" | ||
|
||
import argparse | ||
|
||
import numpy as np | ||
from dipy.core.geometry import sphere2cart | ||
from dipy.core.gradients import gradient_table | ||
from dipy.core.sphere import HemiSphere, Sphere, disperse_charges | ||
from dipy.sims.voxel import all_tensor_evecs, multi_tensor, single_tensor | ||
|
||
from eddymotion.model._dipy import GaussianProcessModel | ||
from eddymotion.testing.simulations import add_b0, create_random_polar_coordinates | ||
from eddymotion.viz.signals import plot_prediction_surface | ||
|
||
SAMPLING_DIRECTIONS = 200 | ||
|
||
|
||
def create_single_fiber_evecs(angles): | ||
"""Create eigenvalues for a simulated single fiber.""" | ||
|
||
sticks = np.array(sphere2cart(1, np.deg2rad(angles[0]), np.deg2rad(angles[1]))) | ||
evecs = all_tensor_evecs(sticks) | ||
|
||
return evecs | ||
|
||
|
||
def create_diffusion_encoding_gradient_dirs(hsph_dirs, iterations=5000, seed=1234): | ||
"""Create the dMRI gradient-encoding directions.""" | ||
|
||
# Create the gradient-encoding directions placing random points on a hemisphere | ||
theta, phi = create_random_polar_coordinates(hsph_dirs, seed=seed) | ||
hsph_initial = HemiSphere(theta=theta, phi=phi) | ||
|
||
# Move the points so that the electrostatic potential energy is minimized | ||
hsph_updated, potential = disperse_charges(hsph_initial, iterations) | ||
|
||
# Create a sphere | ||
return Sphere(xyz=np.vstack((hsph_updated.vertices, -hsph_updated.vertices))) | ||
|
||
|
||
def create_single_shell_gradient_table(hsph_dirs, bval_shell, iterations=5000): | ||
"""Create a single-shell gradient table.""" | ||
|
||
# Create diffusion-encoding gradient directions | ||
sph = create_diffusion_encoding_gradient_dirs(hsph_dirs, iterations=iterations) | ||
|
||
# Create the gradient bvals and bvecs | ||
vertices = sph.vertices | ||
values = np.ones(vertices.shape[0]) | ||
bvecs = vertices | ||
bvals = bval_shell * values | ||
|
||
# Add a b0 value to the gradient table | ||
bvals, bvecs = add_b0(bvals, bvecs) | ||
return gradient_table(bvals, bvecs) | ||
|
||
|
||
def determine_fiber_count(evals): | ||
"""Determine the fiber count.""" | ||
|
||
evals_count = len(evals) | ||
|
||
# Create the DWI signal using a single tensor | ||
if evals_count == 3: | ||
return 1 | ||
elif evals_count == 6: | ||
return 2 | ||
elif evals_count == 9: | ||
return 3 | ||
else: | ||
raise NotImplementedError( | ||
"Diffusion gradient-encoding signal generation not implemented for more than 3 fibers" | ||
) | ||
|
||
|
||
def create_single_tensor_signal(angles, evals, S0, snr, rng, gtab): | ||
"""Create a DWI signal with a single tensor.""" | ||
|
||
# Create eigenvectors for a single fiber | ||
evecs = create_single_fiber_evecs(angles) | ||
|
||
return single_tensor(gtab, S0=S0, evals=evals, evecs=evecs, snr=snr, rng=rng) | ||
|
||
|
||
def create_multi_tensor_signal(angles, evals, S0, snr, rng, gtab): | ||
"""Create a DWI signal with multiple tensors.""" | ||
|
||
# Signal fraction: percentage of the contribution of each tensor | ||
fractions = [100 / len(evals)] * len(evals) | ||
|
||
# signal, sticks = multi_tensor( | ||
# gtab, evals, S0=S0, angles=angles, fractions=fractions, snr=snr, rng=rng | ||
# ) | ||
# _evecs = np.array([all_tensor_evecs(_stick) for _stick in _sticks]) | ||
signal, _ = multi_tensor( | ||
gtab, evals, S0=S0, angles=angles, fractions=fractions, snr=snr, rng=rng | ||
) | ||
|
||
return signal | ||
|
||
|
||
def create_single_shell_signal(angles, gtab, S0, evals, snr): | ||
"""Create a single-shell diffusion gradient-encoding signal.""" | ||
|
||
# Fix the random number generator for reproducibility when generating the | ||
# signal | ||
seed = 1234 | ||
rng = np.random.default_rng(seed) | ||
|
||
fiber_count = determine_fiber_count(evals) | ||
|
||
# Eigenvalues | ||
group_size = 3 | ||
_evals = np.asarray([evals[i : i + group_size] for i in range(0, len(evals), group_size)]) | ||
|
||
# Polar coordinates (theta, phi) of the principal axis of the tensor | ||
group_size = 2 | ||
_angles = [tuple(angles[i : i + group_size]) for i in range(0, len(angles), group_size)] | ||
# Get the only in the lists for the single fiber case | ||
if fiber_count == 1: | ||
_evals = _evals[0] | ||
_angles = _angles[0] | ||
|
||
# Create the DWI signal using a single tensor | ||
if fiber_count == 1: | ||
signal = create_single_tensor_signal(_angles, _evals, S0, snr, rng, gtab) | ||
elif fiber_count == 2: | ||
signal = create_multi_tensor_signal(_angles, _evals, S0, snr, rng, gtab) | ||
elif fiber_count == 3: | ||
signal = create_multi_tensor_signal(_angles, _evals, S0, snr, rng, gtab) | ||
else: | ||
raise NotImplementedError( | ||
"Diffusion gradient-encoding signal generation not implemented for more than 3 fibers" | ||
) | ||
|
||
return signal | ||
|
||
|
||
def get_query_vectors(gtab, train_mask): | ||
"""Get the diffusion-encoding gradient vectors where the signal is to be estimated from the | ||
gradient table and the training mask: the vectors of interest are those that are masked in | ||
the training mask. b0 values are excluded.""" | ||
|
||
idx = np.logical_and(~train_mask, ~gtab.b0s_mask) | ||
return gtab.bvecs[idx], np.where(idx)[0] | ||
|
||
|
||
def create_random_train_mask(gtab, size, seed=1234): | ||
"""Create a mask for the gradient table where a ``size`` number of indices will be | ||
excluded. b0 values are excluded.""" | ||
|
||
rng = np.random.default_rng(seed) | ||
|
||
# Get the indices of the non-zero diffusion-encoding gradient vector indices | ||
nnzero_degv_idx = np.where(~gtab.b0s_mask)[0] | ||
|
||
if nnzero_degv_idx.size < size: | ||
raise ValueError( | ||
f"Requested {size} values for masking; gradient table has {nnzero_degv_idx.size} " | ||
"non-zero diffusion-encoding gradient vectors. Reduce the number of masked values." | ||
) | ||
|
||
lo = rng.choice(nnzero_degv_idx, size=size, replace=False) | ||
|
||
# Exclude the b0s | ||
zero_degv_idx = np.asarray(list(set(range(len(gtab.bvals))).difference(nnzero_degv_idx))) | ||
lo = np.hstack([zero_degv_idx, lo]) | ||
|
||
train_mask = np.ones(len(gtab.bvals), dtype=bool) | ||
train_mask[lo] = False | ||
|
||
return train_mask | ||
|
||
|
||
def perform_experiment(gtab, signal): | ||
"""Perform experiment: estimate the dMRI signal on a set of directions fitting a | ||
Gaussian process to the rest of the data.""" | ||
|
||
# Fix the random number generator for reproducibility when generating the | ||
# sampling directions | ||
# seed = 1234 | ||
|
||
# Define the Gaussian process model parameters | ||
kernel_model = "spherical" | ||
lambda_s = 2.0 | ||
a = 1.0 | ||
sigma_sq = 0.5 | ||
|
||
# Define the Gaussian process model instance | ||
gp_model = GaussianProcessModel( | ||
kernel_model=kernel_model, lambda_s=lambda_s, a=a, sigma_sq=sigma_sq | ||
) | ||
|
||
# Use all available data for training | ||
gpfit = gp_model.fit(signal[~gtab.b0s_mask], gtab[~gtab.b0s_mask]) | ||
|
||
# Predict on an oversampled set of random directions over the unit sphere | ||
# theta, phi = create_random_polar_coordinates(SAMPLING_DIRECTIONS, seed=seed) | ||
# sph = Sphere(theta=theta, phi=phi) | ||
|
||
# ToDo | ||
# Not sure why all predictions are zero in gpfit.predict(sph.vertices) | ||
# Also, when creating the convex hull, the gtab required is the one that | ||
# would correspond to the new directions, so a new gtab would need to be | ||
# generated | ||
# return gpfit.predict(sph.vertices), sph.vertices | ||
# For now, predict on the same data | ||
return gpfit.predict(gtab[~gtab.b0s_mask].bvecs), gtab[~gtab.b0s_mask].bvecs | ||
|
||
|
||
def check_fiber_data_args(angles, evals): | ||
"""Check that the number of angle and eigenvalue elements to build a | ||
synthetic fiber signal are appropriate.""" | ||
|
||
angles_count = len(angles) | ||
evals_count = len(evals) | ||
|
||
if angles_count % 2 != 0: | ||
raise ValueError(f"Two fiber angles required per fiber; {angles_count} provided") | ||
|
||
# Create the DWI signal using a single tensor | ||
if evals_count % 3 != 0: | ||
raise ValueError( | ||
f"Three fiber DTI model eigenvalues required per fiber; {evals_count} provided" | ||
) | ||
|
||
if len(angles) == 2 and len(evals) == 3: | ||
pass | ||
elif len(angles) == 4 and len(evals) == 6: | ||
pass | ||
elif len(angles) == 6 and len(evals) == 9: | ||
pass | ||
else: | ||
raise ValueError( | ||
"Fiber angle and fiber DTI model eigenvalue counts do not match; " | ||
f"{angles_count}, {evals_count} provided" | ||
) | ||
|
||
|
||
def _build_arg_parser(): | ||
parser = argparse.ArgumentParser( | ||
description=__doc__, formatter_class=argparse.RawTextHelpFormatter | ||
) | ||
parser.add_argument( | ||
"hsph_dirs", | ||
help="Number of diffusion gradient-encoding directions in the half sphere", | ||
type=int, | ||
) | ||
parser.add_argument( | ||
"bval_shell", | ||
help="Shell b-value", | ||
type=float, | ||
) | ||
parser.add_argument( | ||
"S0", | ||
help="S0 value", | ||
type=float, | ||
) | ||
parser.add_argument( | ||
"--angles", | ||
help="Polar and azimuth angles of the tensor(s0", | ||
nargs="+", | ||
type=float, | ||
) | ||
parser.add_argument( | ||
"--evals", | ||
help="Eigenvalues of the tensor(s)", | ||
nargs="+", | ||
type=float, | ||
) | ||
parser.add_argument( | ||
"--snr", | ||
help="Signal to noise ratio", | ||
type=float, | ||
) | ||
return parser | ||
|
||
|
||
def _parse_args(parser): | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def main(): | ||
parser = _build_arg_parser() | ||
args = _parse_args(parser) | ||
|
||
check_fiber_data_args(args.angles, args.evals) | ||
|
||
# Create a gradient table for a single-shell | ||
gtab = create_single_shell_gradient_table(args.hsph_dirs, args.bval_shell) | ||
|
||
# Create the DWI signal | ||
signal = create_single_shell_signal(args.angles, gtab, args.S0, args.evals, args.snr) | ||
|
||
# Estimate the dMRI signal using a Gaussian process estimator | ||
y_pred, y_pred_dirs = perform_experiment(gtab, signal) | ||
|
||
# Plot the predicted signal | ||
title = "GP model signal prediction" | ||
fig, _, _ = plot_prediction_surface( | ||
signal[~gtab.b0s_mask], | ||
y_pred, | ||
args.S0, | ||
gtab.bvecs[~gtab.b0s_mask], | ||
y_pred_dirs, | ||
title, | ||
"gray", | ||
) | ||
fig.savefig(args.gp_pred_plot_fname, format="svg") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.