@@ -69,6 +69,7 @@ struct AffectSystem
69
69
discretes:: Vector
70
70
""" Maps the symbols of unknowns/observed in the ImplicitDiscreteSystem to its corresponding unknown/parameter in the parent system."""
71
71
aff_to_sys:: Dict
72
+ explicit:: Bool
72
73
end
73
74
74
75
system (a:: AffectSystem ) = a. system
@@ -77,6 +78,7 @@ unknowns(a::AffectSystem) = a.unknowns
77
78
parameters (a:: AffectSystem ) = a. parameters
78
79
aff_to_sys (a:: AffectSystem ) = a. aff_to_sys
79
80
previous_vals (a:: AffectSystem ) = parameters (system (a))
81
+ is_explicit (a:: AffectSystem ) = a. explicit
80
82
81
83
function Base. show (iio:: IO , aff:: AffectSystem )
82
84
eqs = vcat (equations (system (aff)), observed (system (aff)))
@@ -105,6 +107,8 @@ Base.nameof(::Pre) = :Pre
105
107
Base. show (io:: IO , x:: Pre ) = print (io, " Pre" )
106
108
input_timedomain (:: Pre , _ = nothing ) = ContinuousClock ()
107
109
output_timedomain (:: Pre , _ = nothing ) = ContinuousClock ()
110
+ unPre (x:: Num ) = unPre (unwrap (x))
111
+ unPre (x:: BasicSymbolic ) = operation (x) isa Pre ? only (arguments (x)) : x
108
112
109
113
function (p:: Pre )(x)
110
114
iw = Symbolics. iswrapped (x)
@@ -229,24 +233,28 @@ function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs = Equation[
229
233
isempty (affect) && return nothing
230
234
isempty (algeeqs) && @warn " No algebraic equations were found. If the system has no algebraic equations, this can be disregarded. Otherwise pass in `algeeqs` to the SymbolicContinuousCallback constructor."
231
235
236
+ explicit = true
232
237
affect = scalarize (affect)
233
238
dvs = OrderedSet ()
234
239
params = OrderedSet ()
240
+ params = OrderedSet ()
235
241
for eq in affect
236
242
if ! haspre (eq) && ! (symbolic_type (eq. rhs) === NotSymbolic ())
237
243
@warn " Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x)."
244
+ explicit = false
238
245
end
239
246
collect_vars! (dvs, params, eq, iv; op = Pre)
240
247
end
241
248
for eq in algeeqs
242
249
collect_vars! (dvs, params, eq, iv)
250
+ expilcit = false
243
251
end
244
252
if isnothing (iv)
245
253
iv = isempty (dvs) ? iv : only (arguments (dvs[1 ]))
246
254
isnothing (iv) && @warn " No independent variable specified and could not be inferred. If the iv appears in an affect equation explicitly, like x ~ t + 1, then it must be specified as an argument to the SymbolicContinuousCallback or SymbolicDiscreteCallback constructor. Otherwise this warning can be disregarded."
247
255
end
248
256
249
- # System parameters should become unknowns in the ImplicitDiscreteSystem.
257
+ # Parameters in affect equations should become unknowns in the ImplicitDiscreteSystem.
250
258
cb_params = Any[]
251
259
discretes = Any[]
252
260
p_as_dvs = Any[]
@@ -268,15 +276,15 @@ function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs = Equation[
268
276
aff_map = Dict (zip (p_as_dvs, discretes))
269
277
rev_map = Dict ([v => k for (k, v) in aff_map])
270
278
affect = Symbolics. substitute (affect, rev_map)
271
- @mtkbuild affectsys = ImplicitDiscreteSystem (vcat (affect, algeeqs), iv, collect (union (dvs, p_as_dvs)), cb_params)
279
+ @named affectsys = ImplicitDiscreteSystem (vcat (affect, algeeqs), iv, collect (union (dvs, p_as_dvs)), cb_params)
272
280
# get accessed parameters p from Pre(p) in the callback parameters
273
- params = filter (isparameter, map (x -> only ( arguments ( unwrap (x)) ), cb_params))
281
+ params = filter (isparameter, map (x -> unPre (x ), cb_params))
274
282
# add unknowns to the map
275
283
for u in dvs
276
284
aff_map[u] = u
277
285
end
278
286
279
- return AffectSystem (affectsys, collect (dvs), params, discretes, aff_map)
287
+ return AffectSystem (affectsys, collect (dvs), params, discretes, aff_map, explicit )
280
288
end
281
289
282
290
function make_affect (affect; kwargs... )
468
476
# ######### Namespacing Utilities ###########
469
477
# ###########################################
470
478
471
- function namespace_affect (affect:: FunctionalAffect , s)
479
+ function namespace_affects (affect:: FunctionalAffect , s)
472
480
FunctionalAffect (func (affect),
473
481
renamespace .((s,), unknowns (affect)),
474
482
unknowns_syms (affect),
@@ -478,35 +486,35 @@ function namespace_affect(affect::FunctionalAffect, s)
478
486
context (affect))
479
487
end
480
488
481
- function namespace_affect (affect:: AffectSystem , s)
489
+ function namespace_affects (affect:: AffectSystem , s)
482
490
AffectSystem (renamespace (s, system (affect)),
483
491
renamespace .((s,), unknowns (affect)),
484
492
renamespace .((s,), parameters (affect)),
485
493
renamespace .((s,), discretes (affect)),
486
- Dict ([k => renamespace (s, v) for (k, v) in aff_to_sys (affect)]))
494
+ Dict ([k => renamespace (s, v) for (k, v) in aff_to_sys (affect)]), is_explicit (affect) )
487
495
end
488
- namespace_affect (af:: Nothing , s) = nothing
496
+ namespace_affects (af:: Nothing , s) = nothing
489
497
490
498
function namespace_callback (cb:: SymbolicContinuousCallback , s):: SymbolicContinuousCallback
491
499
SymbolicContinuousCallback (
492
500
namespace_equation .(equations (cb), (s,)),
493
- namespace_affect (affects (cb), s),
494
- affect_neg = namespace_affect (affect_negs (cb), s),
495
- initialize = namespace_affect (initialize_affects (cb), s),
496
- finalize = namespace_affect (finalize_affects (cb), s),
501
+ namespace_affects (affects (cb), s),
502
+ affect_neg = namespace_affects (affect_negs (cb), s),
503
+ initialize = namespace_affects (initialize_affects (cb), s),
504
+ finalize = namespace_affects (finalize_affects (cb), s),
497
505
rootfind = cb. rootfind)
498
506
end
499
507
500
- function namespace_condition (condition, s)
508
+ function namespace_conditions (condition, s)
501
509
is_timed_condition (condition) ? condition : namespace_expr (condition, s)
502
510
end
503
511
504
512
function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
505
513
SymbolicDiscreteCallback (
506
- namespace_condition ( condition (cb), s),
514
+ namespace_conditions ( conditions (cb), s),
507
515
namespace_affects (affects (cb), s),
508
- namespace_affects (initialize_affects (cb), s),
509
- namespace_affects (finalize_affects (cb), s))
516
+ initialize = namespace_affects (initialize_affects (cb), s),
517
+ finalize = namespace_affects (finalize_affects (cb), s))
510
518
end
511
519
512
520
function Base. hash (cb:: SymbolicContinuousCallback , s:: UInt )
@@ -623,8 +631,6 @@ function compile_condition(cbs::Union{AbstractCallback, Vector{<:AbstractCallbac
623
631
end
624
632
end
625
633
end
626
-
627
- cond
628
634
end
629
635
630
636
"""
@@ -707,12 +713,12 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
707
713
inits = []
708
714
finals = []
709
715
for cb in cbs
710
- affect = compile_affect (cb. affect, cb, sys, default = EMPTY_FUNCTION)
716
+ affect = compile_affect (cb. affect, cb, sys, default = EMPTY_FUNCTION, kwargs ... )
711
717
push! (affects, affect)
712
- affect_neg = (cb. affect_neg === cb. affect) ? affect : compile_affect (cb. affect_neg, cb, sys, default = EMPTY_FUNCTION)
718
+ affect_neg = (cb. affect_neg === cb. affect) ? affect : compile_affect (cb. affect_neg, cb, sys, default = EMPTY_FUNCTION, kwargs ... )
713
719
push! (affect_negs, affect_neg)
714
- push! (inits, compile_affect (cb. initialize, cb, sys; default = nothing , is_init = true ))
715
- push! (finals, compile_affect (cb. finalize, cb, sys; default = nothing ))
720
+ push! (inits, compile_affect (cb. initialize, cb, sys; default = nothing , is_init = true ), kwargs ... )
721
+ push! (finals, compile_affect (cb. finalize, cb, sys; default = nothing ), kwargs ... )
716
722
end
717
723
718
724
# Since there may be different number of conditions and affects,
@@ -729,8 +735,8 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
729
735
isnothing (f) && return
730
736
f (integ)
731
737
end
732
- initialize = compile_vector_optional_affect (inits, SciMLBase. INITIALIZE_DEFAULT)
733
- finalize = compile_vector_optional_affect (finals, SciMLBase. FINALIZE_DEFAULT)
738
+ initialize = wrap_vector_optional_affect (inits, SciMLBase. INITIALIZE_DEFAULT)
739
+ finalize = wrap_vector_optional_affect (finals, SciMLBase. FINALIZE_DEFAULT)
734
740
735
741
return VectorContinuousCallback (
736
742
trigger, affect, affect_neg, length (eqs); initialize, finalize,
@@ -743,14 +749,14 @@ function generate_callback(cb, sys; kwargs...)
743
749
ps = parameters (sys; initial_parameters = true )
744
750
745
751
trigger = is_timed ? conditions (cb) : compile_condition (cb, sys, dvs, ps; kwargs... )
746
- affect = compile_affect (cb. affect, cb, sys, default = EMPTY_FUNCTION)
752
+ affect = compile_affect (cb. affect, cb, sys, default = EMPTY_FUNCTION, kwargs ... )
747
753
affect_neg = if is_discrete (cb)
748
754
nothing
749
755
else
750
- (cb. affect === cb. affect_neg) ? affect : compile_affect (cb. affect_neg, cb, sys, default = EMPTY_FUNCTION)
756
+ (cb. affect === cb. affect_neg) ? affect : compile_affect (cb. affect_neg, cb, sys, default = EMPTY_FUNCTION, kwargs ... )
751
757
end
752
- init = compile_affect (cb. initialize, cb, sys, default = SciMLBase. INITIALIZE_DEFAULT, is_init = true )
753
- final = compile_affect (cb. finalize, cb, sys, default = SciMLBase. FINALIZE_DEFAULT)
758
+ init = compile_affect (cb. initialize, cb, sys, default = SciMLBase. INITIALIZE_DEFAULT, is_init = true , kwargs ... )
759
+ final = compile_affect (cb. finalize, cb, sys, default = SciMLBase. FINALIZE_DEFAULT, kwargs ... )
754
760
755
761
initialize = isnothing (cb. initialize) ? init : ((c, u, t, i) -> init (i))
756
762
finalize = isnothing (cb. finalize) ? final : ((c, u, t, i) -> final (i))
@@ -795,32 +801,29 @@ function compile_affect(
795
801
get (ic. callback_to_clocks, cb, Int[])
796
802
end
797
803
798
- f = if isnothing (aff)
799
- default
804
+ if isnothing (aff)
805
+ full_args = is_init && (default === SciMLBase. INITIALIZE_DEFAULT)
806
+ is_init ? wrap_save_discretes (f, save_idxs; full_args) : default
800
807
elseif aff isa AffectSystem
801
- compile_equational_affect (aff, sys)
808
+ f = compile_equational_affect (aff, sys; kwargs... )
809
+ wrap_save_discretes (f, save_idxs)
802
810
elseif aff isa FunctionalAffect || aff isa ImperativeAffect
803
- compile_functional_affect (aff, sys; kwargs... )
811
+ f = compile_functional_affect (aff, sys; kwargs... )
812
+ wrap_save_discretes (f, save_idxs; full_args = true )
804
813
end
805
- wrap_save_discretes (f, save_idxs; is_init)
806
814
end
807
815
808
- # Init can be: user defined function, nothing, or INITIALIZE_DEFAULT
809
- function wrap_save_discretes (f, save_idxs; is_init = false )
810
- if isempty (save_idxs) || f === SciMLBase. FINALIZE_DEFAULT || (isnothing (f) && ! is_init)
811
- return f
812
- elseif f === SciMLBase. INITIALIZE_DEFAULT
813
- let save_idxs = save_idxs
814
- (c, u, t, i) -> begin
815
- f (c, u, t, i)
816
+ function wrap_save_discretes (f, save_idxs; full_args = false )
817
+ let save_idxs = save_idxs
818
+ if full_args
819
+ return (c, u, t, i) -> begin
820
+ isnothing (f) || f (c, u, t, i)
816
821
for idx in save_idxs
817
822
SciMLBase. save_discretes! (i, idx)
818
823
end
819
824
end
820
- end
821
- else
822
- let save_idxs = save_idxs
823
- (i) -> begin
825
+ else
826
+ return (i) -> begin
824
827
isnothing (f) || f (i)
825
828
for idx in save_idxs
826
829
SciMLBase. save_discretes! (i, idx)
@@ -831,9 +834,9 @@ function wrap_save_discretes(f, save_idxs; is_init = false)
831
834
end
832
835
833
836
"""
834
- Initialize and Finalize for VectorContinuousCallback.
837
+ Initialize and finalize for VectorContinuousCallback.
835
838
"""
836
- function compile_vector_optional_affect (funs, default)
839
+ function wrap_vector_optional_affect (funs, default)
837
840
all (isnothing, funs) && return default
838
841
return let funs = funs
839
842
function (cb, u, t, integ)
@@ -844,35 +847,71 @@ function compile_vector_optional_affect(funs, default)
844
847
end
845
848
end
846
849
847
- function compile_equational_affect (aff:: AffectSystem , sys; kwargs... )
850
+ function add_integrator_header (
851
+ sys:: AbstractSystem , integrator = gensym (:MTKIntegrator ), out = :u )
852
+ expr -> Func ([DestructuredArgs (expr. args, integrator, inds = [:u , :p , :t ])], [],
853
+ expr. body),
854
+ expr -> Func (
855
+ [DestructuredArgs (expr. args, integrator, inds = [out, :u , :p , :t ])], [],
856
+ expr. body)
857
+ end
858
+
859
+ """
860
+ Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates.
861
+ """
862
+ function compile_equational_affect (aff:: AffectSystem , sys; reset_jumps = false , kwargs... )
848
863
affsys = system (aff)
849
- aff_map = aff_to_sys (aff)
850
- sys_map = Dict ([v => k for (k, v) in aff_map])
851
- ps_to_modify = discretes (aff)
852
- dvs_to_modify = setdiff (unknowns (aff), getfield .(observed (sys), :lhs ))
853
- # TODO : Add an optimization for systems without algebraic equations
854
-
855
- return let dvs_to_modify = dvs_to_modify, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_modify = ps_to_modify
856
-
857
- @inline function affect! (integrator)
858
- pmap = Pair[]
859
- for pre_p in parameters (affsys)
860
- p = only (arguments (unwrap (pre_p)))
861
- pval = isparameter (p) ? integrator. ps[p] : integrator[p]
862
- push! (pmap, pre_p => pval)
863
- end
864
- guesses = Pair[u => integrator[aff_map[u]] for u in unknowns (affsys)]
865
- affprob = ImplicitDiscreteProblem (affsys, Pair[], (integrator. t, integrator. t), pmap; guesses, build_initializeprob = false )
864
+ reinit = has_alg_equations (sys) || has_alg_equations (affsys)
865
+ ps_to_update = discretes (aff)
866
+ dvs_to_update = setdiff (unknowns (aff), getfield .(observed (sys), :lhs ))
867
+
868
+ if is_explicit (aff)
869
+ update_eqs = equations (affsys)
870
+ update_eqs = Symbolics. fast_substitute (equations, Dict ([p => unPre (p) for p in parameters (affsys)]))
871
+ rhss = map (x -> x. rhs, update_eqs)
872
+ lhss = map (x -> x. lhs, update_eqs)
873
+ is_p = [lhs ∈ ps_to_update for lhs in lhss]
874
+
875
+ dvs = unknowns (sys)
876
+ ps = parameters (sys)
877
+ t = get_iv (sys)
878
+
879
+ u_idxs = indexin ((@view lhss[.! is_p]), dvs)
880
+ p_idxs = indexin ((@view lhss[is_p]), ps)
881
+ _ps = reorder_parameters (sys, ps)
882
+ integ = gensym (:MTKIntegrator )
883
+
884
+ u_up, u_up! = build_function_wrapper (sys, (@view rhss[.! is_p]), dvs, _ps... , t; wrap_code = add_integrator_header (sys, integ, :u ), expression = Val{false }, outputidxs = u_idxs)
885
+ p_up, p_up! = build_function_wrapper (sys, (@view rhss[is_p]), dvs, _ps... , t; wrap_code = add_integrator_header (sys, integ, :p ), expression = Val{false }, outputidxs = p_idxs)
886
+
887
+ return (integ) -> begin
888
+ u_up! (integ)
889
+ p_up! (integ)
890
+ reset_jumps && reset_aggregated_jumps! (integ)
891
+ end
892
+ else
893
+ aff_map = aff_to_sys (aff)
894
+ sys_map = Dict ([v => k for (k, v) in aff_map])
895
+
896
+ return let dvs_to_update = dvs_to_update, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_update = ps_to_update
897
+ (integ) -> begin
898
+ pmap = Pair[]
899
+ for pre_p in parameters (affsys)
900
+ p = unPre (pre_p)
901
+ pval = isparameter (p) ? integ. ps[p] : integ[p]
902
+ push! (pmap, pre_p => pval)
903
+ end
904
+ guesses = Pair[u => integ[aff_map[u]] for u in unknowns (affsys)]
905
+ affprob = ImplicitDiscreteProblem (affsys, Pair[], (integ. t, integ. t), pmap; guesses, build_initializeprob = false )
866
906
867
- affsol = init (affprob, SimpleIDSolve ())
868
- for u in dvs_to_modify
869
- integrator[u] = affsol[sys_map[u]]
870
- end
871
- for p in ps_to_modify
872
- integrator. ps[p] = affsol[sys_map[p]]
907
+ affsol = init (affprob, SimpleIDSolve ())
908
+ for u in dvs_to_update
909
+ integ[u] = affsol[sys_map[u]]
910
+ end
911
+ for p in ps_to_update
912
+ integ. ps[p] = affsol[sys_map[p]]
913
+ end
873
914
end
874
-
875
- sys isa JumpSystem && reset_aggregated_jumps! (integrator)
876
915
end
877
916
end
878
917
end
0 commit comments