Skip to content

Commit

Permalink
Merge branch 'master' into yge/rotate_zeta
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Nov 18, 2024
2 parents 337e214 + 0b52a24 commit 4a98856
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 111 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ New Features
Bug Fixes

- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version

- Changes ``FixLambdaGauge`` constraint to now enforce zero flux surface average for lambda, instead of enforcing lambda(rho,0,0)=0 as it was incorrectly doing before.

v0.12.3
-------
Expand Down
79 changes: 11 additions & 68 deletions desc/objectives/linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from termcolor import colored

from desc.backend import execute_on_cpu, jnp, tree_leaves, tree_map, tree_structure
from desc.basis import zernike_radial, zernike_radial_coeffs
from desc.basis import zernike_radial
from desc.geometry import FourierRZCurve
from desc.utils import broadcast_tree, errorif, setdefault

Expand Down Expand Up @@ -774,8 +774,8 @@ def build(self, use_jit=False, verbose=1):
super().build(use_jit=use_jit, verbose=verbose)


class FixLambdaGauge(_Objective):
"""Fixes gauge freedom for lambda: lambda(theta=0,zeta=0)=0.
class FixLambdaGauge(FixParameters):
"""Fixes gauge freedom for lambda, which sets the flux surface avg of lambda to 0.
Note: this constraint is automatically applied when needed, and does not need to be
included by the user.
Expand All @@ -793,9 +793,6 @@ class FixLambdaGauge(_Objective):
"""

_scalar = False
_linear = True
_fixed = False # not "diagonal", since it is fixing a sum
_units = "(rad)"
_print_value_fmt = "lambda gauge error: "

Expand All @@ -806,75 +803,21 @@ def __init__(
normalize_target=True,
name="lambda gauge",
):
if eq.sym:
indices = False
else:
indices = np.where(
np.logical_and(eq.L_basis.modes[:, 1] == 0, eq.L_basis.modes[:, 2] == 0)
)[0]
super().__init__(
things=eq,
thing=eq,
params={"L_lmn": indices},
target=0,
bounds=None,
weight=1,
normalize=normalize,
normalize_target=normalize_target,
name=name,
)

def build(self, use_jit=False, verbose=1):
"""Build constant arrays.
Parameters
----------
use_jit : bool, optional
Whether to just-in-time compile the objective and derivatives.
verbose : int, optional
Level of output.
"""
eq = self.things[0]
L_basis = eq.L_basis

if L_basis.sym:
self._A = np.zeros((0, L_basis.num_modes))
else:
# l(rho,0,0) = 0
# at theta=zeta=0, basis for lambda reduces to just a polynomial in rho
# what this constraint does is make all the coefficients of each power
# of rho equal to zero
# i.e. if lambda = (L_200 + 2*L_310) rho**2 + (L_100 + 2*L_210)*rho
# this constraint will make
# L_200 + 2*L_310 = 0
# L_100 + 2*L_210 = 0
L_modes = L_basis.modes
mnpos = np.where((L_modes[:, 1:] >= [0, 0]).all(axis=1))[0]
l_lmn = L_modes[mnpos, :]
if len(l_lmn) > 0:
c = zernike_radial_coeffs(l_lmn[:, 0], l_lmn[:, 1])
else:
c = np.zeros((0, 0))

A = np.zeros((c.shape[1], L_basis.num_modes))
A[:, mnpos] = c.T
self._A = A

self._dim_f = self._A.shape[0]
super().build(use_jit=use_jit, verbose=verbose)

def compute(self, params, constants=None):
"""Compute lambda gauge freedom errors.
Parameters
----------
params : dict
Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants
Returns
-------
f : ndarray
gauge freedom errors.
"""
return jnp.dot(self._A, params["L_lmn"])


class FixThetaSFL(FixParameters):
"""Fixes lambda=0 so that poloidal angle is the SFL poloidal angle.
Expand Down
Binary file modified tests/inputs/HELIO_asym.h5
Binary file not shown.
17 changes: 6 additions & 11 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import numpy as np
import pytest
import scipy.linalg
from qsc import Qsc

import desc.examples
from desc.backend import jnp
from desc.equilibrium import Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.grid import LinearGrid
from desc.io import load
from desc.magnetic_fields import OmnigenousField
from desc.objectives import (
Expand Down Expand Up @@ -69,10 +67,10 @@ def test_LambdaGauge_sym(DummyStellarator):
eq = load(load_from=str(DummyStellarator["output_path"]), file_format="hdf5")
with pytest.warns(UserWarning, match="Reducing radial"):
eq.change_resolution(L=2, M=1, N=1)
correct_constraint_matrix = np.zeros((0, 5))
lam_con = FixLambdaGauge(eq)
lam_con.build()
np.testing.assert_array_equal(lam_con._A, correct_constraint_matrix)
# should have no indices to fix
assert lam_con._params["L_lmn"].size == 0


@pytest.mark.unit
Expand Down Expand Up @@ -105,13 +103,10 @@ def test_LambdaGauge_asym():
lam_con = FixLambdaGauge(eq)
lam_con.build()

# make sure that any lambda in the null space gives lambda==0 at theta=zeta=0
Z = scipy.linalg.null_space(lam_con._A)
grid = LinearGrid(L=10, theta=[0], zeta=[0])
for z in Z.T:
eq.L_lmn = z
lam = eq.compute("lambda", grid=grid)["lambda"]
np.testing.assert_allclose(lam, 0, atol=1e-15)
indices = np.where(
np.logical_and(eq.L_basis.modes[:, 1] == 0, eq.L_basis.modes[:, 2] == 0)
)[0]
np.testing.assert_allclose(indices, lam_con._params["L_lmn"])


@pytest.mark.regression
Expand Down
70 changes: 39 additions & 31 deletions tests/test_vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def test_vmec_save_asym(VMEC_save_asym):
vmec.variables["jcurv"][20:100], desc.variables["jcurv"][20:100], rtol=2
)
np.testing.assert_allclose(
vmec.variables["DShear"][20:100], desc.variables["DShear"][20:100], rtol=3e-2
vmec.variables["DShear"][20:100], desc.variables["DShear"][20:100], rtol=6e-2
)
np.testing.assert_allclose(
vmec.variables["DCurr"][20:100],
Expand Down Expand Up @@ -1074,13 +1074,22 @@ def test_vmec_save_asym(VMEC_save_asym):
# Next, calculate some quantities and compare
# the DESC wout -> DESC (should be very close)
# and the DESC wout -> VMEC wout (should be approximately close)
surfs = desc.variables["ns"][:]
s_full = np.linspace(0, 1, surfs)
s_half = s_full[0:-1] + 0.5 / (surfs - 1)
r_full = np.sqrt(s_full)
r_half = np.sqrt(s_half)

vol_grid = LinearGrid(
rho=np.sqrt(
abs(
vmec.variables["phi"][:].filled()
/ np.max(np.abs(vmec.variables["phi"][:].filled()))
)
)[10::10],
rho=r_full[10::10],
M=15,
N=15,
NFP=eq.NFP,
axis=False,
sym=False,
)
vol_half_grid = LinearGrid(
rho=r_half[10::10],
M=15,
N=15,
NFP=eq.NFP,
Expand All @@ -1100,12 +1109,13 @@ def test(
atol_vmec_desc_wout=1e-5,
rtol_vmec_desc_wout=1e-2,
grid=vol_grid,
is_half_grid=False,
):
"""Helper fxn to evaluate Fourier series from wout and compare to DESC."""
xm = desc.variables["xm_nyq"][:] if use_nyq else desc.variables["xm"][:]
xn = desc.variables["xn_nyq"][:] if use_nyq else desc.variables["xn"][:]

si = abs(vmec.variables["phi"][:] / np.max(np.abs(vmec.variables["phi"][:])))
si = np.insert(s_half, 0, 0) if is_half_grid else s_full
rho = grid.nodes[:, 0]
s = rho**2
# some quantities must be negated before comparison bc
Expand Down Expand Up @@ -1166,66 +1176,64 @@ def test(
rtol=rtol_vmec_desc_wout,
)

# R & Z & lambda
# R & Z
test("rmn", "R", use_nyq=False)
test("zmn", "Z", use_nyq=False, atol_vmec_desc_wout=4e-2)

test("zmn", "Z", use_nyq=False, atol_vmec_desc_wout=1e-2)
test(
"lmn",
"lambda",
use_nyq=False,
atol_vmec_desc_wout=1e-2,
negate_DESC_quant=True,
grid=vol_half_grid,
is_half_grid=True,
)
# |B|
test("bmn", "|B|", rtol_desc_desc_wout=7e-4)

# B^zeta
test("bsupvmn", "B^zeta") # ,rtol_desc_desc_wout=6e-5)
test("bsupvmn", "B^zeta")

# B_zeta
test("bsubvmn", "B_zeta", rtol_desc_desc_wout=3e-4)
test("bsubvmn", "B_zeta", grid=vol_half_grid, is_half_grid=True)

# hard to compare to VMEC for the currents, since
# VMEC F error is worse and equilibria are not exactly similar
# just compare back to DESC
test("currumn", "J^theta", atol_vmec_desc_wout=1e4)
test("currvmn", "J^zeta", negate_DESC_quant=True, atol_vmec_desc_wout=1e5)

# can only compare lambda, sqrt(g) B_psi B^theta and B_theta at bdry
test(
"lmn",
"lambda",
use_nyq=False,
negate_DESC_quant=True,
grid=bdry_grid,
atol_desc_desc_wout=4e-4,
atol_vmec_desc_wout=5e-2,
)
test(
"gmn",
"sqrt(g)",
convert_sqrt_g_or_B_rho=True,
negate_DESC_quant=True,
grid=bdry_grid,
rtol_desc_desc_wout=5e-4,
rtol_vmec_desc_wout=4e-2,
grid=vol_half_grid,
is_half_grid=True,
)

# Compare B_psi B^theta and B_theta at bdry only
test(
"bsupumn",
"B^theta",
negate_DESC_quant=True,
grid=bdry_grid,
is_half_grid=True,
atol_vmec_desc_wout=6e-4,
)
test(
"bsubumn",
"B_theta",
negate_DESC_quant=True,
grid=bdry_grid,
atol_desc_desc_wout=1e-4,
atol_vmec_desc_wout=4e-4,
is_half_grid=True,
)
test(
"bsubsmn",
"B_rho",
grid=bdry_grid,
convert_sqrt_g_or_B_rho=True,
rtol_vmec_desc_wout=6e-2,
atol_vmec_desc_wout=9e-3,
grid=bdry_grid,
atol_vmec_desc_wout=2e-3,
)


Expand Down

0 comments on commit 4a98856

Please sign in to comment.