Skip to content

Commit

Permalink
added ShearNoise33 noise model that assumes that convergence variable…
Browse files Browse the repository at this point in the history
…s have a noise prior 2x the other parameters.
  • Loading branch information
bwpriest committed Aug 15, 2024
1 parent 699ff87 commit 8421462
Show file tree
Hide file tree
Showing 14 changed files with 196 additions and 27 deletions.
11 changes: 7 additions & 4 deletions MuyGPyS/_src/gp/noise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from MuyGPyS._src.util import _collect_implementation

(_homoscedastic_perturb, _heteroscedastic_perturb) = _collect_implementation(
"MuyGPyS._src.gp.noise",
"_homoscedastic_perturb",
"_heteroscedastic_perturb",
(_homoscedastic_perturb, _heteroscedastic_perturb, _shear_perturb33) = (
_collect_implementation(
"MuyGPyS._src.gp.noise",
"_homoscedastic_perturb",
"_heteroscedastic_perturb",
"_shear_perturb33",
)
)
27 changes: 27 additions & 0 deletions MuyGPyS/_src/gp/noise/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ def _homoscedastic_perturb(
)


@jit
def _shear_perturb33(Kin: jnp.ndarray, noise_variance: float) -> jnp.ndarray:
convergence_variance = noise_variance * 2
if Kin.ndim == 5:
b, in_count, nn_count, in_count2, nn_count2 = Kin.shape
assert nn_count == nn_count2
assert in_count == in_count2
assert in_count == 3
all_count = in_count * nn_count
Kin_flat = Kin.reshape(b, all_count, all_count)
nugget = jnp.diag(
jnp.hstack(
(
convergence_variance * jnp.ones(nn_count),
noise_variance * jnp.ones(2 * nn_count),
)
)
)
Kin_flat = Kin_flat + nugget
return Kin_flat.reshape(b, in_count, nn_count, in_count, nn_count)
else:
raise ValueError(
"homoscedastic perturbation is not implemented for tensors of "
f"shape {Kin.shape}"
)


@jit
def _heteroscedastic_perturb(
Kin: jnp.ndarray, noise_variances: jnp.ndarray
Expand Down
6 changes: 5 additions & 1 deletion MuyGPyS/_src/gp/noise/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
#
# SPDX-License-Identifier: MIT

from MuyGPyS._src.gp.noise.numpy import _homoscedastic_perturb, np
from MuyGPyS._src.gp.noise.numpy import (
_homoscedastic_perturb,
_shear_perturb33,
np,
)


def _heteroscedastic_perturb(
Expand Down
26 changes: 26 additions & 0 deletions MuyGPyS/_src/gp/noise/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,32 @@ def _homoscedastic_perturb(
)


def _shear_perturb33(Kin: np.ndarray, noise_variance: float) -> np.ndarray:
convergence_variance = noise_variance * 2
if Kin.ndim == 5:
b, in_count, nn_count, in_count2, nn_count2 = Kin.shape
assert nn_count == nn_count2
assert in_count == in_count2
assert in_count == 3
all_count = in_count * nn_count
Kin_flat = Kin.reshape(b, all_count, all_count)
nugget = np.diag(
np.hstack(
(
convergence_variance * np.ones(nn_count),
noise_variance * np.ones(2 * nn_count),
)
)
)
Kin_flat = Kin_flat + nugget
return Kin_flat.reshape(b, in_count, nn_count, in_count, nn_count)
else:
raise ValueError(
"homoscedastic perturbation is not implemented for tensors of "
f"shape {Kin.shape}"
)


def _heteroscedastic_perturb(
Kin: np.ndarray, noise_variances: np.ndarray
) -> np.ndarray:
Expand Down
28 changes: 28 additions & 0 deletions MuyGPyS/_src/gp/noise/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,34 @@ def _homoscedastic_perturb(
)


def _shear_perturb33(
Kin: torch.ndarray, noise_variance: float
) -> torch.ndarray:
convergence_variance = noise_variance * 2
if Kin.ndim == 5:
b, in_count, nn_count, in_count2, nn_count2 = Kin.shape
assert nn_count == nn_count2
assert in_count == in_count2
assert in_count == 3
all_count = in_count * nn_count
Kin_flat = Kin.reshape(b, all_count, all_count)
nugget = torch.diag(
torch.hstack(
(
convergence_variance * torch.ones(nn_count),
noise_variance * torch.ones(2 * nn_count),
)
)
)
Kin_flat = Kin_flat + nugget
return Kin_flat.reshape(b, in_count, nn_count, in_count, nn_count)
else:
raise ValueError(
"homoscedastic perturbation is not implemented for tensors of "
f"shape {Kin.shape}"
)


def _heteroscedastic_perturb(
Kin: torch.ndarray, noise_variances: torch.ndarray
) -> torch.ndarray:
Expand Down
2 changes: 2 additions & 0 deletions MuyGPyS/_src/math/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
count_nonzero,
cov,
choose,
diag,
divide,
einsum,
equal,
Expand All @@ -28,6 +29,7 @@
float32,
float64,
histogram,
hstack,
inf,
int32,
int64,
Expand Down
1 change: 1 addition & 0 deletions MuyGPyS/_src/math/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
cov,
count_nonzero,
choose,
diag,
divide,
dot,
einsum,
Expand Down
2 changes: 2 additions & 0 deletions MuyGPyS/_src/math/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
cov,
corrcoef,
cuda,
diag,
einsum,
eq,
equal,
exp,
hstack,
inf,
int32,
int64,
Expand Down
42 changes: 32 additions & 10 deletions MuyGPyS/_test/shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from MuyGPyS.gp.deformation import DifferenceIsotropy, F2
from MuyGPyS.gp.hyperparameter import ScalarParam, Parameter, FixedScale
from MuyGPyS.gp.kernels.experimental import ShearKernel, ShearKernel2in3out
from MuyGPyS.gp.noise import HomoscedasticNoise
from MuyGPyS.gp.noise import HomoscedasticNoise, ShearNoise33


def kk_f(x1, y1, x2, y2, a=1, b=1):
Expand Down Expand Up @@ -184,13 +184,24 @@ def conventional_mean(Kin, Kcross, targets, noise):
nugget_size = Kin.shape[0]
test_count = int(Kcross.shape[0] / 3)
return (
(
Kcross
@ np.linalg.solve(
Kin + noise * np.eye(nugget_size),
targets,
)
(Kcross @ np.linalg.solve(Kin + noise * np.eye(nugget_size), targets))
.reshape(3, test_count)
.swapaxes(0, 1)
)


def conventional_mean33(Kin, Kcross, targets, noise):
nugget_size = Kin.shape[0]
assert nugget_size % 3 == 0
test_count = int(Kcross.shape[0] / 3)
train_count = int(nugget_size / 3)
nugget = np.diag(
np.hstack(
(2 * noise * np.ones(train_count), noise * np.ones(2 * train_count))
)
)
return (
(Kcross @ np.linalg.solve(Kin + nugget, targets))
.reshape(3, test_count)
.swapaxes(0, 1)
)
Expand All @@ -199,9 +210,20 @@ def conventional_mean(Kin, Kcross, targets, noise):
def conventional_variance(Kin, Kcross, Kout, noise):
nugget_size = Kin.shape[0]
return Kout - Kcross @ np.linalg.solve(
Kin + noise * np.eye(nugget_size),
Kcross.T,
Kin + noise * np.eye(nugget_size), Kcross.T
)


def conventional_variance33(Kin, Kcross, Kout, noise):
nugget_size = Kin.shape[0]
assert nugget_size % 3 == 0
train_count = int(nugget_size / 3)
nugget = np.diag(
np.hstack(
(2 * noise * np.ones(train_count), noise * np.ones(2 * train_count))
)
)
return Kout - Kcross @ np.linalg.solve(Kin + nugget, Kcross.T)


class BenchmarkTestCase(parameterized.TestCase):
Expand Down Expand Up @@ -236,7 +258,7 @@ def setUpClass(cls):
length_scale=Parameter(0.04, [0.02, 0.07]),
),
),
noise=HomoscedasticNoise(cls.noise_prior),
noise=ShearNoise33(cls.noise_prior),
scale=FixedScale(),
)
cls.model23 = MuyGPS(
Expand Down
1 change: 1 addition & 0 deletions MuyGPyS/gp/noise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .heteroscedastic import HeteroscedasticNoise
from .null import NullNoise
from .noise_fn import NoiseFn
from .shear import ShearNoise33
29 changes: 29 additions & 0 deletions MuyGPyS/gp/noise/shear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

"""
Noise modeling
Defines data structures and functors that handle noise priors for MuyGPs models.
"""

from typing import Callable, Tuple, Union

import MuyGPyS._src.math as mm

from MuyGPyS._src.gp.noise import _shear_perturb33

from MuyGPyS.gp.noise.homoscedastic import HomoscedasticNoise


class ShearNoise33(HomoscedasticNoise):

def __init__(
self,
val: Union[str, float],
bounds: Union[str, Tuple[float, float]] = "fixed",
_backend_fn: Callable = _shear_perturb33,
):
super(ShearNoise33, self).__init__(val, bounds, _backend_fn=_backend_fn)
5 changes: 3 additions & 2 deletions experimental/shear_2x3_offset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@
"A^{-1} + A^{-1} C^\\top S^{-1} C A^{-1} & -A^{-1}C^\\top S^{-1} \\\\\n",
"-S^{-1} CA^{-1} & S^{-1} \\\\\n",
"\\end{pmatrix}, \\textrm{ where} \\\\\n",
"S &= D - C K_{\\gamma_1, \\gamma_1}^{-1} C^\\top\n",
"S &= D - C A^{-1} C^\\top\n",
"\\end{align}.\n",
"\n",
"Rewriting for $K_{2,3}(X, X)$ and dropping the $(X, X)$ arguments for legibility, this resolves to\n",
Expand All @@ -922,7 +922,8 @@
"\\begin{pmatrix}\n",
"K_{\\gamma_1, \\gamma_1} & K_{\\gamma_1, \\gamma_2} \\\\\n",
"K_{\\gamma_2, \\gamma_1} & K_{\\gamma_2, \\gamma_2} \\\\\n",
"\\end{pmatrix}^{-1} =\n",
"\\end{pmatrix}^{-1} \\\\\n",
"&=\n",
"\\begin{pmatrix}\n",
"K_{\\gamma_1, \\gamma_1}^{-1} + K_{\\gamma_1, \\gamma_1}^{-1} K_{\\gamma_1, \\gamma_2} S^{-1} K_{\\gamma_2, \\gamma_1} K_{\\gamma_1, \\gamma_1}^{-1} & -K_{\\gamma_1, \\gamma_1}^{-1} K_{\\gamma_1, \\gamma_2} S^{-1} \\\\\n",
"-S^{-1} K_{\\gamma_2, \\gamma_1} K_{\\gamma_1, \\gamma_1}^{-1} & S^{-1} \\\\\n",
Expand Down
8 changes: 4 additions & 4 deletions experimental/shear_kernel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"from MuyGPyS.gp.hyperparameter import Parameter\n",
"from MuyGPyS.gp.kernels.experimental import ShearKernel\n",
"from MuyGPyS.neighbors import NN_Wrapper\n",
"from MuyGPyS.gp.noise import HomoscedasticNoise"
"from MuyGPyS.gp.noise import HomoscedasticNoise, ShearNoise33"
]
},
{
Expand Down Expand Up @@ -107,7 +107,7 @@
" length_scale=Parameter(length_scale),\n",
" ),\n",
" ),\n",
" noise = HomoscedasticNoise(1e-4),\n",
" noise = ShearNoise33(1e-4),\n",
")\n",
"diffs = shear_model.kernel.deformation.pairwise_tensor(features, np.arange(data_count))"
]
Expand Down Expand Up @@ -1447,7 +1447,7 @@
" length_scale=Parameter(0.5, [0.01, 0.9]),\n",
" ),\n",
" ),\n",
" noise=HomoscedasticNoise(noise_prior),\n",
" noise=ShearNoise33(noise_prior),\n",
" scale=AnalyticScale(),\n",
")\n",
"\n",
Expand Down Expand Up @@ -1695,7 +1695,7 @@
" length_scale=Parameter(shear_mse_optimized.kernel.deformation.length_scale()),\n",
" ),\n",
" ),\n",
" noise=HomoscedasticNoise(noise_prior),\n",
" noise=ShearNoise33(noise_prior),\n",
")"
]
},
Expand Down
Loading

0 comments on commit 8421462

Please sign in to comment.