Skip to content

Commit

Permalink
Merge branch 'master' into rc/anisotropy
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Aug 23, 2023
2 parents b0a1a5e + 61797b6 commit 52f7c5c
Show file tree
Hide file tree
Showing 117 changed files with 4,282 additions and 2,501 deletions.
861 changes: 443 additions & 418 deletions .test_durations

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def fori_loop(lower, upper, body_fun, init_val):
val = body_fun(i, val)
return val

def cond(pred, true_fun, false_fun, operand):
def cond(pred, true_fun, false_fun, *operand):
"""Conditionally apply true_fun or false_fun.
This version is for the numpy backend, for jax backend see jax.lax.cond
Expand All @@ -227,9 +227,9 @@ def cond(pred, true_fun, false_fun, operand):
"""
if pred:
return true_fun(operand)
return true_fun(*operand)
else:
return false_fun(operand)
return false_fun(*operand)

def switch(index, branches, operand):
"""Apply exactly one of branches given by index.
Expand Down
10 changes: 5 additions & 5 deletions desc/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
Parameters
----------
params : dict of ndarray
Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc
Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc.
transforms : dict of Transform
Transforms for R, Z, lambda, etc
Transforms for R, Z, lambda, etc.
profiles : dict of Profile
Profile objects for pressure, iota, current, etc
Profile objects for pressure, iota, current, etc.
data : dict of ndarray
Data computed so far, generally output from other compute functions
kwargs : dict
Expand Down Expand Up @@ -59,8 +59,8 @@
# import the compute module.
def _build_data_index():

for p in data_index.keys():
for key in data_index[p].keys():
for p in data_index:
for key in data_index[p]:
full = {
"data": get_data_deps(key, p, has_axis=False),
"transforms": get_derivs(key, p, has_axis=False),
Expand Down
52 changes: 42 additions & 10 deletions desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
transforms={},
profiles=[],
coordinates="rtz",
data=["B"],
data=["B", "|B|"],
)
def _b(params, transforms, profiles, data, **kwargs):
data["b"] = (data["B"].T / jnp.linalg.norm(data["B"], axis=-1)).T
data["b"] = (data["B"].T / data["|B|"]).T
return data


Expand All @@ -47,6 +47,8 @@ def _b(params, transforms, profiles, data, **kwargs):
data=["e_theta/sqrt(g)", "e_zeta"],
)
def _e_sup_rho(params, transforms, profiles, data, **kwargs):
# At the magnetic axis, this function returns the multivalued map whose
# image is the set { 𝐞^ρ | ρ=0 }.
data["e^rho"] = cross(data["e_theta/sqrt(g)"], data["e_zeta"])
return data

Expand Down Expand Up @@ -196,8 +198,14 @@ def _e_sup_theta(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_rho", "e_zeta"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _e_sup_theta_times_sqrt_g(params, transforms, profiles, data, **kwargs):
# At the magnetic axis, this function returns the multivalued map whose
# image is the set { 𝐞^θ √g | ρ=0 }.
data["e^theta*sqrt(g)"] = cross(data["e_zeta"], data["e_rho"])
return data

Expand Down Expand Up @@ -299,6 +307,8 @@ def _e_sup_theta_z(params, transforms, profiles, data, **kwargs):
data=["e_rho", "e_theta/sqrt(g)"],
)
def _e_sup_zeta(params, transforms, profiles, data, **kwargs):
# At the magnetic axis, this function returns the multivalued map whose
# image is the set { 𝐞^΢ | ρ=0 }.
data["e^zeta"] = cross(data["e_rho"], data["e_theta/sqrt(g)"])
return data

Expand Down Expand Up @@ -453,6 +463,8 @@ def _e_sub_phi(params, transforms, profiles, data, **kwargs):
data=["R", "R_r", "Z_r", "omega_r"],
)
def _e_sub_rho(params, transforms, profiles, data, **kwargs):
# At the magnetic axis, this function returns the multivalued map whose
# image is the set { 𝐞ᡨ | ρ=0 }.
data["e_rho"] = jnp.array([data["R_r"], data["R"] * data["omega_r"], data["Z_r"]]).T
return data

Expand Down Expand Up @@ -1386,6 +1398,8 @@ def _e_sub_theta(params, transforms, profiles, data, **kwargs):
axis_limit_data=["e_theta_r", "sqrt(g)_r"],
)
def _e_sub_theta_over_sqrt_g(params, transforms, profiles, data, **kwargs):
# At the magnetic axis, this function returns the multivalued map whose
# image is the set { 𝐞_θ / √g | ρ=0 }.
data["e_theta/sqrt(g)"] = transforms["grid"].replace_at_axis(
(data["e_theta"].T / data["sqrt(g)"]).T,
lambda: (data["e_theta_r"].T / data["sqrt(g)_r"]).T,
Expand Down Expand Up @@ -1426,6 +1440,8 @@ def _e_sub_theta_pest(params, transforms, profiles, data, **kwargs):
data=["R", "R_r", "R_rt", "R_t", "Z_rt", "omega_r", "omega_rt", "omega_t"],
)
def _e_sub_theta_r(params, transforms, profiles, data, **kwargs):
# At the magnetic axis, this function returns the multivalued map whose
# image is the set { βˆ‚α΅¨ 𝐞_ΞΈ | ρ=0 }
data["e_theta_r"] = jnp.array(
[
-data["R"] * data["omega_t"] * data["omega_r"] + data["R_rt"],
Expand Down Expand Up @@ -3428,16 +3444,22 @@ def _gradpsi(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_theta", "e_zeta", "|e_theta x e_zeta|"],
axis_limit_data=["e_theta_r", "|e_theta x e_zeta|_r"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _n_rho(params, transforms, profiles, data, **kwargs):
# equal to e^rho / |e^rho| but works correctly for surfaces as well that don't have
# contravariant basis defined
data["n_rho"] = (
cross(data["e_theta"], data["e_zeta"]) / data["|e_theta x e_zeta|"][:, None]
# Equal to 𝐞^ρ / β€–πž^ρ‖ but works correctly for surfaces as well that don't
# have contravariant basis defined.
data["n_rho"] = transforms["grid"].replace_at_axis(
(cross(data["e_theta"], data["e_zeta"]).T / data["|e_theta x e_zeta|"]).T,
# At the magnetic axis, this function returns the multivalued map whose
# image is the set { 𝐞^ρ / β€–πž^ρ‖ | ρ=0 }.
lambda: (
cross(data["e_theta_r"], data["e_zeta"]).T / data["|e_theta x e_zeta|_r"]
).T,
)
return data

Expand All @@ -3460,9 +3482,11 @@ def _n_rho(params, transforms, profiles, data, **kwargs):
],
)
def _n_theta(params, transforms, profiles, data, **kwargs):
# Equal to 𝐞^ΞΈ / β€–πž^ΞΈβ€– but works correctly for surfaces as well that don't
# have contravariant basis defined.
data["n_theta"] = (
cross(data["e_zeta"], data["e_rho"]) / data["|e_zeta x e_rho|"][:, None]
)
cross(data["e_zeta"], data["e_rho"]).T / data["|e_zeta x e_rho|"]
).T
return data


Expand All @@ -3478,13 +3502,21 @@ def _n_theta(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_rho", "e_theta", "|e_rho x e_theta|"],
axis_limit_data=["e_theta_r", "|e_rho x e_theta|_r"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _n_zeta(params, transforms, profiles, data, **kwargs):
data["n_zeta"] = (
cross(data["e_rho"], data["e_theta"]) / data["|e_rho x e_theta|"][:, None]
# Equal to 𝐞^ΞΆ / β€–πž^ΞΆβ€– but works correctly for surfaces as well that don't
# have contravariant basis defined.
data["n_zeta"] = transforms["grid"].replace_at_axis(
(cross(data["e_rho"], data["e_theta"]).T / data["|e_rho x e_theta|"]).T,
# At the magnetic axis, this function returns the multivalued map whose
# image is the set { 𝐞^ΞΆ / β€–πž^ΞΆβ€– | ρ=0 }.
lambda: (
cross(data["e_rho"], data["e_theta_r"]).T / data["|e_rho x e_theta|_r"]
).T,
)
return data
Loading

0 comments on commit 52f7c5c

Please sign in to comment.