Skip to content

Commit

Permalink
Refactor coil test to use parameterize (#1362)
Browse files Browse the repository at this point in the history
Resolves #1104
  • Loading branch information
f0uriest authored Nov 13, 2024
2 parents a880816 + 71dc2c6 commit df95b90
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 59 deletions.
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,13 @@ def DummyMixedCoilSet(tmpdir_factory):
vf_coil, displacement=[0, 0, 2], n=3, endpoint=True
)
xyz_coil = FourierXYZCoil(current=2)
phi = 2 * np.pi * np.linspace(0, 1, 20, endpoint=True) ** 2
phi = 2 * np.pi * np.linspace(0, 1, 20, endpoint=True)
spline_coil = SplineXYZCoil(
current=1,
X=np.cos(phi),
Y=np.sin(phi),
Z=np.zeros_like(phi),
knots=np.linspace(0, 2 * np.pi, len(phi)),
knots=phi,
)
full_coilset = MixedCoilSet(
(tf_coilset, vf_coilset, xyz_coil, spline_coil), check_intersection=False
Expand Down
124 changes: 67 additions & 57 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,20 @@ def test_second_stage_optimization_CoilSet():

@pytest.mark.slow
@pytest.mark.unit
def test_optimize_with_all_coil_types(DummyCoilSet, DummyMixedCoilSet):
@pytest.mark.parametrize(
"coil_type",
[
"FourierPlanarCoil",
"FourierRZCoil",
"FourierXYZCoil",
"SplineXYZCoil",
"CoilSet sym",
"CoilSet asym",
"MixedCoilSet",
"nested CoilSet",
],
)
def test_optimize_with_all_coil_types(DummyCoilSet, DummyMixedCoilSet, coil_type):
"""Test optimizing for every type of coil and dummy coil sets."""
sym_coils = load(load_from=str(DummyCoilSet["output_path_sym"]), file_format="hdf5")
asym_coils = load(
Expand All @@ -1340,66 +1353,63 @@ def test_optimize_with_all_coil_types(DummyCoilSet, DummyMixedCoilSet):
quad_eval_grid = LinearGrid(M=2, sym=True)
quad_field_grid = LinearGrid(N=2)

def test(c, method):
target = 11
rtol = 1e-3
# first just check that quad flux works for a couple iterations
# as this is an expensive objective to compute
obj = ObjectiveFunction(
QuadraticFlux(
eq=eq,
field=c,
vacuum=True,
weight=1e-4,
eval_grid=quad_eval_grid,
field_grid=quad_field_grid,
)
)
optimizer = Optimizer(method)
(c,), _ = optimizer.optimize(c, obj, maxiter=2, ftol=0, xtol=1e-15)

# now check with optimizing geometry and actually check result
objs = [
CoilLength(c, target=target),
]
extra_msg = ""
if isinstance(c, MixedCoilSet):
# just to check they work without error
objs.extend(
[
CoilCurvature(c, target=0.5, weight=1e-2),
CoilTorsion(c, target=0, weight=1e-2),
]
)
rtol = 3e-2
extra_msg = " with curvature and torsion obj"

obj = ObjectiveFunction(objs)

(c,), _ = optimizer.optimize(c, obj, maxiter=25, ftol=5e-3, xtol=1e-15)
flattened_coils = tree_leaves(
c, is_leaf=lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet)
)
lengths = [coil.compute("length")["length"] for coil in flattened_coils]
np.testing.assert_allclose(
lengths, target, rtol=rtol, err_msg=f"lengths {c}" + extra_msg
)

spline_coil = mixed_coils.coils[-1].copy()

# single coil
test(FourierPlanarCoil(), "fmintr")
test(FourierRZCoil(), "fmintr")
test(FourierXYZCoil(), "fmintr")
test(spline_coil, "fmintr")
types = {
"FourierPlanarCoil": (FourierPlanarCoil(), "fmintr"),
"FourierRZCoil": (FourierRZCoil(), "fmintr"),
"FourierXYZCoil": (FourierXYZCoil(), "fmintr"),
"SplineXYZCoil": (spline_coil, "fmintr"),
"CoilSet sym": (sym_coils, "lsq-exact"),
"CoilSet asym": (asym_coils, "lsq-exact"),
"MixedCoilSet": (mixed_coils, "lsq-exact"),
"nested CoilSet": (nested_coils, "lsq-exact"),
}
c, method = types[coil_type]

target = 11
rtol = 1e-3
# first just check that quad flux works for a couple iterations
# as this is an expensive objective to compute
obj = ObjectiveFunction(
QuadraticFlux(
eq=eq,
field=c,
vacuum=True,
weight=1e-4,
eval_grid=quad_eval_grid,
field_grid=quad_field_grid,
)
)
optimizer = Optimizer(method)
(cc,), _ = optimizer.optimize(c, obj, maxiter=2, ftol=0, xtol=1e-8, copy=True)

# now check with optimizing geometry and actually check result
objs = [
CoilLength(c, target=target),
]
extra_msg = ""
if isinstance(c, MixedCoilSet):
# just to check they work without error
objs.extend(
[
CoilCurvature(c, target=0.5, weight=1e-2),
CoilTorsion(c, target=0, weight=1e-2),
]
)
rtol = 3e-2
extra_msg = " with curvature and torsion obj"

# CoilSet
test(sym_coils, "lsq-exact")
test(asym_coils, "lsq-exact")
obj = ObjectiveFunction(objs)

# MixedCoilSet
test(mixed_coils, "lsq-exact")
test(nested_coils, "lsq-exact")
(c,), _ = optimizer.optimize(c, obj, maxiter=25, ftol=5e-3, xtol=1e-8)
flattened_coils = tree_leaves(
c, is_leaf=lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet)
)
lengths = [coil.compute("length")["length"] for coil in flattened_coils]
np.testing.assert_allclose(
lengths, target, rtol=rtol, err_msg=f"lengths {c}" + extra_msg
)


@pytest.mark.unit
Expand Down

0 comments on commit df95b90

Please sign in to comment.