Skip to content

Commit

Permalink
partial fix to torch parameter optimization (#184)
Browse files Browse the repository at this point in the history
* cleanup

* partially fixed bug disconnecting model hyperparameters from torch optimization

* Only works for stationary isotropic models
  • Loading branch information
bwpriest authored Aug 21, 2023
1 parent 295176f commit d54b7db
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 6 deletions.
9 changes: 9 additions & 0 deletions MuyGPyS/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ndarray,
ones,
outer,
parameter,
repeat,
reshape,
sqrt,
Expand Down Expand Up @@ -89,6 +90,7 @@
"ndarray",
"ones",
"outer",
"parameter",
"repeat",
"reshape",
"sqrt",
Expand All @@ -99,3 +101,10 @@
"where",
"zeros",
)


def promote(x):
if isinstance(x, ndarray):
return x
else:
return array(x)
4 changes: 4 additions & 0 deletions MuyGPyS/_src/math/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,7 @@ def assign(x: ndarray, y: ndarray, *slices) -> ndarray:
diagonal, eye, full, linspace, ones, zeros = fix_function_types(
ftype, _diagonal, _eye, _full, _linspace, _ones, _zeros
)


def parameter(x):
return x
4 changes: 4 additions & 0 deletions MuyGPyS/_src/math/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ def assign(x: ndarray, y: ndarray, *slices) -> ndarray:
diagonal, eye, full, linspace, ones, zeros = fix_function_types(
ftype, _diagonal, _eye, _full, _linspace, _ones, _zeros
)


def parameter(x):
return x
4 changes: 4 additions & 0 deletions MuyGPyS/_src/math/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,7 @@ def assign(x: ndarray, y: ndarray, *slices) -> ndarray:
diagonal, eye, full, linspace, ones, zeros = fix_function_types(
ftype, _diagonal, _eye, _full, _linspace, _ones, _zeros
)


def parameter(x):
return nn.Parameter(_array(x))
3 changes: 3 additions & 0 deletions MuyGPyS/gp/distortion/anisotropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def _get_length_scale_array(
batch_features=None,
**length_scales,
) -> mm.ndarray:
# NOTE[MWP] THIS WILL NOT WORK WITH TORCH OPTIMIZATION.
# We need to eliminate the implicit copy. Will need indirection.
# We should make this whole workflow ifless.
AnisotropicDistortion._lengths_agree(
len(length_scales), target_shape[-1]
)
Expand Down
6 changes: 3 additions & 3 deletions MuyGPyS/gp/distortion/isotropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ def __call__(
self, diffs: mm.ndarray, length_scale: Union[float, mm.ndarray]
) -> mm.ndarray:
length_scale_array = self._get_length_scale_array(
mm.array, diffs.shape, length_scale
diffs.shape, length_scale
)
return self._dist_fn(diffs / length_scale_array)

@staticmethod
def _get_length_scale_array(
array_fn: Callable,
target_shape: mm.ndarray,
length_scale: Union[float, mm.ndarray],
) -> mm.ndarray:
# make sure length_scale is broadcastable when its shape is (batch_count,)
# NOTE[MWP] there is probably a better way to do this
shape = (-1,) + (1,) * (len(target_shape) - 1)
return mm.reshape(array_fn(length_scale), shape)
return mm.reshape(mm.promote(length_scale), shape)

def get_opt_params(
self,
Expand Down
2 changes: 1 addition & 1 deletion MuyGPyS/gp/hyperparameter/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _set_val(self, val: Union[str, float]) -> None:
f"Hyperparameter value {val} is greater than the "
f"optimization upper bound {self._bounds[1]}"
)
self._val = val
self._val = mm.parameter(val)

def _set_bounds(
self,
Expand Down
1 change: 0 additions & 1 deletion MuyGPyS/gp/kernels/matern.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def __call__(self, diffs):
tensor of shape `(data_count, nn_count, nn_count)` whose last two
dimensions are kernel matrices.
"""

return self._fn(diffs, nu=self.nu())

def get_opt_params(
Expand Down
21 changes: 20 additions & 1 deletion MuyGPyS/torch/muygps_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
"""
from MuyGPyS import config
from MuyGPyS._src.math.torch import nn
from MuyGPyS.gp.distortion import IsotropicDistortion
from MuyGPyS.gp.hyperparameter import ScalarHyperparameter
from MuyGPyS.gp.muygps import MuyGPS
from MuyGPyS.gp.tensors import (
pairwise_tensor,
crosswise_tensor,
)

from MuyGPyS.gp.muygps import MuyGPS

if config.state.backend != "torch":
import warnings
Expand Down Expand Up @@ -105,7 +107,23 @@ def __init__(
batch_nn_targets,
):
super().__init__()
if not isinstance(
muygps_model.kernel.distortion_fn, IsotropicDistortion
):
raise NotImplementedError(
"MuyGPyS/torch optimization does not support "
f"{type(muygps_model.kernel.distortion_fn)} distortions"
)
if not isinstance(
muygps_model.kernel.distortion_fn.length_scale, ScalarHyperparameter
):
raise NotImplementedError(
"MuyGPyS/torch optimization does not support "
f"{type(muygps_model.kernel.distortion_fn.length_scale)} "
"length scales"
)
self.muygps_model = muygps_model
self.length_scale = muygps_model.kernel.distortion_fn.length_scale._val
self.batch_indices = batch_indices
self.batch_nn_indices = batch_nn_indices
self.batch_targets = batch_targets
Expand All @@ -124,6 +142,7 @@ def forward(self, x):
A torch.ndarray of shape `(batch_count,response_count)`
consisting of the diagonal elements of the posterior variance.
"""
self.muygps_model._make()

crosswise_diffs = crosswise_tensor(
x,
Expand Down

0 comments on commit d54b7db

Please sign in to comment.