Skip to content

Commit eaf7ae1

Browse files
committed
refactor: make iv, algeeqs kwargs
1 parent bc62e44 commit eaf7ae1

File tree

5 files changed

+757
-751
lines changed

5 files changed

+757
-751
lines changed

src/systems/callbacks.jl

+36-34
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ unknowns(a::AffectSystem) = a.unknowns
7777
parameters(a::AffectSystem) = a.parameters
7878
aff_to_sys(a::AffectSystem) = a.aff_to_sys
7979
previous_vals(a::AffectSystem) = parameters(system(a))
80-
updated_vals(a::AffectSystem) = unknowns(system(a))
8180

8281
function Base.show(iio::IO, aff::AffectSystem)
8382
eqs = vcat(equations(system(aff)), observed(system(aff)))
@@ -148,7 +147,8 @@ haspre(O) = recursive_hasoperator(Pre, O)
148147
const Affect = Union{AffectSystem, FunctionalAffect, ImperativeAffect}
149148

150149
"""
151-
SymbolicContinuousCallback(eqs::Vector{Equation}, affect, affect_neg, rootfind)
150+
SymbolicContinuousCallback(eqs::Vector{Equation}, affect = nothing, iv = nothing;
151+
affect_neg = affect, initialize = nothing, finalize = nothing, rootfind = SciMLBase.LeftRootFind, algeeqs = Equation[])
152152
153153
A [`ContinuousCallback`](@ref SciMLBase.ContinuousCallback) specified symbolically. Takes a vector of equations `eq`
154154
as well as the positive-edge `affect` and negative-edge `affect_neg` that apply when *any* of `eq` are satisfied.
@@ -203,32 +203,31 @@ struct SymbolicContinuousCallback <: AbstractCallback
203203

204204
function SymbolicContinuousCallback(
205205
conditions::Union{Equation, Vector{Equation}},
206-
affect = nothing, iv = nothing;
206+
affect = nothing;
207207
affect_neg = affect,
208208
initialize = nothing,
209209
finalize = nothing,
210210
rootfind = SciMLBase.LeftRootFind,
211+
iv = nothing,
211212
algeeqs = Equation[])
212213

213-
affect isa AbstractVector && isnothing(iv) && @warn "No independent variable specified. If t appears in an affect equation explicitly, like x ~ t + 1, then this must be specified. Otherwise this can be disregarded."
214214
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
215-
new(conditions, make_affect(affect, iv; algeeqs), make_affect(affect_neg, iv; algeeqs),
216-
make_affect(initialize, iv; algeeqs), make_affect(finalize, iv; algeeqs), rootfind)
215+
new(conditions, make_affect(affect; iv, algeeqs), make_affect(affect_neg; iv, algeeqs),
216+
make_affect(initialize; iv, algeeqs), make_affect(finalize; iv, algeeqs), rootfind)
217217
end # Default affect to nothing
218218
end
219219

220-
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
221-
SymbolicContinuousCallback(cb::SymbolicContinuousCallback, args...) = cb
220+
SymbolicContinuousCallback(p::Pair, args...; kwargs...) = SymbolicContinuousCallback(p[1], p[2])
221+
SymbolicContinuousCallback(cb::SymbolicContinuousCallback, args...; kwargs...) = cb
222222

223-
make_affect(affect::Nothing, iv; kwargs...) = nothing
224-
make_affect(affect::Tuple, iv; kwargs...) = FunctionalAffect(affect...)
225-
make_affect(affect::NamedTuple, iv; kwargs...) = FunctionalAffect(; affect...)
226-
make_affect(affect::FunctionalAffect, iv; kwargs...) = affect
227-
make_affect(affect::AffectSystem, iv; kwargs...) = affect
223+
make_affect(affect::Nothing; kwargs...) = nothing
224+
make_affect(affect::Tuple; kwargs...) = FunctionalAffect(affect...)
225+
make_affect(affect::NamedTuple; kwargs...) = FunctionalAffect(; affect...)
226+
make_affect(affect::Affect; kwargs...) = affect
228227

229-
function make_affect(affect::Vector{Equation}, iv; algeeqs = Equation[])
228+
function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs = Equation[])
230229
isempty(affect) && return nothing
231-
isempty(algeeqs) && @warn "No algebraic equations were found. If the system has no algebraic equations, this can be disregarded. Otherwise consider passing in `algeeqs` to the SymbolicContinuousCallbacks constructor."
230+
isempty(algeeqs) && @warn "No algebraic equations were found. If the system has no algebraic equations, this can be disregarded. Otherwise pass in `algeeqs` to the SymbolicContinuousCallback constructor."
232231

233232
affect = scalarize(affect)
234233
dvs = OrderedSet()
@@ -243,6 +242,7 @@ function make_affect(affect::Vector{Equation}, iv; algeeqs = Equation[])
243242
end
244243
if isnothing(iv)
245244
iv = isempty(dvs) ? iv : only(arguments(dvs[1]))
245+
isnothing(iv) && @warn "No independent variable specified and could not be inferred. If the iv appears in an affect equation explicitly, like x ~ t + 1, then it must be specified as an argument to the SymbolicContinuousCallback or SymbolicDiscreteCallback constructor. Otherwise this warning can be disregarded."
246246
end
247247

248248
# System parameters should become unknowns in the ImplicitDiscreteSystem.
@@ -271,21 +271,21 @@ function make_affect(affect::Vector{Equation}, iv; algeeqs = Equation[])
271271
# get accessed parameters p from Pre(p) in the callback parameters
272272
params = filter(isparameter, map(x -> only(arguments(unwrap(x))), cb_params))
273273
# add unknowns to the map
274-
for u in unknowns(affectsys)
274+
for u in dvs
275275
aff_map[u] = u
276276
end
277277

278-
return AffectSystem(affectsys, unknowns(affectsys), params, discretes, aff_map)
278+
return AffectSystem(affectsys, collect(dvs), params, discretes, aff_map)
279279
end
280280

281-
function make_affect(affect)
281+
function make_affect(affect; kwargs...)
282282
error("Malformed affect $(affect). This should be a vector of equations or a tuple specifying a functional affect.")
283283
end
284284

285285
"""
286286
Generate continuous callbacks.
287287
"""
288-
function SymbolicContinuousCallbacks(events, algeeqs::Vector{Equation} = Equation[], iv = nothing)
288+
function SymbolicContinuousCallbacks(events; algeeqs::Vector{Equation} = Equation[], iv = nothing)
289289
callbacks = SymbolicContinuousCallback[]
290290
isnothing(events) && return callbacks
291291

@@ -294,8 +294,7 @@ function SymbolicContinuousCallbacks(events, algeeqs::Vector{Equation} = Equatio
294294

295295
for event in events
296296
cond, affs = event isa Pair ? (event[1], event[2]) : (event, nothing)
297-
affect = make_affect(affs, iv; algeeqs)
298-
push!(callbacks, SymbolicContinuousCallback(cond, affect))
297+
push!(callbacks, SymbolicContinuousCallback(cond, affs; iv, algeeqs))
299298
end
300299
callbacks
301300
end
@@ -380,14 +379,19 @@ end
380379

381380
# TODO: Iterative callbacks
382381
"""
383-
SymbolicDiscreteCallback(conditions::Vector{Equation}, affect)
382+
SymbolicDiscreteCallback(conditions::Vector{Equation}, affect = nothing, iv = nothing;
383+
initialize = nothing, finalize = nothing, algeeqs = Equation[])
384384
385385
A callback that triggers at the first timestep that the conditions are satisfied.
386386
387387
The condition can be one of:
388388
- Δt::Real - periodic events with period Δt
389389
- ts::Vector{Real} - events trigger at these preset times given by `ts`
390390
- eqs::Vector{Equation} - events trigger when the condition evaluates to true
391+
392+
Arguments:
393+
- iv: The independent variable of the system. This must be specified if the independent variable appaers in one of the equations explicitly, as in x ~ t + 1.
394+
- algeeqs: Algebraic equations of the system that must be satisfied after the callback occurs.
391395
"""
392396
struct SymbolicDiscreteCallback <: AbstractCallback
393397
conditions::Any
@@ -397,19 +401,18 @@ struct SymbolicDiscreteCallback <: AbstractCallback
397401

398402
function SymbolicDiscreteCallback(
399403
condition, affect = nothing;
400-
initialize = nothing, finalize = nothing)
404+
initialize = nothing, finalize = nothing, iv = nothing, algeeqs = Equation[])
401405
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
402406

403-
isnothing(iv) && @warn "No independent variable specified. If t appears in an affect equation explicitly, like x ~ t + 1, then this must be specified. Otherwise this can be disregarded."
404-
new(c, make_affect(affect), make_affect(initialize),
405-
make_affect(finalize))
407+
new(c, make_affect(affect; iv, algeeqs), make_affect(initialize; iv, algeeqs),
408+
make_affect(finalize; iv, algeeqs))
406409
end # Default affect to nothing
407410
end
408411

409412
"""
410413
Generate discrete callbacks.
411414
"""
412-
function SymbolicDiscreteCallbacks(events, algeeqs::Vector{Equation} = Equation[], iv = nothing)
415+
function SymbolicDiscreteCallbacks(events; algeeqs::Vector{Equation} = Equation[], iv = nothing)
413416
callbacks = SymbolicDiscreteCallback[]
414417

415418
isnothing(events) && return callbacks
@@ -418,8 +421,7 @@ function SymbolicDiscreteCallbacks(events, algeeqs::Vector{Equation} = Equation[
418421

419422
for event in events
420423
cond, affs = event isa Pair ? (event[1], event[2]) : (event, nothing)
421-
affect = make_affect(affs, iv; algeeqs)
422-
push!(callbacks, SymbolicDiscreteCallback(cond, affect))
424+
push!(callbacks, SymbolicDiscreteCallback(cond, affs; iv, algeeqs))
423425
end
424426
callbacks
425427
end
@@ -801,12 +803,13 @@ function compile_affect(
801803

802804
ps = parameters(aff)
803805
dvs = unknowns(aff)
806+
dvs_to_modify = setdiff(dvs, getfield.(observed(sys), :lhs))
804807

805808
if aff isa AffectSystem
806809
affsys = system(aff)
807810
aff_map = aff_to_sys(aff)
808811
sys_map = Dict([v => k for (k, v) in aff_map])
809-
build_initializeprob = has_alg_eqs(sys)
812+
reinit = has_alg_eqs(sys)
810813

811814
function affect!(integrator)
812815
pmap = Pair[]
@@ -815,12 +818,11 @@ function compile_affect(
815818
pval = isparameter(p) ? integrator.ps[p] : integrator[p]
816819
push!(pmap, pre_p => pval)
817820
end
818-
guesses = Pair[u => integrator[aff_map[u]] for u in updated_vals(aff)]
819-
affprob = ImplicitDiscreteProblem(affsys, Pair[], (0, 1), pmap; guesses, build_initializeprob)
821+
guesses = Pair[u => integrator[aff_map[u]] for u in unknowns(affsys)]
822+
affprob = ImplicitDiscreteProblem(affsys, Pair[], (0, 1), pmap; guesses, build_initializeprob = reinit)
820823

821824
affsol = init(affprob, SimpleIDSolve())
822-
for u in unknowns(aff)
823-
@show u
825+
for u in dvs_to_modify
824826
integrator[u] = affsol[sys_map[u]]
825827
end
826828
for p in discretes(aff)

src/systems/diffeqs/odesystem.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,9 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
305305
throw(ArgumentError("System names must be unique."))
306306
end
307307

308-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !isdiffeq(eq), deqs)
309-
cont_callbacks = SymbolicContinuousCallbacks(continuous_events, alg_eqs, iv)
310-
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events, alg_eqs, iv)
308+
algeeqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !isdiffeq(eq), deqs)
309+
cont_callbacks = SymbolicContinuousCallbacks(continuous_events; algeeqs, iv)
310+
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events; algeeqs, iv)
311311

312312
if is_dde === nothing
313313
is_dde = _check_if_dde(deqs, iv′, systems)

src/systems/diffeqs/sdesystem.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
264264
Wfact = RefValue(EMPTY_JAC)
265265
Wfact_t = RefValue(EMPTY_JAC)
266266

267-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !isdiffeq(eq), deqs)
268-
cont_callbacks = SymbolicContinuousCallbacks(continuous_events, alg_eqs, iv)
269-
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events, alg_eqs, iv)
267+
algeeqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !isdiffeq(eq), deqs)
268+
cont_callbacks = SymbolicContinuousCallbacks(continuous_events; algeeqs, iv)
269+
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events; algeeqs, iv)
270270
if is_dde === nothing
271271
is_dde = _check_if_dde(deqs, iv′, systems)
272272
end

src/systems/jumps/jumpsystem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ function JumpSystem(eqs, iv, unknowns, ps;
236236
end
237237
end
238238

239-
cont_callbacks = SymbolicContinuousCallbacks(continuous_events, Equation[])
240-
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events, Equation[])
239+
cont_callbacks = SymbolicContinuousCallbacks(continuous_events; iv)
240+
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events; iv)
241241

242242
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
243243
ap, iv′, us′, ps′, var_to_name, observed, name, description, systems,

0 commit comments

Comments
 (0)