Skip to content

Commit

Permalink
added downsampling scale class. added downsampling test case. split s…
Browse files Browse the repository at this point in the history
…cale optimization into its own test file. cleanup up backend tests slightly.
  • Loading branch information
bwpriest committed Oct 3, 2023
1 parent b26669c commit 98f82c0
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 117 deletions.
1 change: 1 addition & 0 deletions .github/workflows/develop-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
python tests/batch.py
python tests/predict.py
python tests/precompute/fast_posterior_mean.py
python tests/scale_opt.py
- name: Optimize Tests
if: matrix.test-group == 'optimize'
run: python tests/optimize.py
Expand Down
2 changes: 2 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ all_tests:
- salloc -N1 -ppvis -A muygps --mpibind=on python tests/neighbors.py
- salloc -N1 -ppvis -A muygps --mpibind=on python tests/kernels.py
- salloc -N1 -ppvis -A muygps --mpibind=on python tests/gp.py
- salloc -N1 -ppvis -A muygps --mpibind=on python tests/scale_opt.py
- salloc -N1 -ppvis -A muygps --mpibind=on python tests/optimize.py
- salloc -N1 -ppvis -A muygps --mpibind=on python tests/predict.py
- salloc -N1 -ppvis -A muygps --mpibind=on python tests/multivariate.py
Expand All @@ -52,6 +53,7 @@ all_tests:
- echo "performing MPI tests"
- salloc -N1 --tasks-per-node=36 -ppvis -A muygps --mpibind=on python tests/kernels.py
- salloc -N1 --tasks-per-node=36 -ppvis -A muygps --mpibind=on python tests/gp.py
- salloc -N1 --tasks-per-node=36 -ppvis -A muygps --mpibind=on python tests/scale_opt.py
- salloc -N1 --tasks-per-node=36 -ppvis -A muygps --mpibind=on python tests/optimize.py
- salloc -N1 --tasks-per-node=36 -ppvis -A muygps --mpibind=on python tests/predict.py
- salloc -N1 --tasks-per-node=36 -ppvis -A muygps --mpibind=on python tests/multivariate.py
Expand Down
6 changes: 5 additions & 1 deletion MuyGPyS/_src/optimize/scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from MuyGPyS._src.util import _collect_implementation

[_analytic_scale_optim] = _collect_implementation(
[
_analytic_scale_optim,
_analytic_scale_optim_unnormalized,
] = _collect_implementation(
"MuyGPyS._src.optimize.scale",
"_analytic_scale_optim",
"_analytic_scale_optim_unnormalized",
)
2 changes: 1 addition & 1 deletion MuyGPyS/gp/hyperparameter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

from .scalar import Parameter, Parameter as ScalarParam
from .tensor import TensorParam
from .scale import AnalyticScale, FixedScale, ScaleFn
from .scale import AnalyticScale, DownSampleScale, FixedScale, ScaleFn
99 changes: 96 additions & 3 deletions MuyGPyS/gp/hyperparameter/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
from typing import Callable, Tuple, Type

import MuyGPyS._src.math as mm
import MuyGPyS._src.math.numpy as np
from MuyGPyS._src.util import _fullname
from MuyGPyS._src.optimize.scale import _analytic_scale_optim
from MuyGPyS._src.optimize.scale import (
_analytic_scale_optim,
_analytic_scale_optim_unnormalized,
)


class ScaleFn:
Expand Down Expand Up @@ -40,14 +44,23 @@ def __init__(
_backend_outer: Callable = mm.outer,
**kwargs,
):
self.val = _backend_ones(response_count)
self.val = _backend_ones(
self._check_positive_integer(response_count, "response")
)
self._trained = False

self._backend_ndarray = _backend_ndarray
self._backend_ftype = _backend_ftype
self._backend_farray = _backend_farray
self._backend_outer = _backend_outer

def _check_positive_integer(self, count, name) -> int:
if not isinstance(count, int) or count < 0:
raise ValueError(
f"{name} count must be a positive integer, not {count}"
)
return count

def __str__(self, **kwargs):
return f"{type(self).__name__}({self.val})"

Expand Down Expand Up @@ -169,6 +182,10 @@ class AnalyticScale(ScaleFn):
The integer number of response dimensions.
"""

def __init__(self, _backend_fn: Callable = _analytic_scale_optim, **kwargs):
super().__init__(**kwargs)
self._fn = _backend_fn

def get_opt_fn(self, muygps) -> Callable:
"""
Get a function to optimize the value of the :math:`\\sigma^2` scale
Expand Down Expand Up @@ -203,6 +220,82 @@ def get_opt_fn(self, muygps) -> Callable:
"""

def analytic_scale_opt_fn(K, nn_targets, *args, **kwargs):
return _analytic_scale_optim(muygps.noise.perturb(K), nn_targets)
return self._fn(muygps.noise.perturb(K), nn_targets)

return analytic_scale_opt_fn


class DownSampleScale(ScaleFn):
"""
An optimizable :math:`\\sigma^2` covariance scale parameter.
Identical to :class:`~MuyGPyS.gp.scale.FixedScale`, save that its
`get_opt_fn` method performs an analytic optimization.
Args:
response_count:
The integer number of response dimensions.
down_count:
The integer number of neighbors to sample, without replacement.
Must be less than `nn_count`.
iteration_count:
The number of iterations to
"""

def __init__(
self,
down_count: int = 10,
iteration_count: int = 10,
_backend_fn: Callable = _analytic_scale_optim_unnormalized,
**kwargs,
):
super().__init__(**kwargs)
self._down_count = self._check_positive_integer(
down_count, "down sample"
)
self._iteration_count = self._check_positive_integer(
iteration_count, "down sample iteration"
)
self._fn = _backend_fn

def get_opt_fn(self, muygps) -> Callable:
"""
Args:
muygps:
The model to used to create and perturb the kernel.
Returns:
A function with signature
`(K, nn_targets, *args, **kwargs) -> mm.ndarray` that perturbs the
`(batch_count, nn_count, nn_count)` tensor `K` with `muygps`'s noise
model before solving it against the
`(batch_count, nn_count, response_count)` tensor `nn_targets`.
"""

def downsample_analytic_scale_opt_fn(K, nn_targets, *args, **kwargs):
batch_count, nn_count, _ = K.shape
if nn_count <= self._down_count:
raise ValueError(
f"bad attempt to downsample {self._down_count} elements "
f"from a set of only {nn_count} options"
)
pK = muygps.noise.perturb(K)
scales = []
for _ in range(self._iteration_count):
sampled_indices = np.random.choice(
np.arange(nn_count),
size=self._down_count,
replace=False,
)
sampled_indices.sort()

pK_down = pK[:, sampled_indices, :]
pK_down = pK_down[:, :, sampled_indices]
nn_targets_down = nn_targets[:, sampled_indices, :]
scales.append(self._fn(pK_down, nn_targets_down))

return mm.atleast_1d(np.median(scales, axis=0)) / (
self._down_count * batch_count
)

return downsample_analytic_scale_opt_fn
20 changes: 12 additions & 8 deletions tests/backend/jax_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _make_muygps_rbf_n(cls):
cls.noise, _backend_fn=homoscedastic_perturb_n
),
scale=AnalyticScale(
_backend_fn=analytic_scale_optim_n,
_backend_ones=np.ones,
_backend_ndarray=np.ndarray,
_backend_ftype=np.ftype,
Expand All @@ -191,6 +192,7 @@ def _make_muygps_n(
),
noise=noise,
scale=AnalyticScale(
_backend_fn=analytic_scale_optim_n,
_backend_ones=np.ones,
_backend_ndarray=np.ndarray,
_backend_ftype=np.ftype,
Expand Down Expand Up @@ -262,6 +264,7 @@ def _make_muygps_rbf_j(cls):
cls.noise, _backend_fn=homoscedastic_perturb_j
),
scale=AnalyticScale(
_backend_fn=analytic_scale_optim_j,
_backend_ones=jnp.ones,
_backend_ndarray=jnp.ndarray,
_backend_ftype=jnp.ftype,
Expand All @@ -288,6 +291,7 @@ def _make_muygps_j(
),
noise=noise,
scale=AnalyticScale(
_backend_fn=analytic_scale_optim_j,
_backend_ones=jnp.ones,
_backend_ndarray=jnp.ndarray,
_backend_ftype=jnp.ftype,
Expand Down Expand Up @@ -850,10 +854,10 @@ def test_diagonal_variance_heteroscedastic(self):
def test_scale_optim(self):
self.assertTrue(
allclose_inv(
analytic_scale_optim_n(
self.muygps_gen_n.scale.get_opt_fn(self.muygps_gen_n)(
self.homoscedastic_K_n, self.batch_nn_targets_n
),
analytic_scale_optim_j(
self.muygps_gen_j.scale.get_opt_fn(self.muygps_gen_j)(
self.homoscedastic_K_j, self.batch_nn_targets_j
),
)
Expand All @@ -862,12 +866,12 @@ def test_scale_optim(self):
def test_scale_optim_heteroscedastic(self):
self.assertTrue(
allclose_inv(
analytic_scale_optim_n(
self.heteroscedastic_K_n, self.batch_nn_targets_n
),
analytic_scale_optim_j(
self.heteroscedastic_K_j, self.batch_nn_targets_j
),
self.muygps_heteroscedastic_n.scale.get_opt_fn(
self.muygps_heteroscedastic_n
)(self.heteroscedastic_K_n, self.batch_nn_targets_n),
self.muygps_heteroscedastic_j.scale.get_opt_fn(
self.muygps_heteroscedastic_j
)(self.heteroscedastic_K_j, self.batch_nn_targets_j),
)
)

Expand Down
36 changes: 20 additions & 16 deletions tests/backend/torch_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _make_muygps_rbf_n(cls):
cls.noise, _backend_fn=homoscedastic_perturb_n
),
scale=AnalyticScale(
_backend_fn=analytic_scale_optim_n,
_backend_ones=np.ones,
_backend_ndarray=np.ndarray,
_backend_ftype=np.ftype,
Expand All @@ -166,6 +167,7 @@ def _make_muygps_n(cls, smoothness, noise, deformation):
),
noise=noise,
scale=AnalyticScale(
_backend_fn=analytic_scale_optim_n,
_backend_ones=np.ones,
_backend_ndarray=np.ndarray,
_backend_ftype=np.ftype,
Expand Down Expand Up @@ -233,6 +235,7 @@ def _make_muygps_rbf_t(cls):
cls.noise, _backend_fn=homoscedastic_perturb_t
),
scale=AnalyticScale(
_backend_fn=analytic_scale_optim_t,
_backend_ones=torch.ones,
_backend_ndarray=torch.ndarray,
_backend_ftype=torch.ftype,
Expand All @@ -257,6 +260,7 @@ def _make_muygps_t(cls, smoothness, noise, deformation):
),
noise=noise,
scale=AnalyticScale(
_backend_fn=analytic_scale_optim_t,
_backend_ones=torch.ones,
_backend_ndarray=torch.ndarray,
_backend_ftype=torch.ftype,
Expand Down Expand Up @@ -649,12 +653,12 @@ def test_heteroscedastic_noise(self):
def test_posterior_mean(self):
self.assertTrue(
_allclose(
muygps_posterior_mean_n(
self.muygps_05_n.posterior_mean(
self.homoscedastic_K_n,
self.Kcross_n,
self.batch_nn_targets_n,
),
muygps_posterior_mean_t(
self.muygps_05_t.posterior_mean(
self.homoscedastic_K_t,
self.Kcross_t,
self.batch_nn_targets_t,
Expand All @@ -665,12 +669,12 @@ def test_posterior_mean(self):
def test_posterior_mean_heteroscedastic(self):
self.assertTrue(
_allclose(
muygps_posterior_mean_n(
self.muygps_heteroscedastic_n.posterior_mean(
self.heteroscedastic_K_n,
self.Kcross_n,
self.batch_nn_targets_n,
),
muygps_posterior_mean_t(
self.muygps_heteroscedastic_t.posterior_mean(
self.heteroscedastic_K_t,
self.Kcross_t,
self.batch_nn_targets_t,
Expand All @@ -681,10 +685,10 @@ def test_posterior_mean_heteroscedastic(self):
def test_diagonal_variance(self):
self.assertTrue(
np.allclose(
muygps_diagonal_variance_n(
self.muygps_05_n.posterior_variance(
self.homoscedastic_K_n, self.Kcross_n
),
muygps_diagonal_variance_t(
self.muygps_05_t.posterior_variance(
self.homoscedastic_K_t, self.Kcross_t
),
)
Expand All @@ -693,10 +697,10 @@ def test_diagonal_variance(self):
def test_diagonal_variance_heteroscedastic(self):
self.assertTrue(
np.allclose(
muygps_diagonal_variance_n(
self.muygps_heteroscedastic_n.posterior_variance(
self.heteroscedastic_K_n, self.Kcross_n
),
muygps_diagonal_variance_t(
self.muygps_heteroscedastic_t.posterior_variance(
self.heteroscedastic_K_t, self.Kcross_t
),
)
Expand All @@ -705,10 +709,10 @@ def test_diagonal_variance_heteroscedastic(self):
def test_scale_optim(self):
self.assertTrue(
np.allclose(
analytic_scale_optim_n(
self.muygps_rbf_n.scale.get_opt_fn(self.muygps_rbf_n)(
self.homoscedastic_K_n, self.batch_nn_targets_n
),
analytic_scale_optim_t(
self.muygps_rbf_t.scale.get_opt_fn(self.muygps_rbf_t)(
self.homoscedastic_K_t, self.batch_nn_targets_t
),
)
Expand All @@ -717,12 +721,12 @@ def test_scale_optim(self):
def test_scale_optim_heteroscedastic(self):
self.assertTrue(
np.allclose(
analytic_scale_optim_n(
self.heteroscedastic_K_n, self.batch_nn_targets_n
),
analytic_scale_optim_t(
self.heteroscedastic_K_t, self.batch_nn_targets_t
),
self.muygps_heteroscedastic_n.scale.get_opt_fn(
self.muygps_heteroscedastic_n
)(self.heteroscedastic_K_n, self.batch_nn_targets_n),
self.muygps_heteroscedastic_t.scale.get_opt_fn(
self.muygps_heteroscedastic_t
)(self.heteroscedastic_K_t, self.batch_nn_targets_t),
)
)

Expand Down
Loading

0 comments on commit 98f82c0

Please sign in to comment.