Skip to content

Commit

Permalink
Merge pull request #1105 from AayushSabharwal/as/late-tstops
Browse files Browse the repository at this point in the history
refactor: check if algorithm supports late binding tstops
  • Loading branch information
ChrisRackauckas authored Nov 14, 2024
2 parents b52cf18 + 831d823 commit e2c83e9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,4 @@ jobs:
with:
file: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: false
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Printf = "1.9"
RecursiveArrayTools = "3"
Reexport = "1.0"
ReverseDiff = "1"
SciMLBase = "2.56.0"
SciMLBase = "2.60.0"
SciMLOperators = "0.3"
SciMLStructures = "1.5"
Setfield = "1"
Expand Down
27 changes: 27 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,19 @@ function Base.showerror(io::IO, e::IncompatibleMassMatrixError)
println(io, TruncatedStacktraces.VERBOSE_MSG)
end

const LATE_BINDING_TSTOPS_ERROR_MESSAGE = """
This solver does not support providing `tstops` as a function.
Consider using a different solver or providing `tstops` as an array
of times.
"""

struct LateBindingTstopsNotSupportedError <: Exception end

function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError)
println(io, LATE_BINDING_TSTOPS_ERROR_MESSAGE)
println(io, TruncatedStacktraces.VERBOSE_MSG)
end

function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
kwargs...)
kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle
Expand Down Expand Up @@ -555,6 +568,13 @@ function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...)
p = p, kwargs...)
init_call(_prob, args...; kwargs...)
else
tstops = get(kwargs, :tstops, nothing)
if tstops === nothing && has_kwargs(prob)
tstops = get(prob.kwargs, :tstops, nothing)
end
if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) && !SciMLBase.allows_late_binding_tstops(alg)
throw(LateBindingTstopsNotSupportedError())
end
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
check_prob_alg_pairing(_prob, alg) # alg for improved inference
Expand Down Expand Up @@ -1084,6 +1104,13 @@ function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0
p = p, kwargs...)
solve_call(_prob, args...; kwargs...)
else
tstops = get(kwargs, :tstops, nothing)
if tstops === nothing && has_kwargs(prob)
tstops = get(prob.kwargs, :tstops, nothing)
end
if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) && !SciMLBase.allows_late_binding_tstops(alg)
throw(LateBindingTstopsNotSupportedError())
end
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
check_prob_alg_pairing(_prob, alg) # use alg for improved inference
Expand Down

0 comments on commit e2c83e9

Please sign in to comment.