Skip to content

Commit

Permalink
passing loss functions instead of strings. (#180)
Browse files Browse the repository at this point in the history
* unhooked experimental optimize checks.

* updated tests

* updated docs
  • Loading branch information
bwpriest authored Aug 1, 2023
1 parent 62e83af commit ebf0787
Show file tree
Hide file tree
Showing 22 changed files with 305 additions and 296 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/develop-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
test-group: [short, optimize, optimize-experimental, multivariate]
test-group: [short, optimize, multivariate]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -46,9 +46,9 @@ jobs:
- name: Optimize Tests
if: matrix.test-group == 'optimize'
run: python tests/optimize.py
- name: Optimize Tests
if: matrix.test-group == 'optimize-experimental'
run: python tests/mini_batch.py
# - name: Optimize Tests - experimental
# if: matrix.test-group == 'optimize-experimental'
# run: python tests/mini_batch.py
- name: Multivariate Tests
if: matrix.test-group == 'multivariate'
run: python tests/multivariate.py
Expand Down
32 changes: 16 additions & 16 deletions MuyGPyS/_test/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _do_classify_test_chassis(
target_acc: float,
nn_count: int,
batch_count: int,
loss_method: str,
loss_fn: Callable,
obj_method: str,
opt_method: str,
nn_kwargs: Dict,
Expand All @@ -54,7 +54,7 @@ def _do_classify_test_chassis(
test,
nn_count,
batch_count,
loss_method,
loss_fn,
obj_method,
opt_method,
nn_kwargs,
Expand Down Expand Up @@ -100,7 +100,7 @@ def _do_classify(
test: Dict[str, np.ndarray],
nn_count: int,
batch_count: int,
loss_method: str,
loss_fn: Callable,
obj_method: str,
opt_method: str,
nn_kwargs: Dict,
Expand All @@ -114,7 +114,7 @@ def _do_classify(
train["output"],
nn_count=nn_count,
batch_count=batch_count,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
k_kwargs=k_kwargs,
Expand Down Expand Up @@ -155,7 +155,7 @@ def _do_classify_uq_test_chassis(
nn_count: int,
opt_batch_count: int,
uq_batch_count: int,
loss_method: str,
loss_fn: Callable,
obj_method: str,
opt_method: str,
uq_objectives: Union[List[Callable], Tuple[Callable, ...]],
Expand All @@ -176,7 +176,7 @@ def _do_classify_uq_test_chassis(
nn_count,
opt_batch_count,
uq_batch_count,
loss_method,
loss_fn,
obj_method,
opt_method,
uq_objectives,
Expand Down Expand Up @@ -231,7 +231,7 @@ def _do_classify_uq(
nn_count: int,
opt_batch_count: int,
uq_batch_count: int,
loss_method: str,
loss_fn: Callable,
obj_method: str,
opt_method: str,
uq_objectives: Union[List[Callable], Tuple[Callable, ...]],
Expand All @@ -247,7 +247,7 @@ def _do_classify_uq(
nn_count=nn_count,
opt_batch_count=opt_batch_count,
uq_batch_count=uq_batch_count,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
uq_objectives=uq_objectives,
Expand All @@ -272,7 +272,7 @@ def _do_regress_test_chassis(
target_mse: float,
nn_count: int,
batch_count: int,
loss_method: str,
loss_fn: Callable,
obj_method: str,
opt_method: str,
sigma_method: Optional[str],
Expand All @@ -286,7 +286,7 @@ def _do_regress_test_chassis(
test,
nn_count,
batch_count,
loss_method,
loss_fn,
obj_method,
opt_method,
sigma_method,
Expand Down Expand Up @@ -335,7 +335,7 @@ def _do_regress(
test: Dict[str, np.ndarray],
nn_count: int,
batch_count: int,
loss_method: str,
loss_fn: Callable,
obj_method: str,
opt_method: str,
sigma_method: Optional[str],
Expand All @@ -351,7 +351,7 @@ def _do_regress(
train["output"],
nn_count=nn_count,
batch_count=batch_count,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
sigma_method=sigma_method,
Expand All @@ -375,7 +375,7 @@ def _do_fast_posterior_mean_test_chassis(
target_mse: float,
nn_count: int,
batch_count: int,
loss_method: str,
loss_fn: Callable,
obj_method: str,
opt_method: str,
nn_kwargs: Dict,
Expand All @@ -388,7 +388,7 @@ def _do_fast_posterior_mean_test_chassis(
test=test,
nn_count=nn_count,
batch_count=batch_count,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
nn_kwargs=nn_kwargs,
Expand Down Expand Up @@ -417,7 +417,7 @@ def _do_fast_posterior_mean(
test: Dict[str, np.ndarray],
nn_count: int,
batch_count: int,
loss_method: str,
loss_fn: Callable,
obj_method: str,
opt_method: str,
nn_kwargs: Dict,
Expand All @@ -437,7 +437,7 @@ def _do_fast_posterior_mean(
train["output"],
nn_count=nn_count,
batch_count=batch_count,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
k_kwargs=k_kwargs,
Expand Down
50 changes: 23 additions & 27 deletions MuyGPyS/examples/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,23 @@
import numpy as np

from time import perf_counter
from typing import Dict, List, Tuple, Union
from typing import Callable, Dict, List, Tuple, Union

from MuyGPyS.examples.from_indices import posterior_mean_from_indices
from MuyGPyS.gp import MuyGPS, MultivariateMuyGPS as MMuyGPS
from MuyGPyS.gp.tensors import make_train_tensors
from MuyGPyS.neighbors import NN_Wrapper
from MuyGPyS.optimize import optimize_from_tensors
from MuyGPyS.optimize.batch import get_balanced_batch
from MuyGPyS.optimize.loss import cross_entropy_fn


def make_classifier(
train_features: np.ndarray,
train_labels: np.ndarray,
nn_count: int = 30,
batch_count: int = 200,
loss_method: str = "log",
loss_fn: Callable = cross_entropy_fn,
obj_method: str = "loo_crossval",
opt_method: str = "bayes",
k_kwargs: Dict = dict(),
Expand Down Expand Up @@ -67,7 +68,7 @@ def make_classifier(
... train['output'],
... nn_count=30,
... batch_count=200,
... loss_method="log",
... loss_fn=cross_entropy_fn,
... obj_method="loo_crossval",
... opt_method="bayes",
... k_kwargs=k_kwargs,
Expand All @@ -79,7 +80,7 @@ def make_classifier(
... train['output'],
... nn_count=30,
... batch_count=200,
... loss_method="log",
... loss_fn=cross_entropy_fn,
... obj_method="loo_crossval",
... opt_method="bayes",
... k_kwargs=k_kwargs,
Expand All @@ -99,11 +100,9 @@ def make_classifier(
batch_count:
The number of elements to sample batch for hyperparameter
optimization.
loss_method:
The loss method to use in hyperparameter optimization. Ignored if
loss_fn:
The loss functor to use in hyperparameter optimization. Ignored if
all of the parameters specified by argument `k_kwargs` are fixed.
Currently supports only `"log"` (or `"cross-entropy"`) and `"mse"`
for classification.
opt_method:
Indicates the optimization method to be used. Currently restricted
to `"bayesian"` and `"scipy"`.
Expand Down Expand Up @@ -174,7 +173,7 @@ def make_classifier(
batch_nn_targets,
crosswise_diffs,
pairwise_diffs,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
sigma_method=None,
Expand All @@ -197,7 +196,7 @@ def make_multivariate_classifier(
train_labels: np.ndarray,
nn_count: int = 30,
batch_count: int = 200,
loss_method: str = "mse",
loss_fn: Callable = cross_entropy_fn,
obj_method: str = "loo_crossval",
opt_method: str = "bayes",
k_args: Union[List[Dict], Tuple[Dict, ...]] = list(),
Expand Down Expand Up @@ -233,7 +232,7 @@ def make_multivariate_classifier(
... train['output'],
... nn_count=30,
... batch_count=200,
... loss_method="mse",
... loss_fn=cross_entropy_fn,
... obj_method="loo_crossval",
... opt_method="bayes",
... k_args=k_args,
Expand All @@ -245,7 +244,7 @@ def make_multivariate_classifier(
... train['output'],
... nn_count=30,
... batch_count=200,
... loss_method="mse",
... loss_fn=cross_entropy_fn,
... obj_method="loo_crossval",
... opt_method="bayes",
... k_args=k_args,
Expand All @@ -265,10 +264,9 @@ def make_multivariate_classifier(
batch_count:
The number of elements to sample batch for hyperparameter
optimization.
loss_method:
The loss method to use in hyperparameter optimization. Ignored if
loss_fn:
The loss functor to use in hyperparameter optimization. Ignored if
all of the parameters specified by argument `k_kwargs` are fixed.
Currently supports only `"log"` for classification.
obj_method:
Indicates the objective function to be minimized. Currently
restricted to `"loo_crossval"`.
Expand Down Expand Up @@ -349,7 +347,7 @@ def make_multivariate_classifier(
),
crosswise_diffs,
pairwise_diffs,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
sigma_method=None,
Expand All @@ -372,7 +370,7 @@ def _decide_and_make_classifier(
train_labels: np.ndarray,
nn_count: int = 30,
batch_count: int = 200,
loss_method: str = "log",
loss_fn: Callable = cross_entropy_fn,
obj_method: str = "loo_crossval",
opt_method: str = "bayes",
k_kwargs: Union[Dict, Union[List[Dict], Tuple[Dict, ...]]] = dict(),
Expand All @@ -386,7 +384,7 @@ def _decide_and_make_classifier(
train_labels,
nn_count=nn_count,
batch_count=batch_count,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
k_args=k_kwargs,
Expand All @@ -401,7 +399,7 @@ def _decide_and_make_classifier(
train_labels,
nn_count=nn_count,
batch_count=batch_count,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
k_kwargs=k_kwargs,
Expand All @@ -423,7 +421,7 @@ def do_classify(
train_labels: np.ndarray,
nn_count: int = 30,
batch_count: int = 200,
loss_method: str = "log",
loss_fn: Callable = cross_entropy_fn,
obj_method: str = "loo_crossval",
opt_method: str = "bayes",
k_kwargs: Union[Dict, Union[List[Dict], Tuple[Dict, ...]]] = dict(),
Expand Down Expand Up @@ -460,7 +458,7 @@ def do_classify(
... train['output'],
... nn_count=30,
... batch_count=200,
... loss_method="log",
... loss_fn=cross_entropy_fn,
... obj_method="loo_crossval",
... opt_method="bayes",
... k_kwargs=k_kwargs,
Expand All @@ -487,11 +485,9 @@ def do_classify(
The number of nearest neighbors to employ.
batch_count:
The batch size for hyperparameter optimization.
loss_method:
The loss method to use in hyperparameter optimization. Ignored if
all of the parameters specified by `k_kwargs` are fixed. Currently
supports only `"log"` (also known as `"cross_entropy"`) and `"mse"`
for classification.
loss_fn:
The loss functor to use in hyperparameter optimization. Ignored if
all of the parameters specified by `k_kwargs` are fixed.
obj_method:
Indicates the objective function to be minimized. Currently
restricted to `"loo_crossval"`.
Expand Down Expand Up @@ -533,7 +529,7 @@ def do_classify(
train_labels,
nn_count=nn_count,
batch_count=batch_count,
loss_method=loss_method,
loss_fn=loss_fn,
obj_method=obj_method,
opt_method=opt_method,
k_kwargs=k_kwargs,
Expand Down
Loading

0 comments on commit ebf0787

Please sign in to comment.