@@ -106,15 +106,25 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
106
106
+ `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
107
107
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
108
108
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
109
+
110
+ DAEs will be reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`) after callbacks are applied.
111
+ This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
112
+ that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
113
+ initialization.
109
114
"""
110
115
struct SymbolicContinuousCallback
111
116
eqs:: Vector{Equation}
112
117
affect:: Union{Vector{Equation}, FunctionalAffect}
113
118
affect_neg:: Union{Vector{Equation}, FunctionalAffect, Nothing}
114
119
rootfind:: SciMLBase.RootfindOpt
115
- function SymbolicContinuousCallback (; eqs:: Vector{Equation} , affect = NULL_AFFECT,
116
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
117
- new (eqs, make_affect (affect), make_affect (affect_neg), rootfind)
120
+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
121
+ function SymbolicContinuousCallback (;
122
+ eqs:: Vector{Equation} ,
123
+ affect = NULL_AFFECT,
124
+ affect_neg = affect,
125
+ rootfind = SciMLBase. LeftRootFind,
126
+ reinitializealg = SciMLBase. CheckInit ())
127
+ new (eqs, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
118
128
end # Default affect to nothing
119
129
end
120
130
make_affect (affect) = affect
@@ -183,6 +193,12 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback})
183
193
mapreduce (affect_negs, vcat, cbs, init = Equation[])
184
194
end
185
195
196
+ reinitialization_alg (cb:: SymbolicContinuousCallback ) = cb. reinitializealg
197
+ function reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} )
198
+ mapreduce (
199
+ reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
200
+ end
201
+
186
202
namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
187
203
namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
188
204
namespace_affects (:: Nothing , s) = nothing
@@ -225,11 +241,13 @@ struct SymbolicDiscreteCallback
225
241
# TODO : Iterative
226
242
condition:: Any
227
243
affects:: Any
244
+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
228
245
229
- function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT)
246
+ function SymbolicDiscreteCallback (
247
+ condition, affects = NULL_AFFECT, reinitializealg = SciMLBase. CheckInit ())
230
248
c = scalarize_condition (condition)
231
249
a = scalarize_affects (affects)
232
- new (c, a)
250
+ new (c, a, reinitializealg )
233
251
end # Default affect to nothing
234
252
end
235
253
@@ -286,6 +304,12 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
286
304
reduce (vcat, affects (cb) for cb in cbs; init = [])
287
305
end
288
306
307
+ reinitialization_alg (cb:: SymbolicDiscreteCallback ) = cb. reinitializealg
308
+ function reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} )
309
+ mapreduce (
310
+ reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
311
+ end
312
+
289
313
function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
290
314
af = affects (cb)
291
315
af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
@@ -579,13 +603,15 @@ function generate_single_rootfinding_callback(
579
603
initfn = SciMLBase. INITIALIZE_DEFAULT
580
604
end
581
605
return ContinuousCallback (
582
- cond, affect_function. affect, affect_function. affect_neg,
583
- rootfind = cb. rootfind, initialize = initfn)
606
+ cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
607
+ initialize = initfn,
608
+ initializealg = reinitialization_alg (cb))
584
609
end
585
610
586
611
function generate_vector_rootfinding_callback (
587
612
cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
588
- ps = parameters (sys); rootfind = SciMLBase. RightRootFind, kwargs... )
613
+ ps = parameters (sys); rootfind = SciMLBase. RightRootFind,
614
+ reinitialization = SciMLBase. CheckInit (), kwargs... )
589
615
eqs = map (cb -> flatten_equations (cb. eqs), cbs)
590
616
num_eqs = length .(eqs)
591
617
# fuse equations to create VectorContinuousCallback
@@ -650,7 +676,8 @@ function generate_vector_rootfinding_callback(
650
676
initfn = SciMLBase. INITIALIZE_DEFAULT
651
677
end
652
678
return VectorContinuousCallback (
653
- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initfn)
679
+ cond, affect, affect_neg, length (eqs), rootfind = rootfind,
680
+ initialize = initfn, initializealg = reinitialization)
654
681
end
655
682
656
683
"""
@@ -690,18 +717,24 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
690
717
# group the cbs by what rootfind op they use
691
718
# groupby would be very useful here, but alas
692
719
cb_classes = Dict{
693
- @NamedTuple {rootfind:: SciMLBase.RootfindOpt }, Vector{SymbolicContinuousCallback}}()
720
+ @NamedTuple {
721
+ rootfind:: SciMLBase.RootfindOpt ,
722
+ reinitialization:: SciMLBase.DAEInitializationAlgorithm }, Vector{SymbolicContinuousCallback}}()
694
723
for cb in cbs
695
724
push! (
696
- get! (() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb. rootfind,)),
725
+ get! (() -> SymbolicContinuousCallback[], cb_classes,
726
+ (
727
+ rootfind = cb. rootfind,
728
+ reinitialization = reinitialization_alg (cb))),
697
729
cb)
698
730
end
699
731
700
732
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
701
733
compiled_callbacks = map (collect (pairs (sort! (
702
734
OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
703
735
return generate_vector_rootfinding_callback (
704
- cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, kwargs... )
736
+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind,
737
+ reinitialization = equiv_class. reinitialization, kwargs... )
705
738
end
706
739
if length (compiled_callbacks) == 1
707
740
return compiled_callbacks[]
@@ -772,10 +805,12 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
772
805
end
773
806
if cond isa AbstractVector
774
807
# Preset Time
775
- return PresetTimeCallback (cond, as; initialize = initfn)
808
+ return PresetTimeCallback (
809
+ cond, as; initialize = initfn, initializealg = reinitialization_alg (cb))
776
810
else
777
811
# Periodic
778
- return PeriodicCallback (as, cond; initialize = initfn)
812
+ return PeriodicCallback (
813
+ as, cond; initialize = initfn, initializealg = reinitialization_alg (cb))
779
814
end
780
815
end
781
816
@@ -800,7 +835,8 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
800
835
else
801
836
initfn = SciMLBase. INITIALIZE_DEFAULT
802
837
end
803
- return DiscreteCallback (c, as; initialize = initfn)
838
+ return DiscreteCallback (
839
+ c, as; initialize = initfn, initializealg = reinitialization_alg (cb))
804
840
end
805
841
end
806
842
0 commit comments