Skip to content

Commit e4e7d97

Browse files
Merge pull request #2946 from isaacsas/better_vrj_support
support JumpProblems over ODEProblems
2 parents c10d00b + 42ebdf4 commit e4e7d97

File tree

4 files changed

+232
-52
lines changed

4 files changed

+232
-52
lines changed

src/systems/dependency_graphs.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ equation_dependencies(jumpsys)
3636
equation_dependencies(jumpsys, variables = parameters(jumpsys))
3737
```
3838
"""
39-
function equation_dependencies(sys::AbstractSystem; variables = unknowns(sys))
40-
eqs = equations(sys)
39+
function equation_dependencies(sys::AbstractSystem; variables = unknowns(sys),
40+
eqs = equations(sys))
4141
deps = Set()
4242
depeqs_to_vars = Vector{Vector}(undef, length(eqs))
4343

@@ -114,8 +114,9 @@ digr = asgraph(jumpsys)
114114
```
115115
"""
116116
function asgraph(sys::AbstractSystem; variables = unknowns(sys),
117-
variablestoids = Dict(v => i for (i, v) in enumerate(variables)))
118-
asgraph(equation_dependencies(sys, variables = variables), variablestoids)
117+
variablestoids = Dict(v => i for (i, v) in enumerate(variables)),
118+
eqs = equations(sys))
119+
asgraph(equation_dependencies(sys; variables, eqs), variablestoids)
119120
end
120121

121122
"""
@@ -141,8 +142,7 @@ variable_dependencies(jumpsys)
141142
```
142143
"""
143144
function variable_dependencies(sys::AbstractSystem; variables = unknowns(sys),
144-
variablestoids = nothing)
145-
eqs = equations(sys)
145+
variablestoids = nothing, eqs = equations(sys))
146146
vtois = isnothing(variablestoids) ? Dict(v => i for (i, v) in enumerate(variables)) :
147147
variablestoids
148148

@@ -193,8 +193,8 @@ dg = asdigraph(digr, jumpsys)
193193
```
194194
"""
195195
function asdigraph(g::BipartiteGraph, sys::AbstractSystem; variables = unknowns(sys),
196-
equationsfirst = true)
197-
neqs = length(equations(sys))
196+
equationsfirst = true, eqs = equations(sys))
197+
neqs = length(eqs)
198198
nvars = length(variables)
199199
fadjlist = deepcopy(g.fadjlist)
200200
badjlist = deepcopy(g.badjlist)

src/systems/jumps/jumpsystem.jl

+64-5
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ function JumpSystem(eqs, iv, unknowns, ps;
194194
metadata, gui_metadata, checks = checks)
195195
end
196196

197+
has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
198+
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
199+
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
200+
197201
function generate_rate_function(js::JumpSystem, rate)
198202
consts = collect_constants(rate)
199203
if !isempty(consts) # The SymbolicUtils._build_function method of this case doesn't support postprocess_fbody
@@ -311,7 +315,7 @@ end
311315
```julia
312316
DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
313317
parammap = DiffEqBase.NullParameters;
314-
use_union = false,
318+
use_union = true,
315319
kwargs...)
316320
```
317321
@@ -331,7 +335,6 @@ dprob = DiscreteProblem(complete(js), u₀map, tspan, parammap)
331335
"""
332336
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
333337
parammap = DiffEqBase.NullParameters();
334-
checkbounds = false,
335338
use_union = true,
336339
eval_expression = false,
337340
eval_module = @__MODULE__,
@@ -385,7 +388,7 @@ struct DiscreteProblemExpr{iip} end
385388

386389
function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
387390
parammap = DiffEqBase.NullParameters();
388-
use_union = false,
391+
use_union = true,
389392
kwargs...) where {iip}
390393
if !iscomplete(sys)
391394
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
@@ -412,6 +415,60 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
412415
end
413416
end
414417

418+
"""
419+
```julia
420+
DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan,
421+
parammap = DiffEqBase.NullParameters;
422+
use_union = true,
423+
kwargs...)
424+
```
425+
426+
Generates a blank ODEProblem for a pure jump JumpSystem to utilize as its `prob.prob`. This
427+
is used in the case where there are no ODEs and no SDEs associated with the system but there
428+
are jumps with an explicit time dependency (i.e. `VariableRateJump`s). If no jumps have an
429+
explicit time dependence, i.e. all are `ConstantRateJump`s or `MassActionJump`s then
430+
`DiscreteProblem` should be preferred for performance reasons.
431+
432+
Continuing the example from the [`JumpSystem`](@ref) definition:
433+
434+
```julia
435+
using DiffEqBase, JumpProcesses
436+
u₀map = [S => 999, I => 1, R => 0]
437+
parammap = [β => 0.1 / 1000, γ => 0.01]
438+
tspan = (0.0, 250.0)
439+
oprob = ODEProblem(complete(js), u₀map, tspan, parammap)
440+
```
441+
"""
442+
function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
443+
parammap = DiffEqBase.NullParameters();
444+
use_union = true,
445+
eval_expression = false,
446+
eval_module = @__MODULE__,
447+
kwargs...)
448+
if !iscomplete(sys)
449+
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
450+
end
451+
dvs = unknowns(sys)
452+
ps = parameters(sys)
453+
454+
defs = defaults(sys)
455+
defs = mergedefaults(defs, parammap, ps)
456+
defs = mergedefaults(defs, u0map, dvs)
457+
458+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
459+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
460+
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
461+
else
462+
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
463+
end
464+
465+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
466+
467+
f = (du, u, p, t) -> (du .= 0; nothing)
468+
df = ODEFunction(f; sys, observed = observedfun)
469+
ODEProblem(df, u0, tspan, p; kwargs...)
470+
end
471+
415472
"""
416473
```julia
417474
DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
@@ -449,10 +506,12 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob,
449506
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
450507
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)
451508

509+
# dep graphs are only for constant rate jumps
510+
nonvrjs = ArrayPartition(eqs.x[1], eqs.x[2])
452511
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator) ||
453512
(aggregator isa JumpProcesses.NullAggregator)
454-
jdeps = asgraph(js)
455-
vdeps = variable_dependencies(js)
513+
jdeps = asgraph(js; eqs = nonvrjs)
514+
vdeps = variable_dependencies(js; eqs = nonvrjs)
456515
vtoj = jdeps.badjlist
457516
jtov = vdeps.badjlist
458517
jtoj = needs_depgraph(aggregator) ? eqeq_dependencies(jdeps, vdeps).fadjlist :

test/dep_graphs.jl

+80-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Test
2-
using ModelingToolkit, Graphs, JumpProcesses
2+
using ModelingToolkit, Graphs, JumpProcesses, RecursiveArrayTools
33
using ModelingToolkit: t_nounits as t, D_nounits as D
44
import ModelingToolkit: value
55

@@ -16,11 +16,11 @@ j₅ = ConstantRateJump(k1 * I, [R ~ R + 1])
1616
j₆ = VariableRateJump(k1 * k2 / (1 + t) * S, [S ~ S - 1, R ~ R + 1])
1717
eqs = [j₁, j₂, j₃, j₄, j₅, j₆]
1818
@named js = JumpSystem(eqs, t, [S, I, R], [k1, k2])
19-
S = value(S);
20-
I = value(I);
21-
R = value(R);
22-
k1 = value(k1);
23-
k2 = value(k2);
19+
S = value(S)
20+
I = value(I)
21+
R = value(R)
22+
k1 = value(k1)
23+
k2 = value(k2)
2424
# eq to vars they depend on
2525
eq_sdeps = [Variable[], [S], [S, I], [S, R], [I], [S]]
2626
eq_sidepsf = [Int[], [1], [1, 2], [1, 3], [2], [1]]
@@ -72,6 +72,80 @@ end
7272
dg4 = varvar_dependencies(depsbg, deps2)
7373
@test dg == dg4
7474

75+
# testing when ignoring VariableRateJumps
76+
let
77+
@parameters k1 k2
78+
@variables S(t) I(t) R(t)
79+
j₁ = MassActionJump(k1, [0 => 1], [S => 1])
80+
j₂ = MassActionJump(k1, [S => 1], [S => -1])
81+
j₃ = MassActionJump(k2, [S => 1, I => 1], [S => -1, I => 1])
82+
j₄ = MassActionJump(k2, [S => 2, R => 1], [R => -1])
83+
j₅ = ConstantRateJump(k1 * I, [R ~ R + 1])
84+
j₆ = VariableRateJump(k1 * k2 / (1 + t) * S, [S ~ S - 1, R ~ R + 1])
85+
eqs = [j₁, j₂, j₃, j₄, j₅, j₆]
86+
@named js = JumpSystem(eqs, t, [S, I, R], [k1, k2])
87+
S = value(S)
88+
I = value(I)
89+
R = value(R)
90+
k1 = value(k1)
91+
k2 = value(k2)
92+
# eq to vars they depend on
93+
eq_sdeps = [Variable[], [S], [S, I], [S, R], [I]]
94+
eq_sidepsf = [Int[], [1], [1, 2], [1, 3], [2]]
95+
eq_sidepsb = [[2, 3, 4], [3, 5], [4]]
96+
97+
# filter out vrjs in making graphs
98+
eqs = ArrayPartition(equations(js).x[1], equations(js).x[2])
99+
deps = equation_dependencies(js; eqs)
100+
@test length(deps) == length(eq_sdeps)
101+
@test all(i -> isequal(Set(eq_sdeps[i]), Set(deps[i])), 1:length(eqs))
102+
depsbg = asgraph(js; eqs)
103+
@test depsbg.fadjlist == eq_sidepsf
104+
@test depsbg.badjlist == eq_sidepsb
105+
106+
# eq to params they depend on
107+
eq_pdeps = [[k1], [k1], [k2], [k2], [k1]]
108+
eq_pidepsf = [[1], [1], [2], [2], [1]]
109+
eq_pidepsb = [[1, 2, 5], [3, 4]]
110+
deps = equation_dependencies(js; variables = parameters(js), eqs)
111+
@test length(deps) == length(eq_pdeps)
112+
@test all(i -> isequal(Set(eq_pdeps[i]), Set(deps[i])), 1:length(eqs))
113+
depsbg2 = asgraph(js; variables = parameters(js), eqs)
114+
@test depsbg2.fadjlist == eq_pidepsf
115+
@test depsbg2.badjlist == eq_pidepsb
116+
117+
# var to eqs that modify them
118+
s_eqdepsf = [[1, 2, 3], [3], [4, 5]]
119+
s_eqdepsb = [[1], [1], [1, 2], [3], [3]]
120+
ne = 6
121+
bg = BipartiteGraph(ne, s_eqdepsf, s_eqdepsb)
122+
deps2 = variable_dependencies(js; eqs)
123+
@test isequal(bg, deps2)
124+
125+
# eq to eqs that depend on them
126+
eq_eqdeps = [[2, 3, 4], [2, 3, 4], [2, 3, 4, 5], [4], [4], [2, 3, 4]]
127+
dg = SimpleDiGraph(5)
128+
for (eqidx, eqdeps) in enumerate(eq_eqdeps)
129+
for eqdepidx in eqdeps
130+
add_edge!(dg, eqidx, eqdepidx)
131+
end
132+
end
133+
dg3 = eqeq_dependencies(depsbg, deps2)
134+
@test dg == dg3
135+
136+
# var to vars that depend on them
137+
var_vardeps = [[1, 2, 3], [1, 2, 3], [3]]
138+
ne = 7
139+
dg = SimpleDiGraph(3)
140+
for (vidx, vdeps) in enumerate(var_vardeps)
141+
for vdepidx in vdeps
142+
add_edge!(dg, vidx, vdepidx)
143+
end
144+
end
145+
dg4 = varvar_dependencies(depsbg, deps2)
146+
@test dg == dg4
147+
end
148+
75149
#####################################
76150
# testing for ODE/SDEs
77151
#####################################

0 commit comments

Comments
 (0)