Skip to content

Commit 2d9e975

Browse files
Merge pull request #336 from isaacsas/mass_action_jumps
add mass action jump support
2 parents 20ce5cc + 491b4f5 commit 2d9e975

File tree

3 files changed

+80
-4
lines changed

3 files changed

+80
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2727
[compat]
2828
ArrayInterface = "2.8"
2929
DiffEqBase = "6.28"
30-
DiffEqJump = "6.6"
30+
DiffEqJump = "6.6.2"
3131
DiffRules = "0.1, 1.0"
3232
DocStringExtensions = "0.7, 0.8"
3333
GeneralizedGenerated = "0.1.4, 0.2"

src/systems/jumps/jumpsystem.jl

+45-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,34 @@ function assemble_crj(js, crj, statetoid)
4242
ConstantRateJump(rate, affect)
4343
end
4444

45+
function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2,W2}}},
46+
statetoid, ptoid, p, pcontext) where {U,V,W,V2,W2}
47+
sr = maj.scaled_rates
48+
if sr isa Operation || sr isa Variable
49+
pval = Base.eval(pcontext, Expr(maj.scaled_rates))
50+
else
51+
pval = maj.scaled_rates
52+
end
53+
54+
rs = Vector{Pair{Int,W}}()
55+
for (spec,stoich) in maj.reactant_stoch
56+
if iszero(spec)
57+
push!(rs, 0 => stoich)
58+
else
59+
push!(rs, statetoid[convert(Variable,spec)] => stoich)
60+
end
61+
end
62+
sort!(rs)
63+
64+
ns = Vector{Pair{Int,W2}}()
65+
for (spec,stoich) in maj.net_stoch
66+
iszero(spec) && error("Net stoichiometry can not have a species labelled 0.")
67+
push!(ns, statetoid[convert(Variable,spec)] => stoich)
68+
end
69+
sort!(ns)
70+
71+
MassActionJump(pval, rs, ns, scale_rates = false)
72+
end
4573

4674
"""
4775
```julia
@@ -68,17 +96,32 @@ Generates a JumpProblem from a JumpSystem.
6896
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
6997
vrjs = Vector{VariableRateJump}()
7098
crjs = Vector{ConstantRateJump}()
99+
majs = Vector{MassActionJump}()
100+
pvars = parameters(js)
71101
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
102+
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
103+
104+
# for mass action jumps might need to evaluate parameter expressions
105+
# populate dummy module with params as local variables
106+
# (for eval-ing parameter expressions)
107+
param_context = Module()
108+
for (i, pval) in enumerate(prob.p)
109+
psym = Symbol(pvars[i])
110+
Base.eval(param_context, :($psym = $pval))
111+
end
112+
72113
for j in equations(js)
73114
if j isa ConstantRateJump
74115
push!(crjs, assemble_crj(js, j, statetoid))
75116
elseif j isa VariableRateJump
76117
push!(vrjs, assemble_vrj(js, j, statetoid))
118+
elseif j isa MassActionJump
119+
push!(majs, assemble_maj(js, j, statetoid, ptoid, prob.p, param_context))
77120
else
78-
(j isa MassActionJump) && error("Generation of JumpProblems with MassActionJumps is not yet supported.")
121+
error("JumpSystems should only contain Constant, Variable or Mass Action Jumps.")
79122
end
80123
end
81124
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
82-
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, nothing)
125+
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, isempty(majs) ? nothing : majs)
83126
JumpProblem(prob, aggregator, jset)
84127
end

test/jumpsystem.jl

+34-1
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,37 @@ jprob = JumpProblem(prob,Direct(),jset, save_positions=(false,false))
105105
m2 = getmean(jprob,Nsims)
106106

107107
# test JumpSystem solution agrees with direct version
108-
@test abs(m-m2) ./ m < .01
108+
@test abs(m-m2)/m < .01
109+
110+
111+
# mass action jump tests for SIR model
112+
maj1 = MassActionJump(2*β/2, [S => 1, I => 1], [S => -1, I => 1])
113+
maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
114+
js3 = JumpSystem([maj1,maj2], t, [S,I,R], [β,γ])
115+
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
116+
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
117+
dprob = DiscreteProblem(js3, u₀map, tspan, parammap)
118+
jprob = JumpProblem(js3, dprob, Direct())
119+
m3 = getmean(jprob,Nsims)
120+
@test abs(m-m3)/m < .01
121+
122+
# mass action jump tests for other reaction types (zero order, decay)
123+
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
124+
maj2 = MassActionJump(γ, [S => 1], [S => -1])
125+
js4 = JumpSystem([maj1,maj2], t, [S], [β,γ])
126+
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
127+
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
128+
dprob = DiscreteProblem(js4, [S => 999], (0,1000.), [β => 100.=> .01])
129+
jprob = JumpProblem(js4, dprob, Direct())
130+
m4 = getmean(jprob,Nsims)
131+
@test abs(m4 - 2.0/.01)*.01/2.0 < .01
132+
133+
# test second order rx runs
134+
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
135+
maj2 = MassActionJump(γ, [S => 2], [S => -1])
136+
js4 = JumpSystem([maj1,maj2], t, [S], [β,γ])
137+
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
138+
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
139+
dprob = DiscreteProblem(js4, [S => 999], (0,1000.), [β => 100.=> .01])
140+
jprob = JumpProblem(js4, dprob, Direct())
141+
sol = solve(jprob, SSAStepper())

0 commit comments

Comments
 (0)