Skip to content

Commit

Permalink
address some reviews
Browse files Browse the repository at this point in the history
- deleted eq arg in build()
- used wrapper for Bplasma
- updated grid to source_grid for compute_magnetic_field
- add tolerances for testing axisymmetric Bnorm (maybe #864)
  • Loading branch information
kianorr committed Feb 27, 2024
1 parent a73fceb commit b8c2070
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
24 changes: 13 additions & 11 deletions desc/objectives/_free_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from desc.singularities import (
DFTInterpolator,
FFTInterpolator,
singular_integral,
virtual_casing_biot_savart,
)
from desc.utils import Timer, errorif, warnif
Expand Down Expand Up @@ -801,6 +800,9 @@ class QuadraticFlux(_Objective):
Whether or not the external field's DOF (params) change during optimization.
vacuum : bool
If true, Bplasma is set to zero.
loop : bool
If True, evaluate integral using loops, as opposed to vmap. Slower, but uses
less memory.
name : str
Name of the objective function.
Expand Down Expand Up @@ -829,6 +831,7 @@ def __init__(
eq_fixed=False,
field_fixed=False,
vacuum=False,
loop=True,
name="Quadratic flux",
):
self._src_grid = src_grid
Expand All @@ -841,6 +844,7 @@ def __init__(
self._eq_fixed = eq_fixed
self._field_fixed = field_fixed
self._vacuum = vacuum
self._loop = loop
if not eq_fixed and not field_fixed:
things = [ext_field, eq]
elif eq_fixed and not field_fixed:
Expand All @@ -859,13 +863,11 @@ def __init__(
name=name,
)

def build(self, eq=None, use_jit=True, verbose=1):
def build(self, use_jit=True, verbose=1):
"""Build constant arrays.
Parameters
----------
eq : Equilibrium, optional
Equilibrium that will be optimized to satisfy the Objective.
use_jit : bool, optional
Whether to just-in-time compile the objective and derivatives.
verbose : int, optional
Expand Down Expand Up @@ -985,11 +987,11 @@ def build(self, eq=None, use_jit=True, verbose=1):
)

# don't need extra B/2 since we only care about normal component
Bplasma = -singular_integral(
Bplasma = virtual_casing_biot_savart(
eval_data,
src_data,
"biot_savart",
interpolator,
self._constants["interpolator"],
loop=self._loop,
)

self._constants.update(
Expand Down Expand Up @@ -1063,11 +1065,11 @@ def compute(self, params_1=None, params_2=None, constants=None):
)

# don't need extra B/2 since we only care about normal component
Bplasma = -singular_integral(
Bplasma = virtual_casing_biot_savart(
eval_data,
src_data,
"biot_savart",
constants["interpolator"],
self._constants["interpolator"],
loop=self._loop,
)

x = jnp.array(
Expand All @@ -1080,7 +1082,7 @@ def compute(self, params_1=None, params_2=None, constants=None):

# can't pre-compute Bext because it is dependent on eval_grid
Bext = self._ext_field.compute_magnetic_field(
x, grid=self._field_grid, basis="rpz", params=field_params
x, source_grid=self._field_grid, basis="rpz", params=field_params
)

f = jnp.sum((Bext + Bplasma) * eval_data["n_rho"], axis=-1)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def test_quadratic_flux(self, quadratic_flux_equilibriums):
obj = QuadraticFlux(t_field, eq, eq_fixed=True)
obj.build(eq)
f = obj.compute(params_1=t_field.params_dict)
np.testing.assert_allclose(f, 0)
np.testing.assert_allclose(f, 0, rtol=1e-15, atol=1e-15)

# test nonaxisymmetric surface
eq = desc.io.load("desc/examples/precise_QA_output.h5")[0]
Expand All @@ -659,6 +659,7 @@ def test_quadratic_flux(self, quadratic_flux_equilibriums):

np.testing.assert_allclose(f, Bnorm, atol=1e-3)

@pytest.mark.unit
def test_quadratic_flux_with_field_fixed(self, quadratic_flux_equilibriums):
"""Test with field_fixed = True, eq_fixed = False.
Expand Down Expand Up @@ -716,6 +717,7 @@ def test_quadratic_flux_with_field_fixed(self, quadratic_flux_equilibriums):
np.testing.assert_allclose(new_Rb_lmn, 0, atol=1e-10)
np.testing.assert_allclose(new_Zb_lmn, 0, atol=1e-10)

@pytest.mark.unit
def test_quadratic_flux_with_eq_fixed(self, quadratic_flux_equilibriums):
"""Test with eq_fixed = True, field_fixed = True.
Expand Down Expand Up @@ -755,6 +757,7 @@ def test_quadratic_flux_with_eq_fixed(self, quadratic_flux_equilibriums):

np.testing.assert_allclose(things[0].Phi_mn, 0, atol=1e-8)

@pytest.mark.unit
def test_quadratic_flux_with_eq_and_field_unfixed(
self, quadratic_flux_equilibriums
):
Expand Down Expand Up @@ -816,6 +819,7 @@ def test_quadratic_flux_with_eq_and_field_unfixed(

np.testing.assert_allclose(things[0].Phi_mn, 0, atol=1e-7)

@pytest.mark.unit
def test_quadratic_flux_with_analytic_field(self):
"""Test analytic field optimization when eq_fixed=True, field_fixed=False.
Expand Down Expand Up @@ -850,6 +854,7 @@ def test_quadratic_flux_with_analytic_field(self):
# to get to Bnorm = 0
np.testing.assert_allclose(things[0].B0, 0, atol=1e-3)

@pytest.mark.unit
def test_quadratic_flux_vacuum(self, quadratic_flux_equilibriums):
"""Test vacuum flag."""
# equilibrium that has Bplasma != 0
Expand Down

0 comments on commit b8c2070

Please sign in to comment.