Skip to content

Commit 946322a

Browse files
Merge pull request #3144 from BenChung/affect-reinitalize
Add support for the initializealg argument in SciMLBase callbacks
2 parents 2770a8f + 47f84fe commit 946322a

File tree

2 files changed

+152
-15
lines changed

2 files changed

+152
-15
lines changed

src/systems/callbacks.jl

+51-15
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,25 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
106106
+ `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
107107
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
108108
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
109+
110+
DAEs will be reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`) after callbacks are applied.
111+
This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
112+
that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
113+
initialization.
109114
"""
110115
struct SymbolicContinuousCallback
111116
eqs::Vector{Equation}
112117
affect::Union{Vector{Equation}, FunctionalAffect}
113118
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
114119
rootfind::SciMLBase.RootfindOpt
115-
function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT,
116-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
117-
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind)
120+
reinitializealg::SciMLBase.DAEInitializationAlgorithm
121+
function SymbolicContinuousCallback(;
122+
eqs::Vector{Equation},
123+
affect = NULL_AFFECT,
124+
affect_neg = affect,
125+
rootfind = SciMLBase.LeftRootFind,
126+
reinitializealg = SciMLBase.CheckInit())
127+
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg)
118128
end # Default affect to nothing
119129
end
120130
make_affect(affect) = affect
@@ -183,6 +193,12 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback})
183193
mapreduce(affect_negs, vcat, cbs, init = Equation[])
184194
end
185195

196+
reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg
197+
function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback})
198+
mapreduce(
199+
reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
200+
end
201+
186202
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
187203
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
188204
namespace_affects(::Nothing, s) = nothing
@@ -225,11 +241,13 @@ struct SymbolicDiscreteCallback
225241
# TODO: Iterative
226242
condition::Any
227243
affects::Any
244+
reinitializealg::SciMLBase.DAEInitializationAlgorithm
228245

229-
function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT)
246+
function SymbolicDiscreteCallback(
247+
condition, affects = NULL_AFFECT, reinitializealg = SciMLBase.CheckInit())
230248
c = scalarize_condition(condition)
231249
a = scalarize_affects(affects)
232-
new(c, a)
250+
new(c, a, reinitializealg)
233251
end # Default affect to nothing
234252
end
235253

@@ -286,6 +304,12 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
286304
reduce(vcat, affects(cb) for cb in cbs; init = [])
287305
end
288306

307+
reinitialization_alg(cb::SymbolicDiscreteCallback) = cb.reinitializealg
308+
function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback})
309+
mapreduce(
310+
reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
311+
end
312+
289313
function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
290314
af = affects(cb)
291315
af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
@@ -579,13 +603,15 @@ function generate_single_rootfinding_callback(
579603
initfn = SciMLBase.INITIALIZE_DEFAULT
580604
end
581605
return ContinuousCallback(
582-
cond, affect_function.affect, affect_function.affect_neg,
583-
rootfind = cb.rootfind, initialize = initfn)
606+
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
607+
initialize = initfn,
608+
initializealg = reinitialization_alg(cb))
584609
end
585610

586611
function generate_vector_rootfinding_callback(
587612
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
588-
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
613+
ps = parameters(sys); rootfind = SciMLBase.RightRootFind,
614+
reinitialization = SciMLBase.CheckInit(), kwargs...)
589615
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
590616
num_eqs = length.(eqs)
591617
# fuse equations to create VectorContinuousCallback
@@ -650,7 +676,8 @@ function generate_vector_rootfinding_callback(
650676
initfn = SciMLBase.INITIALIZE_DEFAULT
651677
end
652678
return VectorContinuousCallback(
653-
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initfn)
679+
cond, affect, affect_neg, length(eqs), rootfind = rootfind,
680+
initialize = initfn, initializealg = reinitialization)
654681
end
655682

656683
"""
@@ -690,18 +717,24 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
690717
# group the cbs by what rootfind op they use
691718
# groupby would be very useful here, but alas
692719
cb_classes = Dict{
693-
@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
720+
@NamedTuple{
721+
rootfind::SciMLBase.RootfindOpt,
722+
reinitialization::SciMLBase.DAEInitializationAlgorithm}, Vector{SymbolicContinuousCallback}}()
694723
for cb in cbs
695724
push!(
696-
get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)),
725+
get!(() -> SymbolicContinuousCallback[], cb_classes,
726+
(
727+
rootfind = cb.rootfind,
728+
reinitialization = reinitialization_alg(cb))),
697729
cb)
698730
end
699731

700732
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
701733
compiled_callbacks = map(collect(pairs(sort!(
702734
OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class)
703735
return generate_vector_rootfinding_callback(
704-
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...)
736+
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind,
737+
reinitialization = equiv_class.reinitialization, kwargs...)
705738
end
706739
if length(compiled_callbacks) == 1
707740
return compiled_callbacks[]
@@ -772,10 +805,12 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
772805
end
773806
if cond isa AbstractVector
774807
# Preset Time
775-
return PresetTimeCallback(cond, as; initialize = initfn)
808+
return PresetTimeCallback(
809+
cond, as; initialize = initfn, initializealg = reinitialization_alg(cb))
776810
else
777811
# Periodic
778-
return PeriodicCallback(as, cond; initialize = initfn)
812+
return PeriodicCallback(
813+
as, cond; initialize = initfn, initializealg = reinitialization_alg(cb))
779814
end
780815
end
781816

@@ -800,7 +835,8 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
800835
else
801836
initfn = SciMLBase.INITIALIZE_DEFAULT
802837
end
803-
return DiscreteCallback(c, as; initialize = initfn)
838+
return DiscreteCallback(
839+
c, as; initialize = initfn, initializealg = reinitialization_alg(cb))
804840
end
805841
end
806842

test/symbolic_events.jl

+101
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,88 @@ end
867867
@test sign.(cos.(3 * (required_crossings_c2 .+ 1e-6))) == sign.(last.(cr2))
868868
end
869869

870+
@testset "Discrete event reinitialization (#3142)" begin
871+
@connector LiquidPort begin
872+
p(t)::Float64, [description = "Set pressure in bar",
873+
guess = 1.01325]
874+
Vdot(t)::Float64,
875+
[description = "Volume flow rate in L/min",
876+
guess = 0.0,
877+
connect = Flow]
878+
end
879+
880+
@mtkmodel PressureSource begin
881+
@components begin
882+
port = LiquidPort()
883+
end
884+
@parameters begin
885+
p_set::Float64 = 1.01325, [description = "Set pressure in bar"]
886+
end
887+
@equations begin
888+
port.p ~ p_set
889+
end
890+
end
891+
892+
@mtkmodel BinaryValve begin
893+
@constants begin
894+
p_ref::Float64 = 1.0, [description = "Reference pressure drop in bar"]
895+
ρ_ref::Float64 = 1000.0, [description = "Reference density in kg/m^3"]
896+
end
897+
@components begin
898+
port_in = LiquidPort()
899+
port_out = LiquidPort()
900+
end
901+
@parameters begin
902+
k_V::Float64 = 1.0, [description = "Valve coefficient in L/min/bar"]
903+
k_leakage::Float64 = 1e-08, [description = "Leakage coefficient in L/min/bar"]
904+
ρ::Float64 = 1000.0, [description = "Density in kg/m^3"]
905+
end
906+
@variables begin
907+
S(t)::Float64, [description = "Valve state", guess = 1.0, irreducible = true]
908+
Δp(t)::Float64, [description = "Pressure difference in bar", guess = 1.0]
909+
Vdot(t)::Float64, [description = "Volume flow rate in L/min", guess = 1.0]
910+
end
911+
@equations begin
912+
# Port handling
913+
port_in.Vdot ~ -Vdot
914+
port_out.Vdot ~ Vdot
915+
Δp ~ port_in.p - port_out.p
916+
# System behavior
917+
D(S) ~ 0.0
918+
Vdot ~ S * k_V * sign(Δp) * sqrt(abs(Δp) / p_ref * ρ_ref / ρ) + k_leakage * Δp # softplus alpha function to avoid negative values under the sqrt
919+
end
920+
end
921+
922+
# Test System
923+
@mtkmodel TestSystem begin
924+
@components begin
925+
pressure_source_1 = PressureSource(p_set = 2.0)
926+
binary_valve_1 = BinaryValve(S = 1.0, k_leakage = 0.0)
927+
binary_valve_2 = BinaryValve(S = 1.0, k_leakage = 0.0)
928+
pressure_source_2 = PressureSource(p_set = 1.0)
929+
end
930+
@equations begin
931+
connect(pressure_source_1.port, binary_valve_1.port_in)
932+
connect(binary_valve_1.port_out, binary_valve_2.port_in)
933+
connect(binary_valve_2.port_out, pressure_source_2.port)
934+
end
935+
@discrete_events begin
936+
[30] => [binary_valve_1.S ~ 0.0, binary_valve_2.Δp ~ 0.0]
937+
[60] => [
938+
binary_valve_1.S ~ 1.0, binary_valve_2.S ~ 0.0, binary_valve_2.Δp ~ 1.0]
939+
[120] => [binary_valve_1.S ~ 0.0, binary_valve_2.Δp ~ 0.0]
940+
end
941+
end
942+
943+
# Test Simulation
944+
@mtkbuild sys = TestSystem()
945+
946+
# Test Simulation
947+
prob = ODEProblem(sys, [], (0.0, 150.0))
948+
sol = solve(prob)
949+
@test sol[end] == [0.0, 0.0, 0.0]
950+
end
951+
870952
@testset "Discrete variable timeseries" begin
871953
@variables x(t)
872954
@parameters a(t) b(t) c(t)
@@ -887,3 +969,22 @@ end
887969
@test sol[b] == [2.0, 5.0, 5.0]
888970
@test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
889971
end
972+
973+
@testset "Bump" begin
974+
@variables x(t) [irreducible = true] y(t) [irreducible = true]
975+
eqs = [x ~ y, D(x) ~ -1]
976+
cb = [x ~ 0.0] => [x ~ 0, y ~ 1]
977+
@mtkbuild pend = ODESystem(eqs, t; continuous_events = [cb])
978+
prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x])
979+
@test_throws "CheckInit specified but initialization" solve(prob, Rodas5())
980+
981+
cb = [x ~ 0.0] => [y ~ 1]
982+
@mtkbuild pend = ODESystem(eqs, t; continuous_events = [cb])
983+
prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x])
984+
@test_broken !SciMLBase.successful_retcode(solve(prob, Rodas5()))
985+
986+
cb = [x ~ 0.0] => [x ~ 1, y ~ 1]
987+
@mtkbuild pend = ODESystem(eqs, t; continuous_events = [cb])
988+
prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x])
989+
@test all((0.0; atol = 1e-9), solve(prob, Rodas5())[[x, y]][end])
990+
end

0 commit comments

Comments
 (0)