diff --git a/desc/compat.py b/desc/compat.py index 38f3d9520a..1398a56090 100644 --- a/desc/compat.py +++ b/desc/compat.py @@ -261,3 +261,66 @@ def rescale( eq.surface = eq.get_surface_at(rho=1) return eq + + +def rotate_zeta(eq, angle, copy=False): + """Rotate the equilibrium about the toroidal direction. + + Parameters + ---------- + eq : Equilibrium + Equilibrium to rotate. + angle : float + Angle to rotate the equilibrium in radians. The actual physical rotation + is by angle / self.NFP. + copy : bool, optional + Whether to update the existing equilibrium or make a copy (Default). + + Returns + ------- + eq_rotated : Equilibrium + Equilibrium rotated about the toroidal direction + """ + eq_rotated = eq.copy() if copy else eq + if eq.sym and not (angle % np.pi == 0) and eq.N != 0: + warnings.warn( + "Rotating a stellarator symmetric equilibrium by an angle " + "that is not a multiple of pi will break the symmetry. " + "Changing the symmetry to False to rotate the equilibrium." + ) + eq_rotated.change_resolution(sym=0) + + def _get_new_coeffs(fun): + if fun == "R": + f_lmn = np.array(eq_rotated.R_lmn) + modes = eq_rotated.R_basis.modes + elif fun == "Z": + f_lmn = np.array(eq_rotated.Z_lmn) + modes = eq_rotated.Z_basis.modes + elif fun == "L": + f_lmn = np.array(eq_rotated.L_lmn) + modes = eq_rotated.L_basis.modes + else: + raise ValueError("fun must be 'R', 'Z' or 'L'") + + new_coeffs = f_lmn.copy() + mode_lookup = {(l, m, n): idx for idx, (l, m, n) in enumerate(modes)} + for i, (l, m, n) in enumerate(modes): + id_sin = mode_lookup.get((l, m, -n), None) + v_sin = np.sin(np.abs(n) * angle) + v_cos = np.cos(np.abs(n) * angle) + c_sin = f_lmn[id_sin] if id_sin is not None else 0 + if n >= 0: + new_coeffs[i] = f_lmn[i] * v_cos + c_sin * v_sin + elif n < 0: + new_coeffs[i] = f_lmn[i] * v_cos - c_sin * v_sin + return new_coeffs + + eq_rotated.R_lmn = _get_new_coeffs(fun="R") + eq_rotated.Z_lmn = _get_new_coeffs(fun="Z") + eq_rotated.L_lmn = _get_new_coeffs(fun="L") + + eq_rotated.surface = eq_rotated.get_surface_at(rho=1.0) + eq_rotated.axis = eq_rotated.get_axis() + + return eq_rotated diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 3786af2d60..db4d506607 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1408,66 +1408,6 @@ def to_sfl( """ return to_sfl(self, L, M, N, L_grid, M_grid, N_grid, rcond, copy) - def rotate_zeta(self, angle, copy=False): - """Rotate the equilibrium about the toroidal direction. - - Parameters - ---------- - angle : float - Angle to rotate the equilibrium in radians. The actual physical rotation - is by angle / self.NFP. - copy : bool, optional - Whether to update the existing equilibrium or make a copy (Default). - - Returns - ------- - eq : Equilibrium - Equilibrium rotated about the toroidal direction - """ - eq = self.copy() if copy else self - if eq.sym and not (angle % np.pi == 0) and eq.N != 0: - warnings.warn( - "Rotating a stellarator symmetric equilibrium by an angle " - "that is not a multiple of pi will break the symmetry. " - "Changing the symmetry to False to rotate the equilibrium." - ) - eq.change_resolution(sym=0) - - def _get_new_coeffs(fun): - if fun == "R": - f_lmn = np.array(eq.R_lmn) - modes = eq.R_basis.modes - elif fun == "Z": - f_lmn = np.array(eq.Z_lmn) - modes = eq.Z_basis.modes - elif fun == "L": - f_lmn = np.array(eq.L_lmn) - modes = eq.L_basis.modes - else: - raise ValueError("fun must be 'R', 'Z' or 'L'") - - new_coeffs = f_lmn.copy() - mode_lookup = {(l, m, n): idx for idx, (l, m, n) in enumerate(modes)} - for i, (l, m, n) in enumerate(modes): - id_sin = mode_lookup.get((l, m, -n), None) - v_sin = np.sin(np.abs(n) * angle) - v_cos = np.cos(np.abs(n) * angle) - c_sin = f_lmn[id_sin] if id_sin is not None else 0 - if n >= 0: - new_coeffs[i] = f_lmn[i] * v_cos + c_sin * v_sin - elif n < 0: - new_coeffs[i] = f_lmn[i] * v_cos - c_sin * v_sin - return new_coeffs - - eq.R_lmn = _get_new_coeffs(fun="R") - eq.Z_lmn = _get_new_coeffs(fun="Z") - eq.L_lmn = _get_new_coeffs(fun="L") - - eq.surface = eq.get_surface_at(rho=1.0) - eq.axis = eq.get_axis() - - return eq - @property def surface(self): """Surface: Geometric surface defining boundary conditions.""" diff --git a/tests/test_compat.py b/tests/test_compat.py index ca1a0c55da..e6d98aefb4 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from desc.compat import flip_helicity, flip_theta, rescale +from desc.compat import flip_helicity, flip_theta, rescale, rotate_zeta from desc.examples import get from desc.grid import Grid, LinearGrid, QuadratureGrid @@ -277,3 +277,25 @@ def fun(eq): np.testing.assert_allclose(new_vals["B_max"], 2) np.testing.assert_allclose(new_vals["R0/a"], old_vals["R0/a"]) np.testing.assert_allclose(new_vals["err"], old_vals["err"], atol=1e-10) + + +@pytest.mark.unit +@pytest.mark.solve +def test_rotate_zeta(): + """Test rotating Equilibrium around Z axis.""" + eq = get("ARIES-CS") + eq_no_sym = eq.copy() + eq_no_sym.change_resolution(sym=False) + with pytest.warns(): + eq1 = rotate_zeta(eq, np.pi / 2, copy=True) + eq2 = rotate_zeta(eq1, 3 * np.pi / 2, copy=False) + + assert np.allclose(eq_no_sym.R_lmn, eq2.R_lmn) + assert np.allclose(eq_no_sym.Z_lmn, eq2.Z_lmn) + assert np.allclose(eq_no_sym.L_lmn, eq2.L_lmn) + + eq = get("DSHAPE") + eq3 = rotate_zeta(eq, np.pi / 3, copy=True) + assert np.allclose(eq.R_lmn, eq3.R_lmn) + assert np.allclose(eq.Z_lmn, eq3.Z_lmn) + assert np.allclose(eq.L_lmn, eq3.L_lmn) diff --git a/tests/test_equilibrium.py b/tests/test_equilibrium.py index 272d12f8de..601b559104 100644 --- a/tests/test_equilibrium.py +++ b/tests/test_equilibrium.py @@ -402,25 +402,3 @@ def test_backward_compatible_load_and_resolve(): f_obj = ForceBalance(eq=eq) obj = ObjectiveFunction(f_obj, use_jit=False) eq.solve(maxiter=1, objective=obj) - - -@pytest.mark.unit -@pytest.mark.solve -def test_rotate_zeta(): - """Test rotating Equilibrium around Z axis.""" - eq = get("ARIES-CS") - eq_no_sym = eq.copy() - eq_no_sym.change_resolution(sym=False) - with pytest.warns(): - eq1 = eq.rotate_zeta(np.pi / 2, copy=True) - eq2 = eq1.rotate_zeta(3 * np.pi / 2, copy=False) - - assert np.allclose(eq_no_sym.R_lmn, eq2.R_lmn) - assert np.allclose(eq_no_sym.Z_lmn, eq2.Z_lmn) - assert np.allclose(eq_no_sym.L_lmn, eq2.L_lmn) - - eq = get("DSHAPE") - eq3 = eq.rotate_zeta(np.pi / 3, copy=True) - assert np.allclose(eq.R_lmn, eq3.R_lmn) - assert np.allclose(eq.Z_lmn, eq3.Z_lmn) - assert np.allclose(eq.L_lmn, eq3.L_lmn)