diff --git a/desc/objectives/_free_boundary.py b/desc/objectives/_free_boundary.py index bdd06832d6..2927437f37 100644 --- a/desc/objectives/_free_boundary.py +++ b/desc/objectives/_free_boundary.py @@ -14,7 +14,6 @@ from desc.singularities import ( DFTInterpolator, FFTInterpolator, - singular_integral, virtual_casing_biot_savart, ) from desc.utils import Timer, errorif, warnif @@ -801,6 +800,9 @@ class QuadraticFlux(_Objective): Whether or not the external field's DOF (params) change during optimization. vacuum : bool If true, Bplasma is set to zero. + loop : bool + If True, evaluate integral using loops, as opposed to vmap. Slower, but uses + less memory. name : str Name of the objective function. @@ -829,6 +831,7 @@ def __init__( eq_fixed=False, field_fixed=False, vacuum=False, + loop=True, name="Quadratic flux", ): self._src_grid = src_grid @@ -841,6 +844,7 @@ def __init__( self._eq_fixed = eq_fixed self._field_fixed = field_fixed self._vacuum = vacuum + self._loop = loop if not eq_fixed and not field_fixed: things = [ext_field, eq] elif eq_fixed and not field_fixed: @@ -859,13 +863,11 @@ def __init__( name=name, ) - def build(self, eq=None, use_jit=True, verbose=1): + def build(self, use_jit=True, verbose=1): """Build constant arrays. Parameters ---------- - eq : Equilibrium, optional - Equilibrium that will be optimized to satisfy the Objective. use_jit : bool, optional Whether to just-in-time compile the objective and derivatives. verbose : int, optional @@ -985,11 +987,11 @@ def build(self, eq=None, use_jit=True, verbose=1): ) # don't need extra B/2 since we only care about normal component - Bplasma = -singular_integral( + Bplasma = virtual_casing_biot_savart( eval_data, src_data, - "biot_savart", - interpolator, + self._constants["interpolator"], + loop=self._loop, ) self._constants.update( @@ -1063,11 +1065,11 @@ def compute(self, params_1=None, params_2=None, constants=None): ) # don't need extra B/2 since we only care about normal component - Bplasma = -singular_integral( + Bplasma = virtual_casing_biot_savart( eval_data, src_data, - "biot_savart", - constants["interpolator"], + self._constants["interpolator"], + loop=self._loop, ) x = jnp.array( @@ -1080,7 +1082,7 @@ def compute(self, params_1=None, params_2=None, constants=None): # can't pre-compute Bext because it is dependent on eval_grid Bext = self._ext_field.compute_magnetic_field( - x, grid=self._field_grid, basis="rpz", params=field_params + x, source_grid=self._field_grid, basis="rpz", params=field_params ) f = jnp.sum((Bext + Bplasma) * eval_data["n_rho"], axis=-1) diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 8fbce8d34d..9f0a395fc8 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -637,7 +637,7 @@ def test_quadratic_flux(self, quadratic_flux_equilibriums): obj = QuadraticFlux(t_field, eq, eq_fixed=True) obj.build(eq) f = obj.compute(params_1=t_field.params_dict) - np.testing.assert_allclose(f, 0) + np.testing.assert_allclose(f, 0, rtol=1e-15, atol=1e-15) # test nonaxisymmetric surface eq = desc.io.load("desc/examples/precise_QA_output.h5")[0] @@ -659,6 +659,7 @@ def test_quadratic_flux(self, quadratic_flux_equilibriums): np.testing.assert_allclose(f, Bnorm, atol=1e-3) + @pytest.mark.unit def test_quadratic_flux_with_field_fixed(self, quadratic_flux_equilibriums): """Test with field_fixed = True, eq_fixed = False. @@ -716,6 +717,7 @@ def test_quadratic_flux_with_field_fixed(self, quadratic_flux_equilibriums): np.testing.assert_allclose(new_Rb_lmn, 0, atol=1e-10) np.testing.assert_allclose(new_Zb_lmn, 0, atol=1e-10) + @pytest.mark.unit def test_quadratic_flux_with_eq_fixed(self, quadratic_flux_equilibriums): """Test with eq_fixed = True, field_fixed = True. @@ -755,6 +757,7 @@ def test_quadratic_flux_with_eq_fixed(self, quadratic_flux_equilibriums): np.testing.assert_allclose(things[0].Phi_mn, 0, atol=1e-8) + @pytest.mark.unit def test_quadratic_flux_with_eq_and_field_unfixed( self, quadratic_flux_equilibriums ): @@ -816,6 +819,7 @@ def test_quadratic_flux_with_eq_and_field_unfixed( np.testing.assert_allclose(things[0].Phi_mn, 0, atol=1e-7) + @pytest.mark.unit def test_quadratic_flux_with_analytic_field(self): """Test analytic field optimization when eq_fixed=True, field_fixed=False. @@ -850,6 +854,7 @@ def test_quadratic_flux_with_analytic_field(self): # to get to Bnorm = 0 np.testing.assert_allclose(things[0].B0, 0, atol=1e-3) + @pytest.mark.unit def test_quadratic_flux_vacuum(self, quadratic_flux_equilibriums): """Test vacuum flag.""" # equilibrium that has Bplasma != 0