diff --git a/CHANGELOG.md b/CHANGELOG.md index a1faa87b1b..c9f1de2e3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ New Features Bug Fixes - Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version - +- Changes ``FixLambdaGauge`` constraint to now enforce zero flux surface average for lambda, instead of enforcing lambda(rho,0,0)=0 as it was incorrectly doing before. v0.12.3 ------- diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index 6f302f2df0..3e6275d0c9 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -12,7 +12,7 @@ from termcolor import colored from desc.backend import execute_on_cpu, jnp, tree_leaves, tree_map, tree_structure -from desc.basis import zernike_radial, zernike_radial_coeffs +from desc.basis import zernike_radial from desc.geometry import FourierRZCurve from desc.utils import broadcast_tree, errorif, setdefault @@ -774,8 +774,8 @@ def build(self, use_jit=False, verbose=1): super().build(use_jit=use_jit, verbose=verbose) -class FixLambdaGauge(_Objective): - """Fixes gauge freedom for lambda: lambda(theta=0,zeta=0)=0. +class FixLambdaGauge(FixParameters): + """Fixes gauge freedom for lambda, which sets the flux surface avg of lambda to 0. Note: this constraint is automatically applied when needed, and does not need to be included by the user. @@ -793,9 +793,6 @@ class FixLambdaGauge(_Objective): """ - _scalar = False - _linear = True - _fixed = False # not "diagonal", since it is fixing a sum _units = "(rad)" _print_value_fmt = "lambda gauge error: " @@ -806,75 +803,21 @@ def __init__( normalize_target=True, name="lambda gauge", ): + if eq.sym: + indices = False + else: + indices = np.where( + np.logical_and(eq.L_basis.modes[:, 1] == 0, eq.L_basis.modes[:, 2] == 0) + )[0] super().__init__( - things=eq, + thing=eq, + params={"L_lmn": indices}, target=0, - bounds=None, - weight=1, normalize=normalize, normalize_target=normalize_target, name=name, ) - def build(self, use_jit=False, verbose=1): - """Build constant arrays. - - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - eq = self.things[0] - L_basis = eq.L_basis - - if L_basis.sym: - self._A = np.zeros((0, L_basis.num_modes)) - else: - # l(rho,0,0) = 0 - # at theta=zeta=0, basis for lambda reduces to just a polynomial in rho - # what this constraint does is make all the coefficients of each power - # of rho equal to zero - # i.e. if lambda = (L_200 + 2*L_310) rho**2 + (L_100 + 2*L_210)*rho - # this constraint will make - # L_200 + 2*L_310 = 0 - # L_100 + 2*L_210 = 0 - L_modes = L_basis.modes - mnpos = np.where((L_modes[:, 1:] >= [0, 0]).all(axis=1))[0] - l_lmn = L_modes[mnpos, :] - if len(l_lmn) > 0: - c = zernike_radial_coeffs(l_lmn[:, 0], l_lmn[:, 1]) - else: - c = np.zeros((0, 0)) - - A = np.zeros((c.shape[1], L_basis.num_modes)) - A[:, mnpos] = c.T - self._A = A - - self._dim_f = self._A.shape[0] - super().build(use_jit=use_jit, verbose=verbose) - - def compute(self, params, constants=None): - """Compute lambda gauge freedom errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - gauge freedom errors. - - """ - return jnp.dot(self._A, params["L_lmn"]) - class FixThetaSFL(FixParameters): """Fixes lambda=0 so that poloidal angle is the SFL poloidal angle. diff --git a/tests/inputs/HELIO_asym.h5 b/tests/inputs/HELIO_asym.h5 index c66a6cb100..3d57653cb6 100644 Binary files a/tests/inputs/HELIO_asym.h5 and b/tests/inputs/HELIO_asym.h5 differ diff --git a/tests/test_linear_objectives.py b/tests/test_linear_objectives.py index d0c17f8e39..83f9a32936 100644 --- a/tests/test_linear_objectives.py +++ b/tests/test_linear_objectives.py @@ -2,14 +2,12 @@ import numpy as np import pytest -import scipy.linalg from qsc import Qsc import desc.examples from desc.backend import jnp from desc.equilibrium import Equilibrium from desc.geometry import FourierRZToroidalSurface -from desc.grid import LinearGrid from desc.io import load from desc.magnetic_fields import OmnigenousField from desc.objectives import ( @@ -69,10 +67,10 @@ def test_LambdaGauge_sym(DummyStellarator): eq = load(load_from=str(DummyStellarator["output_path"]), file_format="hdf5") with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(L=2, M=1, N=1) - correct_constraint_matrix = np.zeros((0, 5)) lam_con = FixLambdaGauge(eq) lam_con.build() - np.testing.assert_array_equal(lam_con._A, correct_constraint_matrix) + # should have no indices to fix + assert lam_con._params["L_lmn"].size == 0 @pytest.mark.unit @@ -105,13 +103,10 @@ def test_LambdaGauge_asym(): lam_con = FixLambdaGauge(eq) lam_con.build() - # make sure that any lambda in the null space gives lambda==0 at theta=zeta=0 - Z = scipy.linalg.null_space(lam_con._A) - grid = LinearGrid(L=10, theta=[0], zeta=[0]) - for z in Z.T: - eq.L_lmn = z - lam = eq.compute("lambda", grid=grid)["lambda"] - np.testing.assert_allclose(lam, 0, atol=1e-15) + indices = np.where( + np.logical_and(eq.L_basis.modes[:, 1] == 0, eq.L_basis.modes[:, 2] == 0) + )[0] + np.testing.assert_allclose(indices, lam_con._params["L_lmn"]) @pytest.mark.regression diff --git a/tests/test_vmec.py b/tests/test_vmec.py index 0fef594b3c..91b34667aa 100644 --- a/tests/test_vmec.py +++ b/tests/test_vmec.py @@ -1022,7 +1022,7 @@ def test_vmec_save_asym(VMEC_save_asym): vmec.variables["jcurv"][20:100], desc.variables["jcurv"][20:100], rtol=2 ) np.testing.assert_allclose( - vmec.variables["DShear"][20:100], desc.variables["DShear"][20:100], rtol=3e-2 + vmec.variables["DShear"][20:100], desc.variables["DShear"][20:100], rtol=6e-2 ) np.testing.assert_allclose( vmec.variables["DCurr"][20:100], @@ -1074,13 +1074,22 @@ def test_vmec_save_asym(VMEC_save_asym): # Next, calculate some quantities and compare # the DESC wout -> DESC (should be very close) # and the DESC wout -> VMEC wout (should be approximately close) + surfs = desc.variables["ns"][:] + s_full = np.linspace(0, 1, surfs) + s_half = s_full[0:-1] + 0.5 / (surfs - 1) + r_full = np.sqrt(s_full) + r_half = np.sqrt(s_half) + vol_grid = LinearGrid( - rho=np.sqrt( - abs( - vmec.variables["phi"][:].filled() - / np.max(np.abs(vmec.variables["phi"][:].filled())) - ) - )[10::10], + rho=r_full[10::10], + M=15, + N=15, + NFP=eq.NFP, + axis=False, + sym=False, + ) + vol_half_grid = LinearGrid( + rho=r_half[10::10], M=15, N=15, NFP=eq.NFP, @@ -1100,12 +1109,13 @@ def test( atol_vmec_desc_wout=1e-5, rtol_vmec_desc_wout=1e-2, grid=vol_grid, + is_half_grid=False, ): """Helper fxn to evaluate Fourier series from wout and compare to DESC.""" xm = desc.variables["xm_nyq"][:] if use_nyq else desc.variables["xm"][:] xn = desc.variables["xn_nyq"][:] if use_nyq else desc.variables["xn"][:] - si = abs(vmec.variables["phi"][:] / np.max(np.abs(vmec.variables["phi"][:]))) + si = np.insert(s_half, 0, 0) if is_half_grid else s_full rho = grid.nodes[:, 0] s = rho**2 # some quantities must be negated before comparison bc @@ -1166,18 +1176,26 @@ def test( rtol=rtol_vmec_desc_wout, ) - # R & Z & lambda + # R & Z test("rmn", "R", use_nyq=False) - test("zmn", "Z", use_nyq=False, atol_vmec_desc_wout=4e-2) - + test("zmn", "Z", use_nyq=False, atol_vmec_desc_wout=1e-2) + test( + "lmn", + "lambda", + use_nyq=False, + atol_vmec_desc_wout=1e-2, + negate_DESC_quant=True, + grid=vol_half_grid, + is_half_grid=True, + ) # |B| test("bmn", "|B|", rtol_desc_desc_wout=7e-4) # B^zeta - test("bsupvmn", "B^zeta") # ,rtol_desc_desc_wout=6e-5) + test("bsupvmn", "B^zeta") # B_zeta - test("bsubvmn", "B_zeta", rtol_desc_desc_wout=3e-4) + test("bsubvmn", "B_zeta", grid=vol_half_grid, is_half_grid=True) # hard to compare to VMEC for the currents, since # VMEC F error is worse and equilibria are not exactly similar @@ -1185,30 +1203,22 @@ def test( test("currumn", "J^theta", atol_vmec_desc_wout=1e4) test("currvmn", "J^zeta", negate_DESC_quant=True, atol_vmec_desc_wout=1e5) - # can only compare lambda, sqrt(g) B_psi B^theta and B_theta at bdry - test( - "lmn", - "lambda", - use_nyq=False, - negate_DESC_quant=True, - grid=bdry_grid, - atol_desc_desc_wout=4e-4, - atol_vmec_desc_wout=5e-2, - ) test( "gmn", "sqrt(g)", convert_sqrt_g_or_B_rho=True, negate_DESC_quant=True, - grid=bdry_grid, - rtol_desc_desc_wout=5e-4, - rtol_vmec_desc_wout=4e-2, + grid=vol_half_grid, + is_half_grid=True, ) + + # Compare B_psi B^theta and B_theta at bdry only test( "bsupumn", "B^theta", negate_DESC_quant=True, grid=bdry_grid, + is_half_grid=True, atol_vmec_desc_wout=6e-4, ) test( @@ -1216,16 +1226,14 @@ def test( "B_theta", negate_DESC_quant=True, grid=bdry_grid, - atol_desc_desc_wout=1e-4, - atol_vmec_desc_wout=4e-4, + is_half_grid=True, ) test( "bsubsmn", "B_rho", - grid=bdry_grid, convert_sqrt_g_or_B_rho=True, - rtol_vmec_desc_wout=6e-2, - atol_vmec_desc_wout=9e-3, + grid=bdry_grid, + atol_vmec_desc_wout=2e-3, )