Skip to content

Commit 5f06c73

Browse files
build_function from ODESystem
1 parent 53ac691 commit 5f06c73

File tree

4 files changed

+29
-6
lines changed

4 files changed

+29
-6
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "1.4.1"
4+
version = "1.4.2"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/build_function.jl

+15-4
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
9292
end
9393
end
9494

95+
get_varnumber(varop::Operation,vars::Vector{Operation}) = findfirst(x->isequal(x,varop),vars)
96+
get_varnumber(varop::Operation,vars::Vector{Variable}) = findfirst(x->isequal(x,varop.op),vars)
97+
9598
function numbered_expr(O::Equation,args...;kwargs...)
9699
:($(numbered_expr(O.lhs,args...;kwargs...)) = $(numbered_expr(O.rhs,args...;kwargs...)))
97100
end
@@ -101,12 +104,12 @@ function numbered_expr(O::Operation,vars,parameters;
101104
varname=:u,paramname=:p)
102105
if isa(O.op, ModelingToolkit.Differential)
103106
varop = O.args[1]
104-
i = findfirst(x->isequal(x,varop),vars)
107+
i = get_varnumber(varop,vars)
105108
return :($derivname[$i])
106109
elseif isa(O.op, ModelingToolkit.Variable)
107-
i = findfirst(x->isequal(x,O),vars)
110+
i = get_varnumber(O,vars)
108111
if i == nothing
109-
i = findfirst(x->isequal(x,O),parameters)
112+
i = get_varnumber(O,parameters)
110113
return :($paramname[$i])
111114
else
112115
return :($varname[$i])
@@ -116,7 +119,15 @@ function numbered_expr(O::Operation,vars,parameters;
116119
[numbered_expr(x,vars,parameters;derivname=derivname,
117120
varname=varname,paramname=paramname) for x in O.args]...)
118121
end
119-
function numbered_expr(de::ModelingToolkit.DiffEq,vars,parameters;
122+
123+
function numbered_expr(de::ModelingToolkit.DiffEq,vars::Vector{Variable},parameters;
124+
derivname=:du,varname=:u,paramname=:p)
125+
i = findfirst(x->isequal(x.name,de.x.name),vars)
126+
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;
127+
derivname=derivname,
128+
varname=varname,paramname=paramname)))
129+
end
130+
function numbered_expr(de::ModelingToolkit.DiffEq,vars::Vector{Operation},parameters;
120131
derivname=:du,varname=:u,paramname=:p)
121132
i = findfirst(x->isequal(x.op.name,de.x.name),vars)
122133
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;

src/systems/diffeqs/diffeqsystem.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,12 @@ $(SIGNATURES)
289289
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
290290
"""
291291
function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
292+
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
293+
return (prob.f.sys, prob.f.sys.dvs, prob.f.sys.ps)
292294
@parameters t
293295
vars = [Variable(:x, i)(t) for i in eachindex(prob.u0)]
294-
params = [Variable(,i; known = true)() for i in eachindex(prob.p)]
296+
params = prob.p isa DiffEqBase.NullParameters ? [] :
297+
[Variable(,i; known = true)() for i in eachindex(prob.p)]
295298
@derivatives D'~t
296299

297300
rhs = [D(var) for var in vars]

test/build_targets.jl

+9
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,12 @@ sys = ODESystem(eqs,t,[x,y],[a])
3636

3737
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.MATLABTarget()) ==
3838
ModelingToolkit.build_function(sys.eqs,[x,y],[a],t,target = ModelingToolkit.MATLABTarget())
39+
40+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget()) ==
41+
ModelingToolkit.build_function(sys.eqs,sys.dvs,sys.ps,sys.iv,target = ModelingToolkit.CTarget())
42+
43+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget()) ==
44+
ModelingToolkit.build_function(sys.eqs,sys.dvs,sys.ps,sys.iv,target = ModelingToolkit.StanTarget())
45+
46+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.MATLABTarget()) ==
47+
ModelingToolkit.build_function(sys.eqs,sys.dvs,sys.ps,sys.iv,target = ModelingToolkit.MATLABTarget())

0 commit comments

Comments
 (0)