Skip to content

Commit

Permalink
move rotate_zeta to compat
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Nov 18, 2024
1 parent 4579870 commit 337e214
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 83 deletions.
63 changes: 63 additions & 0 deletions desc/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,66 @@ def rescale(
eq.surface = eq.get_surface_at(rho=1)

return eq


def rotate_zeta(eq, angle, copy=False):
"""Rotate the equilibrium about the toroidal direction.
Parameters
----------
eq : Equilibrium
Equilibrium to rotate.
angle : float
Angle to rotate the equilibrium in radians. The actual physical rotation
is by angle / self.NFP.
copy : bool, optional
Whether to update the existing equilibrium or make a copy (Default).
Returns
-------
eq_rotated : Equilibrium
Equilibrium rotated about the toroidal direction
"""
eq_rotated = eq.copy() if copy else eq
if eq.sym and not (angle % np.pi == 0) and eq.N != 0:
warnings.warn(
"Rotating a stellarator symmetric equilibrium by an angle "
"that is not a multiple of pi will break the symmetry. "
"Changing the symmetry to False to rotate the equilibrium."
)
eq_rotated.change_resolution(sym=0)

def _get_new_coeffs(fun):
if fun == "R":
f_lmn = np.array(eq_rotated.R_lmn)
modes = eq_rotated.R_basis.modes
elif fun == "Z":
f_lmn = np.array(eq_rotated.Z_lmn)
modes = eq_rotated.Z_basis.modes
elif fun == "L":
f_lmn = np.array(eq_rotated.L_lmn)
modes = eq_rotated.L_basis.modes
else:
raise ValueError("fun must be 'R', 'Z' or 'L'")

Check warning on line 304 in desc/compat.py

View check run for this annotation

Codecov / codecov/patch

desc/compat.py#L304

Added line #L304 was not covered by tests

new_coeffs = f_lmn.copy()
mode_lookup = {(l, m, n): idx for idx, (l, m, n) in enumerate(modes)}
for i, (l, m, n) in enumerate(modes):
id_sin = mode_lookup.get((l, m, -n), None)
v_sin = np.sin(np.abs(n) * angle)
v_cos = np.cos(np.abs(n) * angle)
c_sin = f_lmn[id_sin] if id_sin is not None else 0
if n >= 0:
new_coeffs[i] = f_lmn[i] * v_cos + c_sin * v_sin
elif n < 0:
new_coeffs[i] = f_lmn[i] * v_cos - c_sin * v_sin
return new_coeffs

eq_rotated.R_lmn = _get_new_coeffs(fun="R")
eq_rotated.Z_lmn = _get_new_coeffs(fun="Z")
eq_rotated.L_lmn = _get_new_coeffs(fun="L")

eq_rotated.surface = eq_rotated.get_surface_at(rho=1.0)
eq_rotated.axis = eq_rotated.get_axis()

return eq_rotated
60 changes: 0 additions & 60 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,66 +1408,6 @@ def to_sfl(
"""
return to_sfl(self, L, M, N, L_grid, M_grid, N_grid, rcond, copy)

def rotate_zeta(self, angle, copy=False):
"""Rotate the equilibrium about the toroidal direction.
Parameters
----------
angle : float
Angle to rotate the equilibrium in radians. The actual physical rotation
is by angle / self.NFP.
copy : bool, optional
Whether to update the existing equilibrium or make a copy (Default).
Returns
-------
eq : Equilibrium
Equilibrium rotated about the toroidal direction
"""
eq = self.copy() if copy else self
if eq.sym and not (angle % np.pi == 0) and eq.N != 0:
warnings.warn(
"Rotating a stellarator symmetric equilibrium by an angle "
"that is not a multiple of pi will break the symmetry. "
"Changing the symmetry to False to rotate the equilibrium."
)
eq.change_resolution(sym=0)

def _get_new_coeffs(fun):
if fun == "R":
f_lmn = np.array(eq.R_lmn)
modes = eq.R_basis.modes
elif fun == "Z":
f_lmn = np.array(eq.Z_lmn)
modes = eq.Z_basis.modes
elif fun == "L":
f_lmn = np.array(eq.L_lmn)
modes = eq.L_basis.modes
else:
raise ValueError("fun must be 'R', 'Z' or 'L'")

new_coeffs = f_lmn.copy()
mode_lookup = {(l, m, n): idx for idx, (l, m, n) in enumerate(modes)}
for i, (l, m, n) in enumerate(modes):
id_sin = mode_lookup.get((l, m, -n), None)
v_sin = np.sin(np.abs(n) * angle)
v_cos = np.cos(np.abs(n) * angle)
c_sin = f_lmn[id_sin] if id_sin is not None else 0
if n >= 0:
new_coeffs[i] = f_lmn[i] * v_cos + c_sin * v_sin
elif n < 0:
new_coeffs[i] = f_lmn[i] * v_cos - c_sin * v_sin
return new_coeffs

eq.R_lmn = _get_new_coeffs(fun="R")
eq.Z_lmn = _get_new_coeffs(fun="Z")
eq.L_lmn = _get_new_coeffs(fun="L")

eq.surface = eq.get_surface_at(rho=1.0)
eq.axis = eq.get_axis()

return eq

@property
def surface(self):
"""Surface: Geometric surface defining boundary conditions."""
Expand Down
24 changes: 23 additions & 1 deletion tests/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from desc.compat import flip_helicity, flip_theta, rescale
from desc.compat import flip_helicity, flip_theta, rescale, rotate_zeta
from desc.examples import get
from desc.grid import Grid, LinearGrid, QuadratureGrid

Expand Down Expand Up @@ -277,3 +277,25 @@ def fun(eq):
np.testing.assert_allclose(new_vals["B_max"], 2)
np.testing.assert_allclose(new_vals["R0/a"], old_vals["R0/a"])
np.testing.assert_allclose(new_vals["err"], old_vals["err"], atol=1e-10)


@pytest.mark.unit
@pytest.mark.solve
def test_rotate_zeta():
"""Test rotating Equilibrium around Z axis."""
eq = get("ARIES-CS")
eq_no_sym = eq.copy()
eq_no_sym.change_resolution(sym=False)
with pytest.warns():
eq1 = rotate_zeta(eq, np.pi / 2, copy=True)
eq2 = rotate_zeta(eq1, 3 * np.pi / 2, copy=False)

assert np.allclose(eq_no_sym.R_lmn, eq2.R_lmn)
assert np.allclose(eq_no_sym.Z_lmn, eq2.Z_lmn)
assert np.allclose(eq_no_sym.L_lmn, eq2.L_lmn)

eq = get("DSHAPE")
eq3 = rotate_zeta(eq, np.pi / 3, copy=True)
assert np.allclose(eq.R_lmn, eq3.R_lmn)
assert np.allclose(eq.Z_lmn, eq3.Z_lmn)
assert np.allclose(eq.L_lmn, eq3.L_lmn)
22 changes: 0 additions & 22 deletions tests/test_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,25 +402,3 @@ def test_backward_compatible_load_and_resolve():
f_obj = ForceBalance(eq=eq)
obj = ObjectiveFunction(f_obj, use_jit=False)
eq.solve(maxiter=1, objective=obj)


@pytest.mark.unit
@pytest.mark.solve
def test_rotate_zeta():
"""Test rotating Equilibrium around Z axis."""
eq = get("ARIES-CS")
eq_no_sym = eq.copy()
eq_no_sym.change_resolution(sym=False)
with pytest.warns():
eq1 = eq.rotate_zeta(np.pi / 2, copy=True)
eq2 = eq1.rotate_zeta(3 * np.pi / 2, copy=False)

assert np.allclose(eq_no_sym.R_lmn, eq2.R_lmn)
assert np.allclose(eq_no_sym.Z_lmn, eq2.Z_lmn)
assert np.allclose(eq_no_sym.L_lmn, eq2.L_lmn)

eq = get("DSHAPE")
eq3 = eq.rotate_zeta(np.pi / 3, copy=True)
assert np.allclose(eq.R_lmn, eq3.R_lmn)
assert np.allclose(eq.Z_lmn, eq3.Z_lmn)
assert np.allclose(eq.L_lmn, eq3.L_lmn)

0 comments on commit 337e214

Please sign in to comment.