Skip to content

Commit 63d4c7e

Browse files
committed
fix NoInit() error
1 parent eaf7ae1 commit 63d4c7e

File tree

5 files changed

+202
-207
lines changed

5 files changed

+202
-207
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2828
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
2929
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
3030
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
31+
ImplicitDiscreteSolve = "3263718b-31ed-49cf-8a0f-35a466e8af96"
3132
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
3233
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
3334
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
@@ -40,6 +41,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
4041
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
4142
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4243
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
44+
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
4345
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4446
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4547
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -49,7 +51,6 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
4951
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
5052
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
5153
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
52-
SimpleImplicitDiscreteSolve = "3263718b-31ed-49cf-8a0f-35a466e8af96"
5354
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
5455
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5556
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

src/ModelingToolkit.jl

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ import Moshi
5454
using Moshi.Data: @data
5555
using NonlinearSolve
5656
import SCCNonlinearSolve
57+
using ImplicitDiscreteSolve
5758
using Reexport
5859
using RecursiveArrayTools
5960
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix

src/systems/callbacks.jl

+47-57
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,9 @@ function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs = Equation[
233233
dvs = OrderedSet()
234234
params = OrderedSet()
235235
for eq in affect
236-
!haspre(eq) &&
236+
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic())
237237
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x)."
238+
end
238239
collect_vars!(dvs, params, eq, iv; op = Pre)
239240
end
240241
for eq in algeeqs
@@ -299,19 +300,19 @@ function SymbolicContinuousCallbacks(events; algeeqs::Vector{Equation} = Equatio
299300
callbacks
300301
end
301302

302-
function Base.show(io::IO, cb::SymbolicContinuousCallback)
303+
function Base.show(io::IO, cb::AbstractCallback)
303304
indent = get(io, :indent, 0)
304305
iio = IOContext(io, :indent => indent + 1)
305-
print(io, "SymbolicContinuousCallback(")
306-
print(iio, "Equations:")
306+
is_discrete(cb) ? print(io, "SymbolicDiscreteCallback(") : print(io, "SymbolicContinuousCallback(")
307+
print(iio, "Conditions:")
307308
show(iio, equations(cb))
308309
print(iio, "; ")
309310
if affects(cb) != nothing
310311
print(iio, "Affect:")
311312
show(iio, affects(cb))
312313
print(iio, ", ")
313314
end
314-
if affect_negs(cb) != nothing
315+
if !is_discrete(cb) && affect_negs(cb) != nothing
315316
print(iio, "Negative-edge affect:")
316317
show(iio, affect_negs(cb))
317318
print(iio, ", ")
@@ -328,19 +329,19 @@ function Base.show(io::IO, cb::SymbolicContinuousCallback)
328329
print(iio, ")")
329330
end
330331

331-
function Base.show(io::IO, mime::MIME"text/plain", cb::SymbolicContinuousCallback)
332+
function Base.show(io::IO, mime::MIME"text/plain", cb::AbstractCallback)
332333
indent = get(io, :indent, 0)
333334
iio = IOContext(io, :indent => indent + 1)
334-
println(io, "SymbolicContinuousCallback:")
335-
println(iio, "Equations:")
335+
is_discrete(cb) ? println(io, "SymbolicDiscreteCallback:") : println(io, "SymbolicContinuousCallback:")
336+
println(iio, "Conditions:")
336337
show(iio, mime, equations(cb))
337338
print(iio, "\n")
338339
if affects(cb) != nothing
339340
println(iio, "Affect:")
340341
show(iio, mime, affects(cb))
341342
print(iio, "\n")
342343
end
343-
if affect_negs(cb) != nothing
344+
if !is_discrete(cb) && affect_negs(cb) != nothing
344345
print(iio, "Negative-edge affect:\n")
345346
show(iio, mime, affect_negs(cb))
346347
print(iio, "\n")
@@ -394,8 +395,8 @@ Arguments:
394395
- algeeqs: Algebraic equations of the system that must be satisfied after the callback occurs.
395396
"""
396397
struct SymbolicDiscreteCallback <: AbstractCallback
397-
conditions::Any
398-
affect::Affect
398+
conditions::Union{Real, Vector{<:Real}, Vector{Equation}}
399+
affect::Union{Affect, Nothing}
399400
initialize::Union{Affect, Nothing}
400401
finalize::Union{Affect, Nothing}
401402

@@ -409,6 +410,9 @@ struct SymbolicDiscreteCallback <: AbstractCallback
409410
end # Default affect to nothing
410411
end
411412

413+
SymbolicDiscreteCallback(p::Pair, args...; kwargs...) = SymbolicDiscreteCallback(p[1], p[2])
414+
SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback, args...; kwargs...) = cb
415+
412416
"""
413417
Generate discrete callbacks.
414418
"""
@@ -438,29 +442,6 @@ function is_timed_condition(condition::T) where {T}
438442
end
439443
end
440444

441-
function Base.show(io::IO, db::SymbolicDiscreteCallback)
442-
indent = get(io, :indent, 0)
443-
iio = IOContext(io, :indent => indent + 1)
444-
println(io, "SymbolicDiscreteCallback:")
445-
println(iio, "Conditions:")
446-
print(iio, "; ")
447-
if affects(db) != nothing
448-
print(iio, "Affect:")
449-
show(iio, affects(db))
450-
print(iio, ", ")
451-
end
452-
if initialize_affects(db) != nothing
453-
print(iio, "Initialization affect:")
454-
show(iio, initialize_affects(db))
455-
print(iio, ", ")
456-
end
457-
if finalize_affects(db) != nothing
458-
print(iio, "Finalization affect:")
459-
show(iio, finalize_affects(db))
460-
end
461-
print(iio, ")")
462-
end
463-
464445
function vars!(vars, cb::SymbolicDiscreteCallback; op = Differential)
465446
if symbolic_type(conditions(cb)) == NotSymbolic
466447
if conditions(cb) isa AbstractArray
@@ -529,7 +510,7 @@ function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCa
529510
end
530511

531512
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
532-
s = foldr(hash, cb.eqs, init = s)
513+
s = foldr(hash, cb.conditions, init = s)
533514
s = hash(cb.affect, s)
534515
s = hash(cb.affect_neg, s)
535516
s = hash(cb.initialize, s)
@@ -538,8 +519,8 @@ function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
538519
end
539520

540521
function Base.hash(cb::SymbolicDiscreteCallback, s::UInt)
541-
s = hash(cb.condition, s)
542-
s = hash(cb.affects, s)
522+
s = foldr(hash, cb.conditions, init = s)
523+
s = hash(cb.affect, s)
543524
s = hash(cb.initialize, s)
544525
hash(cb.finalize, s)
545526
end
@@ -649,7 +630,9 @@ end
649630
"""
650631
Compile user-defined functional affect.
651632
"""
652-
function compile_functional_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
633+
function compile_functional_affect(affect::FunctionalAffect, cb, sys; kwargs...)
634+
dvs = unknowns(sys)
635+
ps = parameters(sys)
653636
dvs_ind = Dict(reverse(en) for en in enumerate(dvs))
654637
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))
655638

@@ -686,7 +669,18 @@ is_discrete(cb::Vector{<:AbstractCallback}) = eltype(cb) isa SymbolicDiscreteCal
686669
function generate_continuous_callbacks(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys; initial_parameters = true); kwargs...)
687670
cbs = continuous_events(sys)
688671
isempty(cbs) && return nothing
689-
generate_callback(cbs, sys; kwargs...)
672+
cb_classes = Dict{SciMLBase.RootfindOpt, Vector{SymbolicContinuousCallback}}()
673+
for cb in cbs
674+
_cbs = get!(() -> SymbolicContinuousCallback[], cb_classes, cb.rootfind)
675+
push!(_cbs, cb)
676+
end
677+
cb_classes = sort!(OrderedDict(cb_classes))
678+
compiled_callbacks = [generate_callback(cb, sys; kwargs...) for (rf, cb) in cb_classes]
679+
if length(compiled_callbacks) == 1
680+
return only(compiled_callbacks)
681+
else
682+
return CallbackSet(compiled_callbacks...)
683+
end
690684
end
691685

692686
function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys; initial_parameters = true); kwargs...)
@@ -716,9 +710,9 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
716710
finals = []
717711
for cb in cbs
718712
affect = compile_affect(cb.affect, cb, sys, default = (args...) -> ())
719-
720713
push!(affects, affect)
721-
push!(affect_negs, compile_affect(cb.affect_neg, cb, sys, default = affect))
714+
affect_neg = (cb.affect_neg === cb.affect) ? affect : compile_affect(cb.affect_neg, cb, sys, default = (args...) -> ())
715+
push!(affect_negs, affect_neg)
722716
push!(inits, compile_affect(cb.initialize, cb, sys, default = nothing))
723717
push!(finals, compile_affect(cb.finalize, cb, sys, default = nothing))
724718
end
@@ -728,8 +722,6 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
728722
eq2affect = reduce(vcat,
729723
[fill(i, num_eqs[i]) for i in eachindex(affects)])
730724
eqs = reduce(vcat, eqs)
731-
@assert length(eq2affect) == length(eqs)
732-
@assert maximum(eq2affect) == length(affects)
733725

734726
affect = function (integ, idx)
735727
affects[eq2affect[idx]](integ)
@@ -744,7 +736,7 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
744736

745737
return VectorContinuousCallback(
746738
trigger, affect, affect_neg, length(eqs); initialize, finalize,
747-
rootfind = cbs[1].rootfind, initializealg = SciMLBase.NoInit)
739+
rootfind = cbs[1].rootfind, initializealg = SciMLBase.NoInit())
748740
end
749741

750742
function generate_callback(cb, sys; kwargs...)
@@ -762,16 +754,16 @@ function generate_callback(cb, sys; kwargs...)
762754
if is_discrete(cb)
763755
if is_timed && conditions(cb) isa AbstractVector
764756
return PresetTimeCallback(trigger, affect; initialize,
765-
finalize, initializealg = SciMLBase.NoInit)
757+
finalize, initializealg = SciMLBase.NoInit())
766758
elseif is_timed
767759
return PeriodicCallback(affect, trigger; initialize, finalize)
768760
else
769761
return DiscreteCallback(trigger, affect; initialize,
770-
finalize, initializealg = SciMLBase.NoInit)
762+
finalize, initializealg = SciMLBase.NoInit())
771763
end
772764
else
773765
return ContinuousCallback(trigger, affect, affect_neg; initialize, finalize,
774-
rootfind = cb.rootfind, initializealg = SciMLBase.NoInit)
766+
rootfind = cb.rootfind, initializealg = SciMLBase.NoInit())
775767
end
776768
end
777769

@@ -793,27 +785,25 @@ Notes
793785
"""
794786
function compile_affect(
795787
aff::Union{Nothing, Affect}, cb::AbstractCallback, sys::AbstractSystem; default = nothing, kwargs...)
788+
isnothing(aff) && return default
789+
796790
save_idxs = if !(has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing)
797791
Int[]
798792
else
799793
get(ic.callback_to_clocks, cb, Int[])
800794
end
801795

802-
isnothing(aff) && return default
803-
804-
ps = parameters(aff)
805-
dvs = unknowns(aff)
806-
dvs_to_modify = setdiff(dvs, getfield.(observed(sys), :lhs))
807-
808796
if aff isa AffectSystem
809797
affsys = system(aff)
810798
aff_map = aff_to_sys(aff)
811799
sys_map = Dict([v => k for (k, v) in aff_map])
812800
reinit = has_alg_eqs(sys)
801+
ps_to_modify = discretes(aff)
802+
dvs_to_modify = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
813803

814804
function affect!(integrator)
815805
pmap = Pair[]
816-
for pre_p in previous_vals(aff)
806+
for pre_p in parameters(affsys)
817807
p = only(arguments(unwrap(pre_p)))
818808
pval = isparameter(p) ? integrator.ps[p] : integrator[p]
819809
push!(pmap, pre_p => pval)
@@ -825,17 +815,17 @@ function compile_affect(
825815
for u in dvs_to_modify
826816
integrator[u] = affsol[sys_map[u]]
827817
end
828-
for p in discretes(aff)
818+
for p in ps_to_modify
829819
integrator.ps[p] = affsol[sys_map[p]]
830820
end
831821
for idx in save_idxs
832-
SciMLBase.save_discretes!(integ, idx)
822+
SciMLBase.save_discretes!(integrator, idx)
833823
end
834824

835825
sys isa JumpSystem && reset_aggregated_jumps!(integrator)
836826
end
837827
elseif aff isa FunctionalAffect || aff isa ImperativeAffect
838-
compile_functional_affect(aff, cb, sys, dvs, ps; kwargs...)
828+
compile_functional_affect(aff, cb, sys; kwargs...)
839829
end
840830
end
841831

src/systems/imperative_affect.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ function check_assignable(sys, sym)
155155
end
156156
end
157157

158-
function compile_functional_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs...)
158+
function compile_functional_affect(affect::ImperativeAffect, cb, sys; kwargs...)
159159
#=
160160
Implementation sketch:
161161
generate observed function (oop), should save to a component array under obs_syms
@@ -179,6 +179,9 @@ function compile_functional_affect(affect::ImperativeAffect, cb, sys, dvs, ps; k
179179
return (syms_dedup, exprs_dedup)
180180
end
181181

182+
dvs = unknowns(sys)
183+
ps = parameters(sys)
184+
182185
obs_exprs = observed(affect)
183186
if !affect.skip_checks
184187
for oexpr in obs_exprs

0 commit comments

Comments
 (0)