Skip to content

Commit

Permalink
Merge branch 'master' into yge/rotate_zeta
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma authored Nov 19, 2024
2 parents 4a98856 + 51d637c commit 68f3e2d
Show file tree
Hide file tree
Showing 51 changed files with 181 additions and 163 deletions.
2 changes: 1 addition & 1 deletion desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
if use_jax: # noqa: C901
from jax import custom_jvp, jit, vmap

imap = jax.lax.map
Expand Down
5 changes: 2 additions & 3 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,13 +1098,13 @@ def evaluate(
if not len(modes):
return np.array([]).reshape((len(nodes), 0))

# TODO: avoid duplicate calculations when mixing derivatives
# TODO(#1243): avoid duplicate calculations when mixing derivatives
r, t, z = nodes.T
l, m, n = modes.T
lm = modes[:, :2]

if unique:
# TODO: can avoid this here by using grid.unique_idx etc
# TODO(#1243): can avoid this here by using grid.unique_idx etc
# and adding unique_modes attributes to basis
_, ridx, routidx = np.unique(
r, return_index=True, return_inverse=True, axis=0
Expand Down Expand Up @@ -1364,7 +1364,6 @@ def polyval_vec(p, x, prec=None):
def _polyval_exact(p, x, prec):
p = np.atleast_2d(p)
x = np.atleast_1d(x).flatten()
# TODO: possibly multithread this bit
mpmath.mp.dps = prec
y = np.array([np.asarray(mpmath.polyval(list(pi), x)) for pi in p])
return y.astype(float)
Expand Down
1 change: 0 additions & 1 deletion desc/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def f_(carry, x):
return res_append


# TODO in_axes a la vmap?
def _scanmap(fun, scan_fun, argnums=0):
"""A helper function to wrap f with a scan_fun."""

Expand Down
37 changes: 28 additions & 9 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,14 @@ def compute_magnetic_vector_potential(

@classmethod
def linspaced_angular(
cls, coil, current=None, axis=[0, 0, 1], angle=2 * np.pi, n=10, endpoint=False
cls,
coil,
current=None,
axis=[0, 0, 1],
angle=2 * np.pi,
n=10,
endpoint=False,
check_intersection=True,
):
"""Create a CoilSet by repeating a coil at equal spacing around the torus.
Expand All @@ -1651,6 +1658,8 @@ def linspaced_angular(
Number of copies of original coil.
endpoint : bool
Whether to include a coil at final rotation angle. Default = False.
check_intersection : bool
whether to check the resulting coilsets for intersecting coils.
"""
assert isinstance(coil, _Coil) and not isinstance(coil, CoilSet)
Expand All @@ -1664,11 +1673,17 @@ def linspaced_angular(
coili.rotate(axis=axis, angle=phi[i])
coili.current = currents[i]
coils.append(coili)
return cls(*coils)
return cls(*coils, check_intersection=check_intersection)

@classmethod
def linspaced_linear(
cls, coil, current=None, displacement=[2, 0, 0], n=4, endpoint=False
cls,
coil,
current=None,
displacement=[2, 0, 0],
n=4,
endpoint=False,
check_intersection=True,
):
"""Create a CoilSet by repeating a coil at equal spacing in a straight line.
Expand All @@ -1685,6 +1700,8 @@ def linspaced_linear(
Number of copies of original coil.
endpoint : bool
Whether to include a coil at final displacement location. Default = False.
check_intersection : bool
whether to check the resulting coilsets for intersecting coils.
"""
assert isinstance(coil, _Coil) and not isinstance(coil, CoilSet)
Expand All @@ -1699,10 +1716,10 @@ def linspaced_linear(
coili.translate(a[i] * displacement)
coili.current = currents[i]
coils.append(coili)
return cls(*coils)
return cls(*coils, check_intersection=check_intersection)

@classmethod
def from_symmetry(cls, coils, NFP=1, sym=False):
def from_symmetry(cls, coils, NFP=1, sym=False, check_intersection=True):
"""Create a coil group by reflection and symmetry.
Given coils over one field period, repeat coils NFP times between
Expand All @@ -1721,6 +1738,8 @@ def from_symmetry(cls, coils, NFP=1, sym=False):
sym : bool (optional)
Whether to enforce stellarator symmetry.
If True, the coils will be duplicated 2*NFP times. Default = False.
check_intersection : bool
whether to check the resulting coilsets for intersecting coils.
Returns
-------
Expand Down Expand Up @@ -1763,7 +1782,7 @@ def from_symmetry(cls, coils, NFP=1, sym=False):
rotated_coils.rotate(axis=[0, 0, 1], angle=2 * jnp.pi * k / NFP)
coilset += rotated_coils

return cls(*coilset)
return cls(*coilset, check_intersection=check_intersection)

@classmethod
def from_makegrid_coilfile(cls, coil_file, method="cubic", check_intersection=True):
Expand Down Expand Up @@ -1901,8 +1920,8 @@ def save_in_makegrid_format(self, coilsFilename, NFP=None, grid=None):
if None, will default to the coil compute functions's
default grid
"""
# TODO: name each group based off of CoilSet name?
# TODO: have CoilGroup be automatically assigned based off of
# TODO(#1376): name each group based off of CoilSet name?
# TODO(#1376): have CoilGroup be automatically assigned based off of
# CoilSet if current coilset is a collection of coilsets?

NFP = 1 if NFP is None else NFP
Expand Down Expand Up @@ -2697,7 +2716,7 @@ def insert(self, i, new_item):
self._coils.insert(i, new_item)

@classmethod
def from_makegrid_coilfile( # noqa: C901 - FIXME: simplify this
def from_makegrid_coilfile( # noqa: C901
cls, coil_file, method="cubic", ignore_groups=False, check_intersection=True
):
"""Create a MixedCoilSet of SplineXYZCoils from a MAKEGRID coil txtfile.
Expand Down
6 changes: 3 additions & 3 deletions desc/compute/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _A_of_z(params, transforms, profiles, data, **kwargs):
data=["Z", "n_rho", "e_theta|r,p", "rho"],
parameterization=["desc.geometry.surface.FourierRZToroidalSurface"],
resolution_requirement="rt", # just need max(rho) near 1
# FIXME: Add source grid requirement once omega is nonzero.
# TODO(#568): Add source grid requirement once omega is nonzero.
)
def _A_of_z_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
# Denote any vector v = [vᴿ, v^ϕ, vᶻ] with a tuple of its contravariant components.
Expand All @@ -213,7 +213,7 @@ def _A_of_z_FourierRZToroidalSurface(params, transforms, profiles, data, **kwarg
line_integrals(
transforms["grid"],
data["Z"] * n[:, 2] * safenorm(data["e_theta|r,p"], axis=-1),
# FIXME: Works currently for omega = zero, but for nonzero omega
# TODO(#568): Works currently for omega = zero, but for nonzero omega
# we need to integrate over theta at constant phi.
# Should be simple once we have coordinate mapping and source grid
# logic from GitHub pull request #1024.
Expand Down Expand Up @@ -449,7 +449,7 @@ def _perimeter_of_z(params, transforms, profiles, data, **kwargs):
line_integrals(
transforms["grid"],
safenorm(data["e_theta|r,p"], axis=-1),
# FIXME: Works currently for omega = zero, but for nonzero omega
# TODO(#568): Works currently for omega = zero, but for nonzero omega
# we need to integrate over theta at constant phi.
# Should be simple once we have coordinate mapping and source grid
# logic from GitHub pull request #1024.
Expand Down
4 changes: 2 additions & 2 deletions desc/compute/_omnigenity.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def fitfun(x):
return data


# TODO: do math to change definition of nu so that we can just use B_zeta_mn here
# TODO (#568): do math to change definition of nu so that we can just use B_zeta_mn here
@register_compute_fun(
name="B_phi_mn",
label="B_{\\phi, m, n}",
Expand All @@ -63,7 +63,7 @@ def fitfun(x):
data=["B_phi|r,t"],
resolution_requirement="tz",
grid_requirement={"is_meshgrid": True},
aliases="B_zeta_mn", # TODO: remove when phi != zeta
aliases="B_zeta_mn", # TODO(#568): remove when phi != zeta
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
)
Expand Down
6 changes: 3 additions & 3 deletions desc/compute/_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,7 @@ def _iota_num_rrr(params, transforms, profiles, data, **kwargs):
- beta * data["sqrt(g)_rrr"],
data["sqrt(g)"],
),
# Todo: axis limit of beta_rrr
# TODO(#587): axis limit of beta_rrr
# Computed with four applications of l’Hôpital’s rule.
# Requires sqrt(g)_rrrr and fourth derivatives of basis vectors.
jnp.nan,
Expand Down Expand Up @@ -1656,7 +1656,7 @@ def _iota_den_rrr(params, transforms, profiles, data, **kwargs):
- gamma * data["sqrt(g)_rrr"],
data["sqrt(g)"],
),
# Todo: axis limit
# TODO(#587): axis limit
# Computed with four applications of l’Hôpital’s rule.
# Requires sqrt(g)_rrrr and fourth derivatives of basis vectors.
jnp.nan,
Expand Down Expand Up @@ -1713,7 +1713,7 @@ def _q(params, transforms, profiles, data, **kwargs):
return data


# TODO: add K(rho,theta,zeta)*grad(rho) term
# TODO (#1381): add K(rho,theta,zeta)*grad(rho) term
@register_compute_fun(
name="I",
label="I",
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _D_current(params, transforms, profiles, data, **kwargs):
/ data["|grad(psi)|"] ** 3
* dot(Xi, data["B"])
),
# Todo: implement equivalent of equation 4.3 in desc coordinates
# TODO(#671): implement equivalent of equation 4.3 in desc coordinates
jnp.nan,
)
)
Expand Down
4 changes: 2 additions & 2 deletions desc/compute/_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .data_index import register_compute_fun
from .geom_utils import rpz2xyz

# TODO: review when zeta no longer equals phi
# TODO(#568): review when zeta no longer equals phi


@register_compute_fun(
Expand All @@ -27,7 +27,7 @@
def _x_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
R = transforms["R"].transform(params["R_lmn"])
Z = transforms["Z"].transform(params["Z_lmn"])
# TODO: change when zeta no longer equals phi
# TODO(#568): change when zeta no longer equals phi
phi = transforms["grid"].nodes[:, 2]
coords = jnp.stack([R, phi, Z], axis=1)
# default basis for "x" is rpz, the conversion will be done
Expand Down
1 change: 0 additions & 1 deletion desc/continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,6 @@ def solve_continuation( # noqa: C901
if len(deltas) > 0:
if verbose > 0:
print("Perturbing equilibrium")
# TODO: pass Jx if available
eqp = eqfam[ii - 1].copy()
objective_i = get_equilibrium_objective(
eq=eqp, mode=objective, jac_chunk_size=jac_chunk_size
Expand Down
8 changes: 4 additions & 4 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def map_coordinates( # noqa: C901

profiles = get_profiles(inbasis + basis_derivs, eq)

# TODO: make this work for permutations of in/out basis
# TODO (#1382): make this work for permutations of in/out basis
if outbasis == ("rho", "theta", "zeta"):
if inbasis == ("rho", "alpha", "zeta"):
if "iota" in kwargs:
Expand Down Expand Up @@ -286,7 +286,7 @@ def _distance_body(i, idx):
return yg[idx]


# TODO: decide later whether to assume given phi instead of zeta.
# TODO(#568): decide later whether to assume given phi instead of zeta.
def _map_PEST_coordinates(
coords,
L_lmn,
Expand Down Expand Up @@ -395,7 +395,7 @@ def fixup(x, *args):
return out


# TODO: decide later whether to assume given phi instead of zeta.
# TODO(#568): decide later whether to assume given phi instead of zeta.
def _map_clebsch_coordinates(
coords,
iota,
Expand Down Expand Up @@ -766,7 +766,7 @@ def get_rtz_grid(
return desc_grid


# TODO: deprecated, remove eventually
# TODO(#1383): deprecated, remove eventually
def compute_theta_coords(
eq, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs
):
Expand Down
3 changes: 0 additions & 3 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,6 @@ def axis(self, new):
@property
def spectral_indexing(self):
"""str: Type of indexing used for the spectral basis."""
# TODO: allow this to change?
return self._spectral_indexing

@property
Expand Down Expand Up @@ -1972,8 +1971,6 @@ def from_near_axis(
raise ValueError("Input must be a pyQSC or pyQIC solution.") from e

rho, _ = special.js_roots(L, 2, 2)
# TODO: could make this an OCS grid to improve fitting, need to figure out
# how concentric grids work with QSC
grid = LinearGrid(rho=rho, theta=ntheta, zeta=na_eq.phi, NFP=na_eq.nfp)
basis_R = FourierZernikeBasis(
L=L,
Expand Down
2 changes: 1 addition & 1 deletion desc/equilibrium/initial_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from desc.utils import copy_coeffs, warnif


def set_initial_guess(eq, *args, ensure_nested=True): # noqa: C901 - FIXME: simplify
def set_initial_guess(eq, *args, ensure_nested=True): # noqa: C901
"""Set the initial guess for the flux surfaces, eg R_lmn, Z_lmn, L_lmn.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions desc/equilibrium/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def parse_axis(axis, NFP=1, sym=True, surface=None):
name="axis",
)
elif axis is None: # use the center of surface
# TODO: make this method of surface, surface.get_axis()?
# TODO (#1384): make this method of surface, surface.get_axis()?
if isinstance(surface, FourierRZToroidalSurface):
axis = FourierRZCurve(
R_n=surface.R_lmn[np.where(surface.R_basis.modes[:, 1] == 0)],
Expand All @@ -160,7 +160,7 @@ def parse_axis(axis, NFP=1, sym=True, surface=None):
NFP=NFP,
)
elif isinstance(surface, ZernikeRZToroidalSection):
# FIXME: include m=0 l!=0 modes
# TODO (#782): include m=0 l!=0 modes
axis = FourierRZCurve(
R_n=surface.R_lmn[
np.where(
Expand Down
2 changes: 1 addition & 1 deletion desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def from_input_file(cls, path, **kwargs):
)
return surf

# TODO: add k value for number of rotations per field period
# TODO (#1385): add k value for number of rotations per field period
@classmethod
def from_qp_model(
cls,
Expand Down
1 change: 0 additions & 1 deletion desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
class _Grid(IOAble, ABC):
"""Base class for collocation grids."""

# TODO: calculate weights automatically using voronoi / delaunay triangulation
_io_attrs_ = [
"_L",
"_M",
Expand Down
5 changes: 2 additions & 3 deletions desc/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _get_parser_(self):
"""
return get_parser()

def parse_inputs(self, fname=None): # noqa: C901 - FIXME: simplify this
def parse_inputs(self, fname=None): # noqa: C901
"""Read input from DESC input file; converts from VMEC input if necessary.
Parameters
Expand Down Expand Up @@ -386,7 +386,6 @@ def parse_inputs(self, fname=None): # noqa: C901 - FIXME: simplify this
if match:
inputs["bdry_mode"] = words[0].lower()
flag = True
# TODO: set bdry_mode automatically based on bdry coeffs

# coefficient indices
match = re.search(r"l\s*:\s*" + num_form, command, re.IGNORECASE)
Expand Down Expand Up @@ -970,7 +969,7 @@ def vmec_to_desc_input(vmec_fname, desc_fname):
InputReader.write_desc_input(desc_fname, inputs, header)

@staticmethod
def parse_vmec_inputs(vmec_fname, threshold=0): # noqa: C901 - FIXME: simplify this
def parse_vmec_inputs(vmec_fname, threshold=0): # noqa: C901
"""Parse a VMEC input file into a dictionary of DESC inputs.
Parameters
Expand Down
5 changes: 3 additions & 2 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _check_spline_shape(knots, g, dg_dz, pitch_inv=None):
to that field line.
"""
errorif(knots.ndim != 1, msg=f"knots should be 1d; got shape {knots.shape}.")
errorif(knots.ndim != 1, msg=f"knots should be 1d, got shape {knots.shape}.")
errorif(
g.shape[-2] != (knots.size - 1),
msg=(
Expand Down Expand Up @@ -390,7 +390,8 @@ def loop(z): # over num well axis
)

result = jnp.moveaxis(
# TODO: Use batch_size arg of imap after increasing JAX version requirement.
# TODO (#1386): Use batch_size arg of imap after
# increasing JAX version requirement.
imap(loop, (jnp.moveaxis(z1, -1, 0), jnp.moveaxis(z2, -1, 0)))[1],
source=0,
destination=-1,
Expand Down
Loading

0 comments on commit 68f3e2d

Please sign in to comment.