Skip to content

Commit

Permalink
Working TR-BDF2 implementation with implicit error estimate, but Rada…
Browse files Browse the repository at this point in the history
…u IIA is way superior.
  • Loading branch information
JonasBreuling committed Aug 1, 2024
1 parent e031a76 commit e8c8e94
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 79 deletions.
2 changes: 1 addition & 1 deletion examples/particle_on_circular_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def sol_true(t):
print(f"fnorm: {fnorm}")

# solver options
atol = rtol = 1e-5
atol = rtol = 1e-4

##############
# dae solution
Expand Down
1 change: 1 addition & 0 deletions examples/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def f(t, z):

# method = "BDF"
method = "Radau"
method = "TR-BDF2"

# initial conditions
y0 = np.array([l, 0, 0, 0, 0, 0], dtype=float)
Expand Down
4 changes: 2 additions & 2 deletions examples/robertson.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def fun_composite(t, z):
t1 = 1e7
t_span = (t0, t1)
t_eval = np.logspace(-6, 7, num=1000)
t_eval = None
# t_eval = None

# method = "BDF"
method = "Radau"
Expand All @@ -76,7 +76,7 @@ def fun_composite(t, z):
print(f"fnorm: {fnorm}")

# solver options
atol = rtol = 1e-6
atol = rtol = 1e-5

####################
# reference solution
Expand Down
174 changes: 98 additions & 76 deletions scipy_dae/integrate/_dae/trbdf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,35 @@
gamma = 2 - S2
d = gamma / 2
w = S2 / 4

# Butcher Tableau coefficients for the error estimator
e0 = (1 - w) / 3
e1 = w + (1 / 3)
e2 = d / 3

# embedded implicit method
b1_hat = (1 - w) / 3
# b1_hat = d
b2_hat = w + 1 / 3
b3_hat = d / 3
# b4_hat = d

# b23_hat = np.linalg.solve(
# np.array([
# [ gamma, 1],
# [gamma**2, 1],
# ]),
# np.array([
# 1 / 2 - b1_hat - b4_hat,
# 1 / 3 - b4_hat,
# ])
# )

# b2_hat, b3_hat = b23_hat

# compute embedded method for error estimate
c_hat = np.array([0, gamma, 1])
vander = np.vander(c_hat, increasing=True).T

rhs = 1 / np.arange(1, len(c_hat))
gamma0 = 1 / d
# gamma0 = d
b0 = gamma0
rhs[0] -= b0
rhs -= gamma0

b_hat = np.linalg.solve(vander[:-1, 1:], rhs)
# b_hat = np.array([b0, *b_hat])
# b = np.array([w, w, d])
b = np.array([w, d])
v = b - b_hat
C = np.array([0, gamma, 1])

# compute embedded implicit method
vander = np.array([
[1, 1, 1],
[0, gamma, 1],
[0, gamma**2, 1],
])
vander_inv = np.array([
[1, -(gamma + 1) / gamma, 1 / gamma],
[0, gamma / (-gamma**3 + gamma**2), -gamma / (-gamma**3 + gamma**2)],
[0, -gamma**2 / (-gamma**2 + gamma), gamma / (-gamma**2 + gamma)],
])
b4_hat = d
rhs = np.array([
1 - b4_hat,
1 / 2 - b4_hat,
1 / 3 - b4_hat,
])
b_hat = np.linalg.solve(vander, rhs)
b_hat = vander_inv @ rhs
b1_hat, b2_hat, b3_hat = b_hat
assert np.allclose(np.sum(b_hat) + b4_hat, 1)

# Compute the inverse of the Vandermonde matrix to get the
# interpolation matrix P.
# vander = np.vander(C, increasing=True).T
# vander_inv = np.linalg.inv(vander)
P = vander_inv[1:, 1:]

# Coefficients required for interpolating y'
pd0 = 1.5 + S2
Expand Down Expand Up @@ -333,7 +320,6 @@ def _step_impl(self):
current_jac = self.current_jac
jac = self.jac

rejected = False
step_accepted = False
message = None
while not step_accepted:
Expand All @@ -342,7 +328,6 @@ def _step_impl(self):

h = h_abs * self.direction
t_new = t + h
# print(f"t_new: {t_new}")

if self.direction * (t_new - self.t_bound) > 0:
t_new = self.t_bound
Expand All @@ -355,16 +340,22 @@ def _step_impl(self):
scale = atol + np.abs(y) * rtol

# TODO: Better initial guess for z0 as explained by Hosea?
if self.sol is None:
Z0 = np.zeros((2, y.shape[0]))
else:
Z0 = self.sol(t + h * C).T - y

# TR stage
# z_bdf0 = z0
z_bdf0 = Z0[0]
converged_tr = False
while not converged_tr:
if LU is None:
LU = self.lu(Jy + Jyp / (d * h))

t_gamma = t + h * gamma
fun_tr = lambda z: self.fun(t_gamma, y + z, z / (d * h) - yp)
converged_tr, n_iter_tr, z_tr, rate_tr = solve_trbdf2_system(fun_tr, z0, LU, self.solve_lu, scale, self.newton_tol)
converged_tr, n_iter_tr, z_tr, rate_tr = solve_trbdf2_system(fun_tr, z_bdf0, LU, self.solve_lu, scale, self.newton_tol)

if not converged_tr:
if current_jac:
Expand All @@ -382,7 +373,8 @@ def _step_impl(self):
yp_tr = z_tr / (h * d) - yp

# BDF stage
z_bdf0 = pd0 * z0 + pd1 * z_tr + pd2 * (y_tr - y)
# z_bdf0 = pd0 * z0 + pd1 * z_tr + pd2 * (y_tr - y)
z_bdf0 = Z0[1]
converged_bdf = False
while not converged_bdf:
if LU is None:
Expand All @@ -405,39 +397,43 @@ def _step_impl(self):

n_iter = max(n_iter_tr, n_iter_bdf)
rate = max(rate_tr, rate_bdf)
# if rate_tr is not None:
# if rate_bdf is not None:
# rate = max(rate_tr, rate_bdf)
# else:
# rate = rate_tr
# else:
# rate = 0

y_new = y + z_bdf
yp_new = z_bdf / (h * d) - (w / d) * (yp + yp_tr)

# error = 0.5 * (y + e0 * z0 + e1 * z_tr + e2 * z_bdf - y_new)
error = h * ((b1_hat - w) * yp + (b2_hat - w) * yp_tr + (b3_hat - d) * yp_new)
# error = self.solve_lu(LU, error) #* (d * h) #* d / 3

# # implicit error estimate
# # yp_hat_new = (y_new - y) / (b4_hat * h) - (b1_hat / b4_hat) * yp - (b2_hat / b4_hat) * yp_tr - (b3_hat / b4_hat) * yp_new
# yp_hat_new = (v @ np.array([yp_tr, yp_new]) - b0 * yp) * d
# F = self.fun(t_new, y_new, yp_hat_new)
# error = self.solve_lu(LU, -F)
# error = h * d * (v @ np.array([yp_tr, yp_new]) - b0 * yp - yp_new / d)

# # # error_Fabien = h * MU_REAL * (v @ Yp - b0 * yp - yp_new / MU_REAL)
# # yp_hat_new = MU_REAL * (v @ Yp - b0 * yp)
# # F = self.fun(t_new, y_new, yp_hat_new)
# # error_Fabien = self.solve_lu(LU_real, -F)
# embedded method of Hosea
# error = h * ((b1_hat - w) * yp + (b2_hat - w) * yp_tr + (b3_hat - d) * yp_new)

# implicit error estimate
# yp_hat_new = (
# # (y_new - y) / (b4_hat * h)
# z_bdf / (b4_hat * h)
# - (b1_hat / b4_hat) * yp
# - (b2_hat / b4_hat) * yp_tr
# - (b3_hat / b4_hat) * yp_new
# )
# yp_hat_new = (
# z_bdf / h
# - (b1_hat) * yp
# - (b2_hat) * yp_tr
# - (b3_hat) * yp_new
# ) / b4_hat
# yp_hat_new = (
# z_bdf / h # = (w * yp + w * yp_tr + d * yp_new)
# - (b1_hat) * yp
# - (b2_hat) * yp_tr
# - (b3_hat) * yp_new
# ) / b4_hat
yp_hat_new = (
(w - b1_hat) * yp
+ (w - b2_hat) * yp_tr
+ (d - b3_hat) * yp_new
) / b4_hat
F = self.fun(t_new, y_new, yp_hat_new)
error = self.solve_lu(LU, -F)

scale = atol + np.maximum(np.abs(y), np.abs(y_new)) * rtol

# TODO: Add stabilized error
# stabilised_error = self.solve_lu(LU, error)
stabilised_error = error
error_norm = norm(stabilised_error / scale)
error_norm = norm(error / scale)

safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)

Expand Down Expand Up @@ -495,12 +491,16 @@ def _step_impl(self):
self.z_bdf = z_bdf
self.y_tr = y_tr

self.Z = np.vstack((z_tr, z_bdf))

return step_accepted, message

def _dense_output_impl(self):
return Trbdf2DenseOutput(self.t_old, self.t, self.h_old,
self.z0, self.z_tr, self.z_bdf,
self.y_old, self.y_tr, self.y)
# return Trbdf2DenseOutput(self.t_old, self.t, self.h_old,
# self.z0, self.z_tr, self.z_bdf,
# self.y_old, self.y_tr, self.y)
Q = np.dot(self.Z.T, P)
return RadauDenseOutput(self.t_old, self.t, self.y_old, Q)


class Trbdf2DenseOutput(DenseOutput):
Expand Down Expand Up @@ -549,3 +549,25 @@ def _call_impl(self, t):
y.append(self._call_impl(tk))
y = np.array(y).T
return y


class RadauDenseOutput(DenseOutput):
def __init__(self, t_old, t, y_old, Q):
super().__init__(t_old, t)
self.h = t - t_old
self.Q = Q
self.order = Q.shape[1] - 1
self.y_old = y_old

def _call_impl(self, t):
x = (t - self.t_old) / self.h
x = np.atleast_1d(x)
p = np.tile(x, (self.order + 1, 1))
p = np.cumprod(p, axis=0)
# Here we don't multiply by h, not a mistake.
y = np.dot(self.Q, p)
y += self.y_old[:, None]
if t.ndim == 0:
y = np.squeeze(y)

return y

0 comments on commit e8c8e94

Please sign in to comment.