Skip to content

Commit

Permalink
simple working jupyter lab notebook example
Browse files Browse the repository at this point in the history
  • Loading branch information
gray95 committed Sep 9, 2024
1 parent acc0703 commit 0a10f20
Showing 1 changed file with 166 additions and 0 deletions.
166 changes: 166 additions & 0 deletions examples/0dim_normflow_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cc68f0ae-c4c3-4654-9b2f-149e0e6a8ce3",
"metadata": {},
"source": [
"This notebook serves as a gentle introduction to the normflow package. Let's begin by importing a some standard libraries."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8c51a10e-8602-4843-90e6-862d9b8f26ef",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys"
]
},
{
"cell_type": "markdown",
"id": "444e3c9a-928f-4b9b-b449-9996d961f026",
"metadata": {},
"source": [
"Now let's import the key objects from the normflow library. Net, Action, etc..."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d830795e-1e48-452f-b223-80b970d8d9f0",
"metadata": {},
"outputs": [],
"source": [
"from normflow import np, torch, Model\n",
"from normflow import backward_sanitychecker\n",
"from normflow.nn import DistConvertor_\n",
"from normflow.action import ScalarPhi4Action\n",
"from normflow.prior import NormalPrior"
]
},
{
"cell_type": "markdown",
"id": "3adc2eec-8328-43c3-9b96-3a696c337198",
"metadata": {},
"source": [
"We define the parameters of our scalar field theory and the machine learning parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f410b258-8c0f-46ef-b0a1-98e120fe8fbb",
"metadata": {},
"outputs": [],
"source": [
"m_sq=-1.2\n",
"lambd=0.5\n",
"knots_len=10\n",
"n_epochs=1000 \n",
"batch_size=1024\n",
"lat_shape=1 # basically a zero dimensional problem\n",
"nranks=1"
]
},
{
"cell_type": "markdown",
"id": "b4600e67-511c-4ef9-badd-25cf358d7f30",
"metadata": {},
"source": [
"It's time to instantiate the neural network and do the training."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6e315101-1fff-4f73-819b-2ed35a8921e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Not saving model snapshots\n",
"\n",
">>> Training progress (cpu) <<<\n",
"\n",
"Note: log(q/p) is estimated with normalized p; mean & error are obtained from samples in a batch\n",
"\n",
"Epoch: 1 | loss: -0.530407 | ess: 0.870286 | rho: 0.818643 | log(z): 1.11124(38) | log(q/p): 0.6(36) | accept_rate: 0.797(9)\n",
"Epoch: 10 | loss: -0.812281 | ess: 0.895045 | rho: 0.808877 | log(z): 1.10289(33) | log(q/p): 0.3(22) | accept_rate: 0.808(10)\n",
"Epoch: 100 | loss: -1.08288 | ess: 0.991174 | rho: 0.794849 | log(z): 1.115250(92) | log(q/p): 0.03(86) | accept_rate: 0.966(8)\n",
"Epoch: 200 | loss: -1.1083 | ess: 0.996703 | rho: 0.983689 | log(z): 1.111443(56) | log(q/p): 0.00(10) | accept_rate: 0.982(5)\n",
"Epoch: 300 | loss: -1.11372 | ess: 0.998849 | rho: 0.996756 | log(z): 1.114207(33) | log(q/p): 0.000(30) | accept_rate: 0.988(3)\n",
"Epoch: 400 | loss: -1.11222 | ess: 0.997468 | rho: 0.991354 | log(z): 1.113817(49) | log(q/p): 0.002(63) | accept_rate: 0.979(4)\n",
"Epoch: 500 | loss: -1.11301 | ess: 0.99917 | rho: 0.998099 | log(z): 1.113437(28) | log(q/p): 0.000(30) | accept_rate: 0.986(3)\n",
"Epoch: 600 | loss: -1.11163 | ess: 0.99911 | rho: 0.997981 | log(z): 1.112047(29) | log(q/p): 0.000(28) | accept_rate: 0.985(4)\n",
"Epoch: 700 | loss: -1.11178 | ess: 0.999655 | rho: 0.999024 | log(z): 1.111957(18) | log(q/p): 0.000(19) | accept_rate: 0.989(3)\n",
"Epoch: 800 | loss: -1.11132 | ess: 0.999335 | rho: 0.998904 | log(z): 1.111686(25) | log(q/p): 0.000(28) | accept_rate: 0.988(3)\n",
"Epoch: 900 | loss: -1.11311 | ess: 0.998195 | rho: 0.99595 | log(z): 1.114003(42) | log(q/p): 0.001(42) | accept_rate: 0.977(5)\n",
"Epoch: 1000 | loss: -1.1122 | ess: 0.99951 | rho: 0.999103 | log(z): 1.112445(22) | log(q/p): 0.000(22) | accept_rate: 0.988(3)\n",
"(cpu) Time = 2.91 sec.\n",
"Sanity check is OK if following numbers are zero up to round off:\n",
"1.38778e-15 1.11022e-15\n"
]
}
],
"source": [
"net_ = DistConvertor_(knots_len, symmetric=True)\n",
"\n",
"action_dict = dict(kappa=0, m_sq=m_sq, lambd=lambd)\n",
"prior = NormalPrior(shape=lat_shape)\n",
"action = ScalarPhi4Action(**action_dict)\n",
"\n",
"model = Model(net_=net_, prior=prior, action=action)\n",
"\n",
"snapshot_path = None\n",
"\n",
"hyperparam = dict(lr=0.01, weight_decay=0.)\n",
"\n",
"fit_kwargs = dict(\n",
" n_epochs=n_epochs,\n",
" save_every=None,\n",
" batch_size=batch_size // nranks,\n",
" hyperparam=hyperparam,\n",
" checkpoint_dict=dict(print_stride=100, snapshot_path=snapshot_path)\n",
" )\n",
"\n",
"model.fit(**fit_kwargs)\n",
"\n",
"backward_sanitychecker(model)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c6d3976-79b5-451b-987a-e7c7efe42610",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 0a10f20

Please sign in to comment.