@@ -217,7 +217,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
217
217
end # Default affect to nothing
218
218
end
219
219
220
- SymbolicContinuousCallback (p:: Pair , args... ; kwargs... ) = SymbolicContinuousCallback (p[1 ], p[2 ])
220
+ SymbolicContinuousCallback (p:: Pair , args... ; kwargs... ) = SymbolicContinuousCallback (p[1 ], p[2 ], args ... ; kwargs ... )
221
221
SymbolicContinuousCallback (cb:: SymbolicContinuousCallback , args... ; kwargs... ) = cb
222
222
223
223
make_affect (affect:: Nothing ; kwargs... ) = nothing
@@ -395,7 +395,7 @@ Arguments:
395
395
- algeeqs: Algebraic equations of the system that must be satisfied after the callback occurs.
396
396
"""
397
397
struct SymbolicDiscreteCallback <: AbstractCallback
398
- conditions:: Union{Real, Vector{<:Real}, Vector{Equation}}
398
+ conditions:: Any
399
399
affect:: Union{Affect, Nothing}
400
400
initialize:: Union{Affect, Nothing}
401
401
finalize:: Union{Affect, Nothing}
@@ -410,7 +410,7 @@ struct SymbolicDiscreteCallback <: AbstractCallback
410
410
end # Default affect to nothing
411
411
end
412
412
413
- SymbolicDiscreteCallback (p:: Pair , args... ; kwargs... ) = SymbolicDiscreteCallback (p[1 ], p[2 ])
413
+ SymbolicDiscreteCallback (p:: Pair , args... ; kwargs... ) = SymbolicDiscreteCallback (p[1 ], p[2 ], args ... ; kwargs ... )
414
414
SymbolicDiscreteCallback (cb:: SymbolicDiscreteCallback , args... ; kwargs... ) = cb
415
415
416
416
"""
630
630
"""
631
631
Compile user-defined functional affect.
632
632
"""
633
- function compile_functional_affect (affect:: FunctionalAffect , cb, sys; kwargs... )
633
+ function compile_functional_affect (affect:: FunctionalAffect , sys; kwargs... )
634
634
dvs = unknowns (sys)
635
635
ps = parameters (sys)
636
636
dvs_ind = Dict (reverse (en) for en in enumerate (dvs))
@@ -639,11 +639,9 @@ function compile_functional_affect(affect::FunctionalAffect, cb, sys; kwargs...)
639
639
if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
640
640
p_inds = [(pind = parameter_index (sys, sym)) === nothing ? sym : pind
641
641
for sym in parameters (affect)]
642
- save_idxs = get (ic. callback_to_clocks, cb, Int[])
643
642
else
644
643
ps_ind = Dict (reverse (en) for en in enumerate (ps))
645
644
p_inds = map (sym -> get (ps_ind, sym, sym), parameters (affect))
646
- save_idxs = Int[]
647
645
end
648
646
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
649
647
# (MTK should keep these symbols)
@@ -652,13 +650,9 @@ function compile_functional_affect(affect::FunctionalAffect, cb, sys; kwargs...)
652
650
p = filter (x -> ! isnothing (x[2 ]), collect (zip (parameters_syms (affect), p_inds))) |>
653
651
NamedTuple
654
652
655
- let u = u, p = p, user_affect = func (affect), ctx = context (affect),
656
- save_idxs = save_idxs
657
- function (integ)
653
+ let u = u, p = p, user_affect = func (affect), ctx = context (affect)
654
+ (integ) -> begin
658
655
user_affect (integ, u, p, ctx)
659
- for idx in save_idxs
660
- SciMLBase. save_discretes! (integ, idx)
661
- end
662
656
end
663
657
end
664
658
end
@@ -670,6 +664,8 @@ function generate_continuous_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
670
664
cbs = continuous_events (sys)
671
665
isempty (cbs) && return nothing
672
666
cb_classes = Dict {SciMLBase.RootfindOpt, Vector{SymbolicContinuousCallback}} ()
667
+
668
+ # Sort the callbacks by their rootfinding method
673
669
for cb in cbs
674
670
_cbs = get! (() -> SymbolicContinuousCallback[], cb_classes, cb. rootfind)
675
671
push! (_cbs, cb)
@@ -709,12 +705,12 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
709
705
inits = []
710
706
finals = []
711
707
for cb in cbs
712
- affect = compile_affect (cb. affect, cb, sys, default = (args ... ) -> () )
708
+ affect = compile_affect (cb. affect, cb, sys, default = nothing )
713
709
push! (affects, affect)
714
- affect_neg = (cb. affect_neg === cb. affect) ? affect : compile_affect (cb. affect_neg, cb, sys, default = (args ... ) -> () )
710
+ affect_neg = (cb. affect_neg == cb. affect) ? affect : compile_affect (cb. affect_neg, cb, sys, default = nothing )
715
711
push! (affect_negs, affect_neg)
716
- push! (inits, compile_affect (cb. initialize, cb, sys, default = nothing ))
717
- push! (finals, compile_affect (cb. finalize, cb, sys, default = nothing ))
712
+ push! (inits, compile_affect (cb. initialize, cb, sys; default = nothing , is_init = true ))
713
+ push! (finals, compile_affect (cb. finalize, cb, sys; default = nothing ))
718
714
end
719
715
720
716
# Since there may be different number of conditions and affects,
@@ -746,10 +742,16 @@ function generate_callback(cb, sys; kwargs...)
746
742
747
743
trigger = is_timed ? conditions (cb) : compile_condition (cb, sys, dvs, ps; kwargs... )
748
744
affect = compile_affect (cb. affect, cb, sys, default = (args... ) -> ())
749
- affect_neg = hasfield (typeof (cb), :affect_neg ) ?
750
- compile_affect (cb. affect_neg, cb, sys, default = affect) : nothing
751
- initialize = compile_affect (cb. initialize, cb, sys, default = SciMLBase. INITIALIZE_DEFAULT)
752
- finalize = compile_affect (cb. finalize, cb, sys, default = SciMLBase. FINALIZE_DEFAULT)
745
+ affect_neg = if is_discrete (cb)
746
+ nothing
747
+ else
748
+ (cb. affect == cb. affect_neg) ? affect : compile_affect (cb. affect_neg, cb, sys, default = nothing )
749
+ end
750
+ init = compile_affect (cb. initialize, cb, sys, default = SciMLBase. INITIALIZE_DEFAULT, is_init = true )
751
+ final = compile_affect (cb. finalize, cb, sys, default = SciMLBase. FINALIZE_DEFAULT)
752
+
753
+ initialize = isnothing (cb. initialize) ? init : ((c, u, t, i) -> init (i))
754
+ finalize = isnothing (cb. finalize) ? final : ((c, u, t, i) -> final (i))
753
755
754
756
if is_discrete (cb)
755
757
if is_timed && conditions (cb) isa AbstractVector
@@ -784,32 +786,81 @@ Notes
784
786
- `kwargs` are passed through to `Symbolics.build_function`.
785
787
"""
786
788
function compile_affect (
787
- aff:: Union{Nothing, Affect} , cb:: AbstractCallback , sys:: AbstractSystem ; default = nothing , kwargs... )
788
- isnothing (aff) && return default
789
-
789
+ aff:: Union{Nothing, Affect} , cb:: AbstractCallback , sys:: AbstractSystem ; default = nothing , is_init = false , kwargs... )
790
790
save_idxs = if ! (has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing )
791
791
Int[]
792
792
else
793
793
get (ic. callback_to_clocks, cb, Int[])
794
794
end
795
795
796
- if aff isa AffectSystem
797
- affsys = system (aff)
798
- aff_map = aff_to_sys (aff)
799
- sys_map = Dict ([v => k for (k, v) in aff_map])
800
- reinit = has_alg_eqs (sys)
801
- ps_to_modify = discretes (aff)
802
- dvs_to_modify = setdiff (unknowns (aff), getfield .(observed (sys), :lhs ))
796
+ f = if isnothing (aff)
797
+ default
798
+ elseif aff isa AffectSystem
799
+ compile_equational_affect (aff, sys)
800
+ elseif aff isa FunctionalAffect || aff isa ImperativeAffect
801
+ compile_functional_affect (aff, sys; kwargs... )
802
+ end
803
+ wrap_save_discretes (f, save_idxs; is_init)
804
+ end
805
+
806
+ # Init can be: user defined function, nothing, or INITIALIZE_DEFAULT
807
+ function wrap_save_discretes (f, save_idxs; is_init = false )
808
+ if isempty (save_idxs) || f === SciMLBase. FINALIZE_DEFAULT || (isnothing (f) && ! is_init)
809
+ return f
810
+ elseif f === SciMLBase. INITIALIZE_DEFAULT
811
+ let save_idxs = save_idxs
812
+ (c, u, t, i) -> begin
813
+ f (c, u, t, i)
814
+ for idx in save_idxs
815
+ SciMLBase. save_discretes! (i, idx)
816
+ end
817
+ end
818
+ end
819
+ else
820
+ let save_idxs = save_idxs
821
+ (i) -> begin
822
+ isnothing (f) || f (i)
823
+ for idx in save_idxs
824
+ SciMLBase. save_discretes! (i, idx)
825
+ end
826
+ end
827
+ end
828
+ end
829
+ end
830
+
831
+ """
832
+ Initialize and Finalize for VectorContinuousCallback.
833
+ """
834
+ function compile_vector_optional_affect (funs, default)
835
+ all (isnothing, funs) && return default
836
+ return let funs = funs
837
+ function (cb, u, t, integ)
838
+ for func in funs
839
+ isnothing (func) ? continue : func (integ)
840
+ end
841
+ end
842
+ end
843
+ end
844
+
845
+ function compile_equational_affect (aff:: AffectSystem , sys; kwargs... )
846
+ affsys = system (aff)
847
+ aff_map = aff_to_sys (aff)
848
+ sys_map = Dict ([v => k for (k, v) in aff_map])
849
+ ps_to_modify = discretes (aff)
850
+ dvs_to_modify = setdiff (unknowns (aff), getfield .(observed (sys), :lhs ))
851
+ # TODO : Add an optimization for systems without algebraic equations
803
852
804
- function affect! (integrator)
853
+ return let dvs_to_modify = dvs_to_modify, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_modify = ps_to_modify
854
+
855
+ @inline function affect! (integrator)
805
856
pmap = Pair[]
806
857
for pre_p in parameters (affsys)
807
858
p = only (arguments (unwrap (pre_p)))
808
859
pval = isparameter (p) ? integrator. ps[p] : integrator[p]
809
860
push! (pmap, pre_p => pval)
810
861
end
811
862
guesses = Pair[u => integrator[aff_map[u]] for u in unknowns (affsys)]
812
- affprob = ImplicitDiscreteProblem (affsys, Pair[], (0 , 1 ), pmap; guesses, build_initializeprob = reinit )
863
+ affprob = ImplicitDiscreteProblem (affsys, Pair[], (0 , 1 ), pmap; guesses, build_initializeprob = false )
813
864
814
865
affsol = init (affprob, SimpleIDSolve ())
815
866
for u in dvs_to_modify
@@ -818,28 +869,9 @@ function compile_affect(
818
869
for p in ps_to_modify
819
870
integrator. ps[p] = affsol[sys_map[p]]
820
871
end
821
- for idx in save_idxs
822
- SciMLBase. save_discretes! (integrator, idx)
823
- end
824
872
825
873
sys isa JumpSystem && reset_aggregated_jumps! (integrator)
826
874
end
827
- elseif aff isa FunctionalAffect || aff isa ImperativeAffect
828
- compile_functional_affect (aff, cb, sys; kwargs... )
829
- end
830
- end
831
-
832
- """
833
- Initialize and Finalize for VectorContinuousCallback.
834
- """
835
- function compile_vector_optional_affect (funs, default)
836
- all (isnothing, funs) && return default
837
- return let funs = funs
838
- function (cb, u, t, integ)
839
- for func in funs
840
- isnothing (func) ? continue : func (integ)
841
- end
842
- end
843
875
end
844
876
end
845
877
0 commit comments