Skip to content

Commit ddb80bc

Browse files
committed
fix: fix initialization and finalization affects
1 parent 63d4c7e commit ddb80bc

File tree

6 files changed

+93
-72
lines changed

6 files changed

+93
-72
lines changed

src/structural_transformation/bipartite_tearing/modia_tearing.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
9696
ieqs = Int[]
9797
filtered_vars = BitSet()
9898
free_eqs = free_equations(graph, var_sccs, var_eq_matching, varfilter)
99-
is_overdetemined = !isempty(free_eqs)
99+
is_overdetermined = !isempty(free_eqs)
100100
for vars in var_sccs
101101
for var in vars
102102
if varfilter(var)
@@ -112,7 +112,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
112112
filtered_vars, isder)
113113
# If the systems is overdetemined, we cannot assume the free equations
114114
# will not form algebraic loops with equations in the sccs.
115-
if !is_overdetemined
115+
if !is_overdetermined
116116
vargraph.ne = 0
117117
for var in vars
118118
vargraph.matching[var] = unassigned
@@ -121,7 +121,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
121121
empty!(ieqs)
122122
empty!(filtered_vars)
123123
end
124-
if is_overdetemined
124+
if is_overdetermined
125125
free_vars = findall(x -> !(x isa Int), var_eq_matching)
126126
tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, free_eqs,
127127
BitSet(free_vars), isder)

src/structural_transformation/utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
218218
all_int_vars = true
219219
coeffs === nothing || empty!(coeffs)
220220
empty!(to_rm)
221+
221222
for j in 𝑠neighbors(graph, ieq)
222223
var = fullvars[j]
223224
isirreducible(var) && (all_int_vars = false; continue)

src/systems/callbacks.jl

+83-51
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
217217
end # Default affect to nothing
218218
end
219219

220-
SymbolicContinuousCallback(p::Pair, args...; kwargs...) = SymbolicContinuousCallback(p[1], p[2])
220+
SymbolicContinuousCallback(p::Pair, args...; kwargs...) = SymbolicContinuousCallback(p[1], p[2], args...; kwargs...)
221221
SymbolicContinuousCallback(cb::SymbolicContinuousCallback, args...; kwargs...) = cb
222222

223223
make_affect(affect::Nothing; kwargs...) = nothing
@@ -395,7 +395,7 @@ Arguments:
395395
- algeeqs: Algebraic equations of the system that must be satisfied after the callback occurs.
396396
"""
397397
struct SymbolicDiscreteCallback <: AbstractCallback
398-
conditions::Union{Real, Vector{<:Real}, Vector{Equation}}
398+
conditions::Any
399399
affect::Union{Affect, Nothing}
400400
initialize::Union{Affect, Nothing}
401401
finalize::Union{Affect, Nothing}
@@ -410,7 +410,7 @@ struct SymbolicDiscreteCallback <: AbstractCallback
410410
end # Default affect to nothing
411411
end
412412

413-
SymbolicDiscreteCallback(p::Pair, args...; kwargs...) = SymbolicDiscreteCallback(p[1], p[2])
413+
SymbolicDiscreteCallback(p::Pair, args...; kwargs...) = SymbolicDiscreteCallback(p[1], p[2], args...; kwargs...)
414414
SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback, args...; kwargs...) = cb
415415

416416
"""
@@ -630,7 +630,7 @@ end
630630
"""
631631
Compile user-defined functional affect.
632632
"""
633-
function compile_functional_affect(affect::FunctionalAffect, cb, sys; kwargs...)
633+
function compile_functional_affect(affect::FunctionalAffect, sys; kwargs...)
634634
dvs = unknowns(sys)
635635
ps = parameters(sys)
636636
dvs_ind = Dict(reverse(en) for en in enumerate(dvs))
@@ -639,11 +639,9 @@ function compile_functional_affect(affect::FunctionalAffect, cb, sys; kwargs...)
639639
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
640640
p_inds = [(pind = parameter_index(sys, sym)) === nothing ? sym : pind
641641
for sym in parameters(affect)]
642-
save_idxs = get(ic.callback_to_clocks, cb, Int[])
643642
else
644643
ps_ind = Dict(reverse(en) for en in enumerate(ps))
645644
p_inds = map(sym -> get(ps_ind, sym, sym), parameters(affect))
646-
save_idxs = Int[]
647645
end
648646
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
649647
# (MTK should keep these symbols)
@@ -652,13 +650,9 @@ function compile_functional_affect(affect::FunctionalAffect, cb, sys; kwargs...)
652650
p = filter(x -> !isnothing(x[2]), collect(zip(parameters_syms(affect), p_inds))) |>
653651
NamedTuple
654652

655-
let u = u, p = p, user_affect = func(affect), ctx = context(affect),
656-
save_idxs = save_idxs
657-
function (integ)
653+
let u = u, p = p, user_affect = func(affect), ctx = context(affect)
654+
(integ) -> begin
658655
user_affect(integ, u, p, ctx)
659-
for idx in save_idxs
660-
SciMLBase.save_discretes!(integ, idx)
661-
end
662656
end
663657
end
664658
end
@@ -670,6 +664,8 @@ function generate_continuous_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
670664
cbs = continuous_events(sys)
671665
isempty(cbs) && return nothing
672666
cb_classes = Dict{SciMLBase.RootfindOpt, Vector{SymbolicContinuousCallback}}()
667+
668+
# Sort the callbacks by their rootfinding method
673669
for cb in cbs
674670
_cbs = get!(() -> SymbolicContinuousCallback[], cb_classes, cb.rootfind)
675671
push!(_cbs, cb)
@@ -709,12 +705,12 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
709705
inits = []
710706
finals = []
711707
for cb in cbs
712-
affect = compile_affect(cb.affect, cb, sys, default = (args...) -> ())
708+
affect = compile_affect(cb.affect, cb, sys, default = nothing)
713709
push!(affects, affect)
714-
affect_neg = (cb.affect_neg === cb.affect) ? affect : compile_affect(cb.affect_neg, cb, sys, default = (args...) -> ())
710+
affect_neg = (cb.affect_neg == cb.affect) ? affect : compile_affect(cb.affect_neg, cb, sys, default = nothing)
715711
push!(affect_negs, affect_neg)
716-
push!(inits, compile_affect(cb.initialize, cb, sys, default = nothing))
717-
push!(finals, compile_affect(cb.finalize, cb, sys, default = nothing))
712+
push!(inits, compile_affect(cb.initialize, cb, sys; default = nothing, is_init = true))
713+
push!(finals, compile_affect(cb.finalize, cb, sys; default = nothing))
718714
end
719715

720716
# Since there may be different number of conditions and affects,
@@ -746,10 +742,16 @@ function generate_callback(cb, sys; kwargs...)
746742

747743
trigger = is_timed ? conditions(cb) : compile_condition(cb, sys, dvs, ps; kwargs...)
748744
affect = compile_affect(cb.affect, cb, sys, default = (args...) -> ())
749-
affect_neg = hasfield(typeof(cb), :affect_neg) ?
750-
compile_affect(cb.affect_neg, cb, sys, default = affect) : nothing
751-
initialize = compile_affect(cb.initialize, cb, sys, default = SciMLBase.INITIALIZE_DEFAULT)
752-
finalize = compile_affect(cb.finalize, cb, sys, default = SciMLBase.FINALIZE_DEFAULT)
745+
affect_neg = if is_discrete(cb)
746+
nothing
747+
else
748+
(cb.affect == cb.affect_neg) ? affect : compile_affect(cb.affect_neg, cb, sys, default = nothing)
749+
end
750+
init = compile_affect(cb.initialize, cb, sys, default = SciMLBase.INITIALIZE_DEFAULT, is_init = true)
751+
final = compile_affect(cb.finalize, cb, sys, default = SciMLBase.FINALIZE_DEFAULT)
752+
753+
initialize = isnothing(cb.initialize) ? init : ((c, u, t, i) -> init(i))
754+
finalize = isnothing(cb.finalize) ? final : ((c, u, t, i) -> final(i))
753755

754756
if is_discrete(cb)
755757
if is_timed && conditions(cb) isa AbstractVector
@@ -784,32 +786,81 @@ Notes
784786
- `kwargs` are passed through to `Symbolics.build_function`.
785787
"""
786788
function compile_affect(
787-
aff::Union{Nothing, Affect}, cb::AbstractCallback, sys::AbstractSystem; default = nothing, kwargs...)
788-
isnothing(aff) && return default
789-
789+
aff::Union{Nothing, Affect}, cb::AbstractCallback, sys::AbstractSystem; default = nothing, is_init = false, kwargs...)
790790
save_idxs = if !(has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing)
791791
Int[]
792792
else
793793
get(ic.callback_to_clocks, cb, Int[])
794794
end
795795

796-
if aff isa AffectSystem
797-
affsys = system(aff)
798-
aff_map = aff_to_sys(aff)
799-
sys_map = Dict([v => k for (k, v) in aff_map])
800-
reinit = has_alg_eqs(sys)
801-
ps_to_modify = discretes(aff)
802-
dvs_to_modify = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
796+
f = if isnothing(aff)
797+
default
798+
elseif aff isa AffectSystem
799+
compile_equational_affect(aff, sys)
800+
elseif aff isa FunctionalAffect || aff isa ImperativeAffect
801+
compile_functional_affect(aff, sys; kwargs...)
802+
end
803+
wrap_save_discretes(f, save_idxs; is_init)
804+
end
805+
806+
# Init can be: user defined function, nothing, or INITIALIZE_DEFAULT
807+
function wrap_save_discretes(f, save_idxs; is_init = false)
808+
if isempty(save_idxs) || f === SciMLBase.FINALIZE_DEFAULT || (isnothing(f) && !is_init)
809+
return f
810+
elseif f === SciMLBase.INITIALIZE_DEFAULT
811+
let save_idxs = save_idxs
812+
(c, u, t, i) -> begin
813+
f(c, u, t, i)
814+
for idx in save_idxs
815+
SciMLBase.save_discretes!(i, idx)
816+
end
817+
end
818+
end
819+
else
820+
let save_idxs = save_idxs
821+
(i) -> begin
822+
isnothing(f) || f(i)
823+
for idx in save_idxs
824+
SciMLBase.save_discretes!(i, idx)
825+
end
826+
end
827+
end
828+
end
829+
end
830+
831+
"""
832+
Initialize and Finalize for VectorContinuousCallback.
833+
"""
834+
function compile_vector_optional_affect(funs, default)
835+
all(isnothing, funs) && return default
836+
return let funs = funs
837+
function (cb, u, t, integ)
838+
for func in funs
839+
isnothing(func) ? continue : func(integ)
840+
end
841+
end
842+
end
843+
end
844+
845+
function compile_equational_affect(aff::AffectSystem, sys; kwargs...)
846+
affsys = system(aff)
847+
aff_map = aff_to_sys(aff)
848+
sys_map = Dict([v => k for (k, v) in aff_map])
849+
ps_to_modify = discretes(aff)
850+
dvs_to_modify = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
851+
#TODO: Add an optimization for systems without algebraic equations
803852

804-
function affect!(integrator)
853+
return let dvs_to_modify = dvs_to_modify, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_modify = ps_to_modify
854+
855+
@inline function affect!(integrator)
805856
pmap = Pair[]
806857
for pre_p in parameters(affsys)
807858
p = only(arguments(unwrap(pre_p)))
808859
pval = isparameter(p) ? integrator.ps[p] : integrator[p]
809860
push!(pmap, pre_p => pval)
810861
end
811862
guesses = Pair[u => integrator[aff_map[u]] for u in unknowns(affsys)]
812-
affprob = ImplicitDiscreteProblem(affsys, Pair[], (0, 1), pmap; guesses, build_initializeprob = reinit)
863+
affprob = ImplicitDiscreteProblem(affsys, Pair[], (0, 1), pmap; guesses, build_initializeprob = false)
813864

814865
affsol = init(affprob, SimpleIDSolve())
815866
for u in dvs_to_modify
@@ -818,28 +869,9 @@ function compile_affect(
818869
for p in ps_to_modify
819870
integrator.ps[p] = affsol[sys_map[p]]
820871
end
821-
for idx in save_idxs
822-
SciMLBase.save_discretes!(integrator, idx)
823-
end
824872

825873
sys isa JumpSystem && reset_aggregated_jumps!(integrator)
826874
end
827-
elseif aff isa FunctionalAffect || aff isa ImperativeAffect
828-
compile_functional_affect(aff, cb, sys; kwargs...)
829-
end
830-
end
831-
832-
"""
833-
Initialize and Finalize for VectorContinuousCallback.
834-
"""
835-
function compile_vector_optional_affect(funs, default)
836-
all(isnothing, funs) && return default
837-
return let funs = funs
838-
function (cb, u, t, integ)
839-
for func in funs
840-
isnothing(func) ? continue : func(integ)
841-
end
842-
end
843875
end
844876
end
845877

src/systems/discrete_system/implicit_discrete_system.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ function flatten(sys::ImplicitDiscreteSystem, noeqs = false)
264264
end
265265

266266
function generate_function(
267-
sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
267+
sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, cachesyms::Tuple = (), kwargs...)
268268
iv = get_iv(sys)
269269
# Algebraic equations get shifted forward 1, to match with differential equations
270270
exprs = map(equations(sys)) do eq
@@ -280,8 +280,9 @@ function generate_function(
280280

281281
u_next = map(Shift(iv, 1), dvs)
282282
u = dvs
283+
p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
283284
build_function_wrapper(
284-
sys, exprs, u_next, u, ps..., iv; p_start = 3, extra_assignments, kwargs...)
285+
sys, exprs, u_next, u, p..., iv; p_start = 3, extra_assignments, kwargs...)
285286
end
286287

287288
function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)

src/systems/imperative_affect.jl

+2-16
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,6 @@ function namespace_affect(affect::ImperativeAffect, s)
109109
affect.skip_checks)
110110
end
111111

112-
function compile_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs...)
113-
compile_functional_affect(affect, cb, sys, dvs, ps; kwargs...)
114-
end
115-
116112
function invalid_variables(sys, expr)
117113
filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init = []))
118114
end
@@ -155,7 +151,7 @@ function check_assignable(sys, sym)
155151
end
156152
end
157153

158-
function compile_functional_affect(affect::ImperativeAffect, cb, sys; kwargs...)
154+
function compile_functional_affect(affect::ImperativeAffect, sys; kwargs...)
159155
#=
160156
Implementation sketch:
161157
generate observed function (oop), should save to a component array under obs_syms
@@ -235,14 +231,8 @@ function compile_functional_affect(affect::ImperativeAffect, cb, sys; kwargs...)
235231

236232
upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))
237233

238-
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
239-
save_idxs = get(ic.callback_to_clocks, cb, Int[])
240-
else
241-
save_idxs = Int[]
242-
end
243-
244234
let user_affect = func(affect), ctx = context(affect)
245-
function (integ)
235+
@inline function (integ)
246236
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
247237
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
248238
upd_component_array = NamedTuple{mod_names}(modvals)
@@ -256,10 +246,6 @@ function compile_functional_affect(affect::ImperativeAffect, cb, sys; kwargs...)
256246

257247
# write the new values back to the integrator
258248
_generated_writeback(integ, upd_funs, upd_vals)
259-
260-
for idx in save_idxs
261-
SciMLBase.save_discretes!(integ, idx)
262-
end
263249
end
264250
end
265251
end

src/systems/systemstructure.jl

+1
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,7 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
688688
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
689689
dummy_derivative = true,
690690
kwargs...)
691+
691692
if fully_determined isa Bool
692693
check_consistency &= fully_determined
693694
else

0 commit comments

Comments
 (0)