Skip to content

Commit c9ca03e

Browse files
committed
feat: add optimization for explicit affects
1 parent d4614e7 commit c9ca03e

File tree

5 files changed

+124
-88
lines changed

5 files changed

+124
-88
lines changed

src/systems/callbacks.jl

+112-73
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ struct AffectSystem
6969
discretes::Vector
7070
"""Maps the symbols of unknowns/observed in the ImplicitDiscreteSystem to its corresponding unknown/parameter in the parent system."""
7171
aff_to_sys::Dict
72+
explicit::Bool
7273
end
7374

7475
system(a::AffectSystem) = a.system
@@ -77,6 +78,7 @@ unknowns(a::AffectSystem) = a.unknowns
7778
parameters(a::AffectSystem) = a.parameters
7879
aff_to_sys(a::AffectSystem) = a.aff_to_sys
7980
previous_vals(a::AffectSystem) = parameters(system(a))
81+
is_explicit(a::AffectSystem) = a.explicit
8082

8183
function Base.show(iio::IO, aff::AffectSystem)
8284
eqs = vcat(equations(system(aff)), observed(system(aff)))
@@ -105,6 +107,8 @@ Base.nameof(::Pre) = :Pre
105107
Base.show(io::IO, x::Pre) = print(io, "Pre")
106108
input_timedomain(::Pre, _ = nothing) = ContinuousClock()
107109
output_timedomain(::Pre, _ = nothing) = ContinuousClock()
110+
unPre(x::Num) = unPre(unwrap(x))
111+
unPre(x::BasicSymbolic) = operation(x) isa Pre ? only(arguments(x)) : x
108112

109113
function (p::Pre)(x)
110114
iw = Symbolics.iswrapped(x)
@@ -229,24 +233,28 @@ function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs = Equation[
229233
isempty(affect) && return nothing
230234
isempty(algeeqs) && @warn "No algebraic equations were found. If the system has no algebraic equations, this can be disregarded. Otherwise pass in `algeeqs` to the SymbolicContinuousCallback constructor."
231235

236+
explicit = true
232237
affect = scalarize(affect)
233238
dvs = OrderedSet()
234239
params = OrderedSet()
240+
params = OrderedSet()
235241
for eq in affect
236242
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic())
237243
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x)."
244+
explicit = false
238245
end
239246
collect_vars!(dvs, params, eq, iv; op = Pre)
240247
end
241248
for eq in algeeqs
242249
collect_vars!(dvs, params, eq, iv)
250+
expilcit = false
243251
end
244252
if isnothing(iv)
245253
iv = isempty(dvs) ? iv : only(arguments(dvs[1]))
246254
isnothing(iv) && @warn "No independent variable specified and could not be inferred. If the iv appears in an affect equation explicitly, like x ~ t + 1, then it must be specified as an argument to the SymbolicContinuousCallback or SymbolicDiscreteCallback constructor. Otherwise this warning can be disregarded."
247255
end
248256

249-
# System parameters should become unknowns in the ImplicitDiscreteSystem.
257+
# Parameters in affect equations should become unknowns in the ImplicitDiscreteSystem.
250258
cb_params = Any[]
251259
discretes = Any[]
252260
p_as_dvs = Any[]
@@ -268,15 +276,15 @@ function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs = Equation[
268276
aff_map = Dict(zip(p_as_dvs, discretes))
269277
rev_map = Dict([v => k for (k, v) in aff_map])
270278
affect = Symbolics.substitute(affect, rev_map)
271-
@mtkbuild affectsys = ImplicitDiscreteSystem(vcat(affect, algeeqs), iv, collect(union(dvs, p_as_dvs)), cb_params)
279+
@named affectsys = ImplicitDiscreteSystem(vcat(affect, algeeqs), iv, collect(union(dvs, p_as_dvs)), cb_params)
272280
# get accessed parameters p from Pre(p) in the callback parameters
273-
params = filter(isparameter, map(x -> only(arguments(unwrap(x))), cb_params))
281+
params = filter(isparameter, map(x -> unPre(x), cb_params))
274282
# add unknowns to the map
275283
for u in dvs
276284
aff_map[u] = u
277285
end
278286

279-
return AffectSystem(affectsys, collect(dvs), params, discretes, aff_map)
287+
return AffectSystem(affectsys, collect(dvs), params, discretes, aff_map, explicit)
280288
end
281289

282290
function make_affect(affect; kwargs...)
@@ -468,7 +476,7 @@ end
468476
########## Namespacing Utilities ###########
469477
############################################
470478

471-
function namespace_affect(affect::FunctionalAffect, s)
479+
function namespace_affects(affect::FunctionalAffect, s)
472480
FunctionalAffect(func(affect),
473481
renamespace.((s,), unknowns(affect)),
474482
unknowns_syms(affect),
@@ -478,35 +486,35 @@ function namespace_affect(affect::FunctionalAffect, s)
478486
context(affect))
479487
end
480488

481-
function namespace_affect(affect::AffectSystem, s)
489+
function namespace_affects(affect::AffectSystem, s)
482490
AffectSystem(renamespace(s, system(affect)),
483491
renamespace.((s,), unknowns(affect)),
484492
renamespace.((s,), parameters(affect)),
485493
renamespace.((s,), discretes(affect)),
486-
Dict([k => renamespace(s, v) for (k, v) in aff_to_sys(affect)]))
494+
Dict([k => renamespace(s, v) for (k, v) in aff_to_sys(affect)]), is_explicit(affect))
487495
end
488-
namespace_affect(af::Nothing, s) = nothing
496+
namespace_affects(af::Nothing, s) = nothing
489497

490498
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
491499
SymbolicContinuousCallback(
492500
namespace_equation.(equations(cb), (s,)),
493-
namespace_affect(affects(cb), s),
494-
affect_neg = namespace_affect(affect_negs(cb), s),
495-
initialize = namespace_affect(initialize_affects(cb), s),
496-
finalize = namespace_affect(finalize_affects(cb), s),
501+
namespace_affects(affects(cb), s),
502+
affect_neg = namespace_affects(affect_negs(cb), s),
503+
initialize = namespace_affects(initialize_affects(cb), s),
504+
finalize = namespace_affects(finalize_affects(cb), s),
497505
rootfind = cb.rootfind)
498506
end
499507

500-
function namespace_condition(condition, s)
508+
function namespace_conditions(condition, s)
501509
is_timed_condition(condition) ? condition : namespace_expr(condition, s)
502510
end
503511

504512
function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
505513
SymbolicDiscreteCallback(
506-
namespace_condition(condition(cb), s),
514+
namespace_conditions(conditions(cb), s),
507515
namespace_affects(affects(cb), s),
508-
namespace_affects(initialize_affects(cb), s),
509-
namespace_affects(finalize_affects(cb), s))
516+
initialize = namespace_affects(initialize_affects(cb), s),
517+
finalize = namespace_affects(finalize_affects(cb), s))
510518
end
511519

512520
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
@@ -623,8 +631,6 @@ function compile_condition(cbs::Union{AbstractCallback, Vector{<:AbstractCallbac
623631
end
624632
end
625633
end
626-
627-
cond
628634
end
629635

630636
"""
@@ -707,12 +713,12 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
707713
inits = []
708714
finals = []
709715
for cb in cbs
710-
affect = compile_affect(cb.affect, cb, sys, default = EMPTY_FUNCTION)
716+
affect = compile_affect(cb.affect, cb, sys, default = EMPTY_FUNCTION, kwargs...)
711717
push!(affects, affect)
712-
affect_neg = (cb.affect_neg === cb.affect) ? affect : compile_affect(cb.affect_neg, cb, sys, default = EMPTY_FUNCTION)
718+
affect_neg = (cb.affect_neg === cb.affect) ? affect : compile_affect(cb.affect_neg, cb, sys, default = EMPTY_FUNCTION, kwargs...)
713719
push!(affect_negs, affect_neg)
714-
push!(inits, compile_affect(cb.initialize, cb, sys; default = nothing, is_init = true))
715-
push!(finals, compile_affect(cb.finalize, cb, sys; default = nothing))
720+
push!(inits, compile_affect(cb.initialize, cb, sys; default = nothing, is_init = true), kwargs...)
721+
push!(finals, compile_affect(cb.finalize, cb, sys; default = nothing), kwargs...)
716722
end
717723

718724
# Since there may be different number of conditions and affects,
@@ -729,8 +735,8 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
729735
isnothing(f) && return
730736
f(integ)
731737
end
732-
initialize = compile_vector_optional_affect(inits, SciMLBase.INITIALIZE_DEFAULT)
733-
finalize = compile_vector_optional_affect(finals, SciMLBase.FINALIZE_DEFAULT)
738+
initialize = wrap_vector_optional_affect(inits, SciMLBase.INITIALIZE_DEFAULT)
739+
finalize = wrap_vector_optional_affect(finals, SciMLBase.FINALIZE_DEFAULT)
734740

735741
return VectorContinuousCallback(
736742
trigger, affect, affect_neg, length(eqs); initialize, finalize,
@@ -743,14 +749,14 @@ function generate_callback(cb, sys; kwargs...)
743749
ps = parameters(sys; initial_parameters = true)
744750

745751
trigger = is_timed ? conditions(cb) : compile_condition(cb, sys, dvs, ps; kwargs...)
746-
affect = compile_affect(cb.affect, cb, sys, default = EMPTY_FUNCTION)
752+
affect = compile_affect(cb.affect, cb, sys, default = EMPTY_FUNCTION, kwargs...)
747753
affect_neg = if is_discrete(cb)
748754
nothing
749755
else
750-
(cb.affect === cb.affect_neg) ? affect : compile_affect(cb.affect_neg, cb, sys, default = EMPTY_FUNCTION)
756+
(cb.affect === cb.affect_neg) ? affect : compile_affect(cb.affect_neg, cb, sys, default = EMPTY_FUNCTION, kwargs...)
751757
end
752-
init = compile_affect(cb.initialize, cb, sys, default = SciMLBase.INITIALIZE_DEFAULT, is_init = true)
753-
final = compile_affect(cb.finalize, cb, sys, default = SciMLBase.FINALIZE_DEFAULT)
758+
init = compile_affect(cb.initialize, cb, sys, default = SciMLBase.INITIALIZE_DEFAULT, is_init = true, kwargs...)
759+
final = compile_affect(cb.finalize, cb, sys, default = SciMLBase.FINALIZE_DEFAULT, kwargs...)
754760

755761
initialize = isnothing(cb.initialize) ? init : ((c, u, t, i) -> init(i))
756762
finalize = isnothing(cb.finalize) ? final : ((c, u, t, i) -> final(i))
@@ -795,32 +801,29 @@ function compile_affect(
795801
get(ic.callback_to_clocks, cb, Int[])
796802
end
797803

798-
f = if isnothing(aff)
799-
default
804+
if isnothing(aff)
805+
full_args = is_init && (default === SciMLBase.INITIALIZE_DEFAULT)
806+
is_init ? wrap_save_discretes(f, save_idxs; full_args) : default
800807
elseif aff isa AffectSystem
801-
compile_equational_affect(aff, sys)
808+
f = compile_equational_affect(aff, sys; kwargs...)
809+
wrap_save_discretes(f, save_idxs)
802810
elseif aff isa FunctionalAffect || aff isa ImperativeAffect
803-
compile_functional_affect(aff, sys; kwargs...)
811+
f = compile_functional_affect(aff, sys; kwargs...)
812+
wrap_save_discretes(f, save_idxs; full_args = true)
804813
end
805-
wrap_save_discretes(f, save_idxs; is_init)
806814
end
807815

808-
# Init can be: user defined function, nothing, or INITIALIZE_DEFAULT
809-
function wrap_save_discretes(f, save_idxs; is_init = false)
810-
if isempty(save_idxs) || f === SciMLBase.FINALIZE_DEFAULT || (isnothing(f) && !is_init)
811-
return f
812-
elseif f === SciMLBase.INITIALIZE_DEFAULT
813-
let save_idxs = save_idxs
814-
(c, u, t, i) -> begin
815-
f(c, u, t, i)
816+
function wrap_save_discretes(f, save_idxs; full_args = false)
817+
let save_idxs = save_idxs
818+
if full_args
819+
return (c, u, t, i) -> begin
820+
isnothing(f) || f(c, u, t, i)
816821
for idx in save_idxs
817822
SciMLBase.save_discretes!(i, idx)
818823
end
819824
end
820-
end
821-
else
822-
let save_idxs = save_idxs
823-
(i) -> begin
825+
else
826+
return (i) -> begin
824827
isnothing(f) || f(i)
825828
for idx in save_idxs
826829
SciMLBase.save_discretes!(i, idx)
@@ -831,9 +834,9 @@ function wrap_save_discretes(f, save_idxs; is_init = false)
831834
end
832835

833836
"""
834-
Initialize and Finalize for VectorContinuousCallback.
837+
Initialize and finalize for VectorContinuousCallback.
835838
"""
836-
function compile_vector_optional_affect(funs, default)
839+
function wrap_vector_optional_affect(funs, default)
837840
all(isnothing, funs) && return default
838841
return let funs = funs
839842
function (cb, u, t, integ)
@@ -844,35 +847,71 @@ function compile_vector_optional_affect(funs, default)
844847
end
845848
end
846849

847-
function compile_equational_affect(aff::AffectSystem, sys; kwargs...)
850+
function add_integrator_header(
851+
sys::AbstractSystem, integrator = gensym(:MTKIntegrator), out = :u)
852+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
853+
expr.body),
854+
expr -> Func(
855+
[DestructuredArgs(expr.args, integrator, inds = [out, :u, :p, :t])], [],
856+
expr.body)
857+
end
858+
859+
"""
860+
Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates.
861+
"""
862+
function compile_equational_affect(aff::AffectSystem, sys; reset_jumps = false, kwargs...)
848863
affsys = system(aff)
849-
aff_map = aff_to_sys(aff)
850-
sys_map = Dict([v => k for (k, v) in aff_map])
851-
ps_to_modify = discretes(aff)
852-
dvs_to_modify = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
853-
#TODO: Add an optimization for systems without algebraic equations
854-
855-
return let dvs_to_modify = dvs_to_modify, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_modify = ps_to_modify
856-
857-
@inline function affect!(integrator)
858-
pmap = Pair[]
859-
for pre_p in parameters(affsys)
860-
p = only(arguments(unwrap(pre_p)))
861-
pval = isparameter(p) ? integrator.ps[p] : integrator[p]
862-
push!(pmap, pre_p => pval)
863-
end
864-
guesses = Pair[u => integrator[aff_map[u]] for u in unknowns(affsys)]
865-
affprob = ImplicitDiscreteProblem(affsys, Pair[], (integrator.t, integrator.t), pmap; guesses, build_initializeprob = false)
864+
reinit = has_alg_equations(sys) || has_alg_equations(affsys)
865+
ps_to_update = discretes(aff)
866+
dvs_to_update = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
867+
868+
if is_explicit(aff)
869+
update_eqs = equations(affsys)
870+
update_eqs = Symbolics.fast_substitute(equations, Dict([p => unPre(p) for p in parameters(affsys)]))
871+
rhss = map(x -> x.rhs, update_eqs)
872+
lhss = map(x -> x.lhs, update_eqs)
873+
is_p = [lhs ps_to_update for lhs in lhss]
874+
875+
dvs = unknowns(sys)
876+
ps = parameters(sys)
877+
t = get_iv(sys)
878+
879+
u_idxs = indexin((@view lhss[.!is_p]), dvs)
880+
p_idxs = indexin((@view lhss[is_p]), ps)
881+
_ps = reorder_parameters(sys, ps)
882+
integ = gensym(:MTKIntegrator)
883+
884+
u_up, u_up! = build_function_wrapper(sys, (@view rhss[.!is_p]), dvs, _ps..., t; wrap_code = add_integrator_header(sys, integ, :u), expression = Val{false}, outputidxs = u_idxs)
885+
p_up, p_up! = build_function_wrapper(sys, (@view rhss[is_p]), dvs, _ps..., t; wrap_code = add_integrator_header(sys, integ, :p), expression = Val{false}, outputidxs = p_idxs)
886+
887+
return (integ) -> begin
888+
u_up!(integ)
889+
p_up!(integ)
890+
reset_jumps && reset_aggregated_jumps!(integ)
891+
end
892+
else
893+
aff_map = aff_to_sys(aff)
894+
sys_map = Dict([v => k for (k, v) in aff_map])
895+
896+
return let dvs_to_update = dvs_to_update, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_update = ps_to_update
897+
(integ) -> begin
898+
pmap = Pair[]
899+
for pre_p in parameters(affsys)
900+
p = unPre(pre_p)
901+
pval = isparameter(p) ? integ.ps[p] : integ[p]
902+
push!(pmap, pre_p => pval)
903+
end
904+
guesses = Pair[u => integ[aff_map[u]] for u in unknowns(affsys)]
905+
affprob = ImplicitDiscreteProblem(affsys, Pair[], (integ.t, integ.t), pmap; guesses, build_initializeprob = false)
866906

867-
affsol = init(affprob, SimpleIDSolve())
868-
for u in dvs_to_modify
869-
integrator[u] = affsol[sys_map[u]]
870-
end
871-
for p in ps_to_modify
872-
integrator.ps[p] = affsol[sys_map[p]]
907+
affsol = init(affprob, SimpleIDSolve())
908+
for u in dvs_to_update
909+
integ[u] = affsol[sys_map[u]]
910+
end
911+
for p in ps_to_update
912+
integ.ps[p] = affsol[sys_map[p]]
913+
end
873914
end
874-
875-
sys isa JumpSystem && reset_aggregated_jumps!(integrator)
876915
end
877916
end
878917
end

src/systems/imperative_affect.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function Base.hash(a::ImperativeAffect, s::UInt)
9999
hash(a.ctx, s)
100100
end
101101

102-
function namespace_affect(affect::ImperativeAffect, s)
102+
function namespace_affects(affect::ImperativeAffect, s)
103103
ImperativeAffect(func(affect),
104104
namespace_expr.(observed(affect), (s,)),
105105
observed_syms(affect),

0 commit comments

Comments
 (0)