Skip to content

Commit 90f47ec

Browse files
Merge pull request #3007 from AayushSabharwal/as/save-discretes-init
feat: save discrete variables in callback init
2 parents 7eb5354 + 666ecef commit 90f47ec

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

src/systems/callbacks.jl

+60-5
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,22 @@ function generate_single_rootfinding_callback(
565565
rf_oop(u, parameter_values(integ), t)
566566
end
567567
end
568+
569+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
570+
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
571+
initfn = let save_idxs = save_idxs
572+
function (cb, u, t, integrator)
573+
for idx in save_idxs
574+
SciMLBase.save_discretes!(integrator, idx)
575+
end
576+
end
577+
end
578+
else
579+
initfn = SciMLBase.INITIALIZE_DEFAULT
580+
end
568581
return ContinuousCallback(
569-
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind)
582+
cond, affect_function.affect, affect_function.affect_neg,
583+
rootfind = cb.rootfind, initialize = initfn)
570584
end
571585

572586
function generate_vector_rootfinding_callback(
@@ -618,8 +632,25 @@ function generate_vector_rootfinding_callback(
618632
affect_neg(integ)
619633
end
620634
end
635+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
636+
save_idxs = mapreduce(
637+
cb -> get(ic.callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[])
638+
initfn = if isempty(save_idxs)
639+
SciMLBase.INITIALIZE_DEFAULT
640+
else
641+
let save_idxs = save_idxs
642+
function (cb, u, t, integrator)
643+
for idx in save_idxs
644+
SciMLBase.save_discretes!(integrator, idx)
645+
end
646+
end
647+
end
648+
end
649+
else
650+
initfn = SciMLBase.INITIALIZE_DEFAULT
651+
end
621652
return VectorContinuousCallback(
622-
cond, affect, affect_neg, length(eqs), rootfind = rootfind)
653+
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initfn)
623654
end
624655

625656
"""
@@ -727,12 +758,24 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
727758
cond = condition(cb)
728759
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
729760
postprocess_affect_expr!, kwargs...)
761+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
762+
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
763+
initfn = let save_idxs = save_idxs
764+
function (cb, u, t, integrator)
765+
for idx in save_idxs
766+
SciMLBase.save_discretes!(integrator, idx)
767+
end
768+
end
769+
end
770+
else
771+
initfn = SciMLBase.INITIALIZE_DEFAULT
772+
end
730773
if cond isa AbstractVector
731774
# Preset Time
732-
return PresetTimeCallback(cond, as)
775+
return PresetTimeCallback(cond, as; initialize = initfn)
733776
else
734777
# Periodic
735-
return PeriodicCallback(as, cond)
778+
return PeriodicCallback(as, cond; initialize = initfn)
736779
end
737780
end
738781

@@ -745,7 +788,19 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
745788
c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...)
746789
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
747790
postprocess_affect_expr!, kwargs...)
748-
return DiscreteCallback(c, as)
791+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
792+
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
793+
initfn = let save_idxs = save_idxs
794+
function (cb, u, t, integrator)
795+
for idx in save_idxs
796+
SciMLBase.save_discretes!(integrator, idx)
797+
end
798+
end
799+
end
800+
else
801+
initfn = SciMLBase.INITIALIZE_DEFAULT
802+
end
803+
return DiscreteCallback(c, as; initialize = initfn)
749804
end
750805
end
751806

test/symbolic_events.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ end
882882
@test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0]
883883
sol = solve(prob, Tsit5())
884884

885-
@test sol[a] == [-1.0]
886-
@test sol[b] == [5.0, 5.0]
887-
@test sol[c] == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
885+
@test sol[a] == [1.0, -1.0]
886+
@test sol[b] == [2.0, 5.0, 5.0]
887+
@test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
888888
end

0 commit comments

Comments
 (0)