Skip to content

Commit 9b8e9c4

Browse files
Merge pull request #173 from JuliaDiffEq/distributed
compatibility with distributed
2 parents 711d1b8 + 9fb8eb8 commit 9b8e9c4

10 files changed

+113
-137
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
*.jl.cov
22
*.jl.*.cov
33
*.jl.mem
4+
Manifest.toml

Project.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ version = "0.6.4"
66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
88
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
9+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
910
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
11+
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
1012
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1113
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1214
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
@@ -18,7 +20,8 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1820
julia = "1"
1921

2022
[extras]
23+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2124
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2225

2326
[targets]
24-
test = ["Test"]
27+
test = ["OrdinaryDiffEq", "Test"]

src/ModelingToolkit.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ export @register, @I
88
export modelingtoolkitize
99

1010

11-
using DiffEqBase
11+
using DiffEqBase, Distributed
1212
using StaticArrays, LinearAlgebra
1313

1414
using MacroTools
1515
import MacroTools: splitdef, combinedef
16-
16+
import GeneralizedGenerated
1717
using DocStringExtensions
1818

1919
"""

src/direct.jl

-9
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,3 @@ end
3232
function simplified_expr(eq::Equation)
3333
Expr(:(=), simplified_expr(eq.lhs), simplified_expr(eq.rhs))
3434
end
35-
36-
macro I(ex)
37-
name = :ICompile
38-
ret = return quote
39-
macro $(esc(name))()
40-
esc($ex)
41-
end
42-
end
43-
end

src/systems/diffeqs/diffeqsystem.jl

+28-44
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,20 @@ function (f::ODEToExpr)(O::Operation)
156156
isempty(O.args) && return O.op.name # 0-ary parameters
157157
return build_expr(:call, Any[O.op.name; f.(O.args)])
158158
end
159-
return build_expr(:call, Any[O.op; f.(O.args)])
159+
return build_expr(:call, Any[Symbol(O.op); f.(O.args)])
160160
end
161161
(f::ODEToExpr)(x) = convert(Expr, x)
162162

163-
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps)
163+
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true})
164164
jac = calculate_jacobian(sys)
165-
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys))
165+
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression)
166166
end
167167

168-
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps)
168+
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true})
169169
rhss = [deq.rhs for deq sys.eqs]
170170
dvs′ = [clean(dv) for dv dvs]
171171
ps′ = [clean(p) for p ps]
172-
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys))
172+
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression)
173173
end
174174

175175
function calculate_factorized_W(sys::ODESystem, simplify=true)
@@ -196,16 +196,16 @@ function calculate_factorized_W(sys::ODESystem, simplify=true)
196196
(Wfact,Wfact_t)
197197
end
198198

199-
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true)
199+
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true})
200200
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
201201
siz = size(Wfact)
202202
constructor = :(x -> begin
203203
A = SMatrix{$siz...}(x)
204204
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
205205
end)
206206

207-
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys);constructor=constructor)
208-
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys);constructor=constructor)
207+
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor)
208+
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor)
209209

210210
return (Wfact_func, Wfact_t_func)
211211
end
@@ -217,53 +217,37 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
217217
are used to set the order of the dependent variable and parameter vectors,
218218
respectively.
219219
"""
220-
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps,
221-
safe = Val{true};
220+
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps;
222221
version = nothing,
223222
jac = false, Wfact = false) where {iip}
224-
_f = eval(generate_function(sys, dvs, ps))
225-
out_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u,p,t)
226-
out_f_safe(du,u,p,t) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u,p,t)
227-
out_f(u,p,t) = _f(u,p,t)
228-
out_f(du,u,p,t) = _f(du,u,p,t)
223+
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
224+
225+
f(u,p,t) = f_oop(u,p,t)
226+
f(du,u,p,t) = f_iip(du,u,p,t)
229227

230228
if jac
231-
_jac = eval(generate_jacobian(sys, dvs, ps))
232-
jac_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_jac,Matrix{eltype(u)},u,p,t)
233-
jac_f_safe(J,u,p,t) = ModelingToolkit.fast_invokelatest(_jac,Nothing,J,u,p,t)
234-
jac_f(u,p,t) = _jac(u,p,t)
235-
jac_f(J,u,p,t) = _jac(J,u,p,t)
229+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
230+
_jac(u,p,t) = jac_oop(u,p,t)
231+
_jac(J,u,p,t) = jac_iip(J,u,p,t)
236232
else
237-
jac_f_safe = nothing
238-
jac_f = nothing
233+
_jac = nothing
239234
end
240235

241236
if Wfact
242-
_Wfact,_Wfact_t = eval.(generate_factorized_W(sys, dvs, ps))
243-
Wfact_f_safe(u,p,gam,t) = ModelingToolkit.fast_invokelatest(_Wfact,Matrix{eltype(u)},u,p,gam,t)
244-
Wfact_f_safe(J,u,p,gam,t) = ModelingToolkit.fast_invokelatest(_Wfact,Nothing,J,u,p,gam,t)
245-
Wfact_f_t_safe(u,p,gam,t) = ModelingToolkit.fast_invokelatest(_Wfact_t,Matrix{eltype(u)},u,p,gam,t)
246-
Wfact_f_t_safe(J,u,p,gam,t) = ModelingToolkit.fast_invokelatest(_Wfact_t,Nothing,J,u,p,gam,t)
247-
Wfact_f(u,p,gam,t) = _Wfact(u,p,gam,t)
248-
Wfact_f(J,u,p,gam,t) = _Wfact(J,u,p,gam,t)
249-
Wfact_f_t(u,p,gam,t) = _Wfact_t(u,p,gam,t)
250-
Wfact_f_t(J,u,p,gam,t) = _Wfact_t(J,u,p,gam,t)
237+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, Val{false})
238+
Wfact_oop, Wfact_iip = tmp_Wfact
239+
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
240+
_Wfact(u,p,t) = Wfact_oop(u,p,t)
241+
_Wfact(W,u,p,t) = Wfact_iip(W,u,p,t)
242+
_Wfact_t(u,p,t) = Wfact_oop_t(u,p,t)
243+
_Wfact_t(W,u,p,t) = Wfact_iip_t(W,u,p,t)
251244
else
252-
Wfact_f_safe = nothing
253-
Wfact_f_t_safe = nothing
254-
Wfact_f = nothing
255-
Wfact_f_t = nothing
245+
_Wfact,_Wfact_t = nothing,nothing
256246
end
257247

258-
if safe === Val{true}
259-
ODEFunction{iip}(out_f_safe,jac=jac_f_safe,
260-
Wfact = Wfact_f_safe,
261-
Wfact_t = Wfact_f_t_safe)
262-
else
263-
ODEFunction{iip}(out_f,jac=jac_f,
264-
Wfact = Wfact_f,
265-
Wfact_t = Wfact_f_t)
266-
end
248+
ODEFunction{iip}(f,jac=_jac,
249+
Wfact = _Wfact,
250+
Wfact_t = _Wfact_t)
267251
end
268252

269253
function DiffEqBase.ODEFunction(sys::ODESystem, args...; kwargs...)

src/utils.jl

+19-17
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; constructor=nothing)
34+
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true}; constructor=nothing)
3535
_vs = map(x-> x isa Operation ? x.op : x, vs)
3636
_ps = map(x-> x isa Operation ? x.op : x, ps)
3737
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
@@ -40,28 +40,38 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; co
4040

4141
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
4242

43-
fname = gensym()
43+
fname = gensym(:ModelingToolkitFunction)
4444

45-
X = gensym()
45+
X = gensym(:MTIIPVar)
4646
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
4747
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
4848

4949
sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
5050
let_expr = Expr(:let, var_eqs, sys_expr)
5151

5252
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
53-
quote
54-
function $fname($X,$(fargs.args...))
55-
$ip_let_expr
56-
nothing
57-
end
58-
function $fname($(fargs.args...))
53+
54+
oop_ex = :(
55+
($(fargs.args...),) -> begin
5956
X = $let_expr
6057
T = promote_type(map(typeof,X)...)
6158
convert.(T,X)
6259
construct = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, T, $(size(rhss)...)); vec(du) .= x; du)) : constructor)
6360
construct(X)
6461
end
62+
)
63+
64+
iip_ex = :(
65+
($X,$(fargs.args...)) -> begin
66+
$ip_let_expr
67+
nothing
68+
end
69+
)
70+
71+
if expression == Val{true}
72+
return oop_ex, iip_ex
73+
else
74+
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex), GeneralizedGenerated.mk_function(@__MODULE__,iip_ex)
6575
end
6676
end
6777

@@ -94,11 +104,3 @@ function vars!(vars, O)
94104

95105
return vars
96106
end
97-
98-
@inline @generated function fast_invokelatest(f, ::Type{rt}, args...) where rt
99-
tupargs = Expr(:tuple,(a==Nothing ? Int : a for a in args)...)
100-
quote
101-
_f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($((a==Nothing ? Int : a for a in args)...))), :(:ccall)))
102-
return ccall(_f.ptr,rt,$tupargs,$((:(getindex(args,$i) === nothing ? 0 : getindex(args,$i)) for i in 1:length(args))...))
103-
end
104-
end

test/direct.jl

+10-54
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ end
3636
eqs =*(y-x),
3737
x*-z)-y,
3838
x*y - β*z]
39-
f = eval(ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β]))
39+
f1,f2 = ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β])
40+
f = eval(f1)
4041
out = [1.0,2,3]
4142
o1 = f([1.0,2,3],[1.0,2,3])
43+
f = eval(f2)
4244
f(out,[1.0,2,3],[1.0,2,3])
4345
@test all(o1 .== out)
4446

@@ -48,54 +50,23 @@ function test_worldage()
4850
eqs =*(y-x),
4951
x*-z)-y,
5052
x*y - β*z]
51-
_f = eval(ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β]))
52-
f(u,p) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u,p)
53-
f(du,u,p) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u,p)
53+
f, f_iip = ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β],(),ModelingToolkit.simplified_expr,Val{false})
5454
out = [1.0,2,3]
5555
o1 = f([1.0,2,3],[1.0,2,3])
56-
f(out,[1.0,2,3],[1.0,2,3])
56+
f_iip(out,[1.0,2,3],[1.0,2,3])
5757
end
5858
test_worldage()
5959

60-
mac = @I begin
61-
@parameters σ() ρ() β()
62-
@variables x() y() z()
63-
64-
eqs =*(y-x),
65-
x*-z)-y,
66-
x*y - β*z]
67-
ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β])
68-
end
69-
f = @ICompile
70-
out = [1.0,2,3]
71-
o1 = f([1.0,2,3],[1.0,2,3])
72-
f(out,[1.0,2,3],[1.0,2,3])
73-
@test all(o1 .== out)
74-
75-
mac = @I begin
76-
@parameters σ ρ β
77-
@variables x y z
78-
79-
eqs =*(y-x),
80-
x*-z)-y,
81-
x*y - β*z]
82-
= ModelingToolkit.jacobian(eqs,[x,y,z])
83-
ModelingToolkit.build_function(∂,[x,y,z],[σ,ρ,β])
84-
end
85-
f = @ICompile
86-
out = zeros(3,3)
87-
o1 = f([1.0,2,3],[1.0,2,3])
88-
f(out,[1.0,2,3],[1.0,2,3])
89-
@test all(out .== o1)
90-
9160
## No parameters
9261
@variables x y z
9362
eqs = [(y-x)^2,
9463
x*(x-z)-y,
9564
x*y - y*z]
96-
f = eval(ModelingToolkit.build_function(eqs,[x,y,z]))
65+
f1,f2 = ModelingToolkit.build_function(eqs,[x,y,z])
66+
f = eval(f1)
9767
out = zeros(3)
9868
o1 = f([1.0,2,3])
69+
f = eval(f2)
9970
f(out,[1.0,2,3])
10071
@test all(out .== o1)
10172

@@ -104,24 +75,9 @@ function test_worldage()
10475
eqs = [(y-x)^2,
10576
x*(x-z)-y,
10677
x*y - y*z]
107-
_f = eval(ModelingToolkit.build_function(eqs,[x,y,z]))
108-
f(u) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u)
109-
f(du,u) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u)
78+
f, f_iip = ModelingToolkit.build_function(eqs,[x,y,z],(),(),ModelingToolkit.simplified_expr,Val{false})
11079
out = zeros(3)
11180
o1 = f([1.0,2,3])
112-
f(out,[1.0,2,3])
81+
f_iip(out,[1.0,2,3])
11382
end
11483
test_worldage()
115-
116-
mac = @I begin
117-
@variables x y z
118-
eqs = [(y-x)^2,
119-
x*(x-z)-y,
120-
x*y - y*z]
121-
ModelingToolkit.build_function(eqs,[x,y,z])
122-
end
123-
f = @ICompile
124-
out = zeros(3)
125-
o1 = f([1.0,2,3])
126-
f(out,[1.0,2,3])
127-
@test all(out .== o1)

test/distributed.jl

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using Distributed
2+
# add processes to workspace
3+
addprocs(2)
4+
5+
@everywhere using ModelingToolkit, OrdinaryDiffEq
6+
7+
# create the Lorenz system
8+
@everywhere @parameters t σ ρ β
9+
@everywhere @variables x(t) y(t) z(t)
10+
@everywhere @derivatives D'~t
11+
12+
@everywhere eqs = [D(x) ~ σ*(y-x),
13+
D(y) ~ x*-z)-y,
14+
D(z) ~ x*y - β*z]
15+
16+
@everywhere de = ODESystem(eqs)
17+
@everywhere ode_func = ODEFunction(de, [x,y,z], [σ, ρ, β])
18+
19+
@everywhere u0 = [19.,20.,50.]
20+
@everywhere params = [16.,45.92,4]
21+
22+
@everywhere ode_prob = ODEProblem(ode_func, u0, (0., 10.),params)
23+
24+
@everywhere begin
25+
26+
using OrdinaryDiffEq
27+
using ModelingToolkit
28+
29+
function solve_lorenz(ode_problem)
30+
print(solve(ode_problem,Tsit5()))
31+
end
32+
end
33+
34+
solve_lorenz(ode_prob)
35+
36+
future = @spawn solve_lorenz(ode_prob)
37+
fetch(future)

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ using ModelingToolkit, Test
55
@testset "Simplify Test" begin include("simplify.jl") end
66
@testset "Direct Usage Test" begin include("direct.jl") end
77
@testset "System Construction Test" begin include("system_construction.jl") end
8+
@testset "Distributed Test" begin include("distributed.jl") end

0 commit comments

Comments
 (0)