Skip to content

Commit

Permalink
Merge pull request #243 from xopt-org/forgetting_bo
Browse files Browse the repository at this point in the history
Forgetting bo
  • Loading branch information
roussel-ryan authored Sep 17, 2024
2 parents db63e78 + 85efa10 commit c8adf98
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 118 deletions.
212 changes: 102 additions & 110 deletions docs/examples/single_objective_bayes_opt/time_dependent_bo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
}
},
"source": [
"## Time dependent Bayesian Optimization\n",
"# Time dependent Bayesian Optimization\n",
"\n",
"In this example we demonstrate time dependent optimization. In this case we are not\n",
"only interested in finding an optimum point in input space, but also maintain the\n",
Expand All @@ -19,41 +19,25 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-19T20:39:46.889328100Z",
"start_time": "2023-09-19T20:39:43.762243Z"
},
"execution": {
"iopub.execute_input": "2024-09-13T15:55:15.936192Z",
"iopub.status.busy": "2024-09-13T15:55:15.935968Z",
"iopub.status.idle": "2024-09-13T15:55:17.648118Z",
"shell.execute_reply": "2024-09-13T15:55:17.647731Z"
}
},
"metadata": {},
"outputs": [],
"source": [
"# set values if testing\n",
"import os\n",
"from xopt.generators.bayesian.upper_confidence_bound import (\n",
" TDUpperConfidenceBoundGenerator,\n",
")\n",
"import time\n",
"import warnings\n",
"import torch\n",
"from matplotlib import pyplot as plt\n",
"from tqdm import trange\n",
"from xopt.generators.bayesian import TDUpperConfidenceBoundGenerator\n",
"from xopt.vocs import VOCS\n",
"from xopt.evaluator import Evaluator\n",
"\n",
"from xopt import Xopt\n",
"\n",
"import torch\n",
"from matplotlib import pyplot as plt\n",
"\n",
"import time\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")\n",
"N_MC_SAMPLES = 1 if SMOKE_TEST else 128\n",
"NUM_RESTARTS = 1 if SMOKE_TEST else 20"
"NUM_RESTARTS = 1 if SMOKE_TEST else 20\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand All @@ -65,20 +49,24 @@
}
},
"source": [
"### Time dependent test problem\n",
"## Time dependent test problem\n",
"Optimization is carried out over a single variable `x`. The test function is a simple\n",
" quadratic, with a minimum location that drifts in the positive `x` direction over\n",
" (real) time."
" quadratic, with a minimum location that drifts and changes as a function of time `t`."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"Define test functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-19T20:39:46.948328700Z",
"start_time": "2023-09-19T20:39:46.891299700Z"
},
"execution": {
"iopub.execute_input": "2024-09-13T15:55:17.650035Z",
"iopub.status.busy": "2024-09-13T15:55:17.649856Z",
Expand All @@ -88,43 +76,79 @@
},
"outputs": [],
"source": [
"# test evaluate function and vocs\n",
"# location of time dependent minimum\n",
"def k(t_):\n",
" return torch.where(\n",
" t_ < 50, 0.25 * torch.sin(t_ * 6 / 10.0) + 0.1e-2 * t_, -1.5e-2 * (t_ - 50.0)\n",
" )\n",
"\n",
"\n",
"# define function in time and position space\n",
"def g(x_, t_):\n",
" return (x_ - k(t_)) ** 2\n",
"\n",
"\n",
"start_time = time.time()\n",
"\n",
"\n",
"# create callable function for Xopt\n",
"def f(inputs):\n",
" x_ = inputs[\"x\"]\n",
" current_time = time.time()\n",
" t_ = current_time - start_time\n",
" y_ = 5 * (x_ - t_ * 1e-2) ** 2\n",
" return {\"y\": y_, \"time\": current_time}\n",
"\n",
" y_ = g(x_, torch.tensor(t_))\n",
"\n",
" return {\"y\": float(y_), \"time\": float(current_time)}"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Define Xopt objects including optimization algorithm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"variables = {\"x\": [-1, 1]}\n",
"objectives = {\"y\": \"MINIMIZE\"}\n",
"\n",
"vocs = VOCS(variables=variables, objectives=objectives)\n",
"print(vocs)\n",
"\n",
"evaluator = Evaluator(function=f)\n",
"generator = TDUpperConfidenceBoundGenerator(vocs=vocs)\n",
"generator.added_time = 1.0\n",
"generator.beta = 2.0\n",
"generator = TDUpperConfidenceBoundGenerator(\n",
" vocs=vocs,\n",
" beta=0.01,\n",
" added_time=0.1,\n",
" forgetting_time=20.0,\n",
")\n",
"generator.n_monte_carlo_samples = N_MC_SAMPLES\n",
"generator.numerical_optimizer.n_restarts = NUM_RESTARTS\n",
"generator.max_travel_distances = [0.1]\n",
"\n",
"X = Xopt(evaluator=evaluator, generator=generator, vocs=vocs)\n",
"X"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Run optimization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-19T20:39:53.521345600Z",
"start_time": "2023-09-19T20:39:46.922297700Z"
},
"execution": {
"iopub.execute_input": "2024-09-13T15:55:17.679779Z",
"iopub.status.busy": "2024-09-13T15:55:17.679658Z",
Expand All @@ -136,70 +160,37 @@
"source": [
"X.random_evaluate(1)\n",
"\n",
"for _ in range(20):\n",
"for _ in trange(300):\n",
" # note that in this example we can ignore warnings if computation time is greater\n",
" # than added time\n",
" with warnings.catch_warnings():\n",
" warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n",
" X.step()\n",
" time.sleep(0.1)\n",
"\n",
"print(X.generator.generate(1))"
" time.sleep(0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-19T20:39:53.564407600Z",
"start_time": "2023-09-19T20:39:53.522344500Z"
},
"execution": {
"iopub.execute_input": "2024-09-13T15:55:40.920087Z",
"iopub.status.busy": "2024-09-13T15:55:40.919502Z",
"iopub.status.idle": "2024-09-13T15:55:40.932446Z",
"shell.execute_reply": "2024-09-13T15:55:40.932014Z"
}
"collapsed": false
},
"outputs": [],
"source": [
"X.data"
"## Visualize GP model of objective function and plot trajectory"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-19T20:39:55.700392Z",
"start_time": "2023-09-19T20:39:53.537383500Z"
},
"execution": {
"iopub.execute_input": "2024-09-13T15:55:40.934767Z",
"iopub.status.busy": "2024-09-13T15:55:40.934580Z",
"iopub.status.idle": "2024-09-13T15:55:41.787991Z",
"shell.execute_reply": "2024-09-13T15:55:41.787705Z"
}
},
"metadata": {},
"outputs": [],
"source": [
"# plot model\n",
"\n",
"# plot model predictions\n",
"\n",
"data = X.data\n",
"\n",
"xbounds = generator.vocs.bounds\n",
"tbounds = [data[\"time\"].min(), data[\"time\"].max()]\n",
"\n",
"\n",
"def gt(inpts):\n",
" return 5 * (inpts[:, 1] - (inpts[:, 0] - start_time) * 1e-2) ** 2\n",
"\n",
"\n",
"model = X.generator.model\n",
"n = 200\n",
"n = 100\n",
"t = torch.linspace(*tbounds, n, dtype=torch.double)\n",
"x = torch.linspace(*xbounds.flatten(), n, dtype=torch.double)\n",
"tt, xx = torch.meshgrid(t, x)\n",
Expand All @@ -210,7 +201,7 @@
"# NOTE: the model inputs are such that t is the last dimension\n",
"gp_pts = torch.flip(pts, dims=[-1])\n",
"\n",
"gt_vals = gt(pts)\n",
"gt_vals = g(gp_pts.T[0], gp_pts.T[1] - start_time)\n",
"\n",
"with torch.no_grad():\n",
" post = model.posterior(gp_pts)\n",
Expand All @@ -223,6 +214,10 @@
" ax.set_xlabel(\"unix time\")\n",
" ax.set_ylabel(\"x\")\n",
" c = ax.pcolor(tt, xx, mean.reshape(n, n))\n",
" ax.plot(data[\"time\"].to_numpy(), data[\"x\"].to_numpy(), \"oC1\", label=\"samples\")\n",
"\n",
" ax.plot(t, k(t - start_time), \"C3--\", label=\"ideal path\", zorder=10)\n",
" ax.legend()\n",
" fig.colorbar(c)\n",
"\n",
" fig2, ax2 = plt.subplots()\n",
Expand All @@ -232,45 +227,30 @@
" c = ax2.pcolor(tt, xx, std.reshape(n, n))\n",
" fig2.colorbar(c)\n",
"\n",
" ax.plot(data[\"time\"].to_numpy(), data[\"x\"].to_numpy(), \"oC1\")\n",
" ax2.plot(data[\"time\"].to_numpy(), data[\"x\"].to_numpy(), \"oC1\")\n",
"\n",
" fig3, ax3 = plt.subplots()\n",
" ax3.set_title(\"ground truth value\")\n",
" ax3.set_xlabel(\"unix time\")\n",
" ax3.set_ylabel(\"x\")\n",
" c = ax3.pcolor(tt, xx, gt_vals.reshape(n, n))\n",
" fig3.colorbar(c)"
" fig3.colorbar(c)\n",
"\n",
" ax2.plot(data[\"time\"].to_numpy(), data[\"x\"].to_numpy(), \"oC1\")\n",
" ax3.plot(data[\"time\"].to_numpy(), data[\"x\"].to_numpy(), \"oC1\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-19T20:39:55.714394400Z",
"start_time": "2023-09-19T20:39:55.701395500Z"
},
"execution": {
"iopub.execute_input": "2024-09-13T15:55:41.789510Z",
"iopub.status.busy": "2024-09-13T15:55:41.789388Z",
"iopub.status.idle": "2024-09-13T15:55:41.793271Z",
"shell.execute_reply": "2024-09-13T15:55:41.793028Z"
}
"collapsed": false
},
"outputs": [],
"source": [
"list(model.named_parameters())"
"## plot the acquisition function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-09-19T20:39:56.432421500Z",
"start_time": "2023-09-19T20:39:55.716392Z"
},
"execution": {
"iopub.execute_input": "2024-09-13T15:55:41.794609Z",
"iopub.status.busy": "2024-09-13T15:55:41.794517Z",
Expand All @@ -280,7 +260,6 @@
},
"outputs": [],
"source": [
"# plot the acquisition function\n",
"# note that target time is only updated during the generate call\n",
"target_time = X.generator.target_prediction_time\n",
"print(target_time - start_time)\n",
Expand All @@ -293,11 +272,24 @@
"\n",
" fig, ax = plt.subplots()\n",
" c = ax.pcolor(tt, xx, full_acq.reshape(n, n))\n",
" ax.set_xlabel(\"unix time\")\n",
" ax.set_ylabel(\"x\")\n",
" ax.set_title(\"acquisition function\")\n",
" fig.colorbar(c)\n",
"\n",
" fi2, ax2 = plt.subplots()\n",
" ax2.plot(x.flatten(), fixed_acq.flatten())"
" ax2.plot(x.flatten(), fixed_acq.flatten())\n",
" ax2.set_xlabel(\"x\")\n",
" ax2.set_ylabel(\"acquisition function\")\n",
" ax2.set_title(\"acquisition function at last time step\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
8 changes: 7 additions & 1 deletion xopt/generators/bayesian/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from xopt.generators.bayesian.bayesian_exploration import BayesianExplorationGenerator
from xopt.generators.bayesian.expected_improvement import ExpectedImprovementGenerator
from xopt.generators.bayesian.expected_improvement import (
ExpectedImprovementGenerator,
TDExpectedImprovementGenerator,
)
from xopt.generators.bayesian.mobo import MOBOGenerator
from xopt.generators.bayesian.multi_fidelity import MultiFidelityGenerator
from xopt.generators.bayesian.upper_confidence_bound import (
TDUpperConfidenceBoundGenerator,
UpperConfidenceBoundGenerator,
)

Expand All @@ -13,4 +17,6 @@
"UpperConfidenceBoundGenerator",
"ExpectedImprovementGenerator",
"MultiFidelityGenerator",
"TDUpperConfidenceBoundGenerator",
"TDExpectedImprovementGenerator",
]
1 change: 1 addition & 0 deletions xopt/generators/bayesian/bayesian_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class BayesianGenerator(Generator, ABC):
description="custom objective for optimization, replaces objective specified by VOCS",
)
n_interpolate_points: Optional[PositiveInt] = None
memory_length: Optional[PositiveInt] = None

n_candidates: int = 1

Expand Down
Loading

0 comments on commit c8adf98

Please sign in to comment.