Skip to content

Commit f2cbd38

Browse files
Merge pull request #213 from JuliaDiffEq/performance
speed up MTK OOP vector usage
2 parents 988fde4 + 5bb1631 commit f2cbd38

File tree

2 files changed

+40
-31
lines changed

2 files changed

+40
-31
lines changed

README.md

+30-26
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Each operation builds an `Operation` type, and thus `eqs` is an array of
4444
analyzed by other programs. We can turn this into a `ODESystem` via:
4545

4646
```julia
47-
de = ODESystem(eqs)
47+
de = ODESystem(eqs, t, [x,y,z], [σ,ρ,β])
4848
```
4949

5050
where we tell it the variable types and ordering in the first version, or let it
@@ -54,49 +54,53 @@ generated code via:
5454

5555
```julia
5656
using MacroTools
57-
myode_oop = generate_function(de, [x,y,z], [σ,ρ,β])[1] # first one is the out-of-place function
57+
myode_oop = generate_function(de)[1] # first one is the out-of-place function
5858
MacroTools.striplines(myode_oop) # print without line numbers
5959

6060
#=
6161
:((u, p, t)->begin
62-
@inbounds begin
63-
X = @inbounds(begin
64-
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
65-
(σ * (y - x), x * (ρ - z) - y, x * y - β * z)
66-
end
67-
end)
68-
end
62+
if u isa Array
63+
return @inbounds(begin
64+
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
65+
[σ * (y - x), x * (ρ - z) - y, x * y - β * z]
66+
end
67+
end)
68+
else
69+
X = @inbounds(begin
70+
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
71+
(σ * (y - x), x * (ρ - z) - y, x * y - β * z)
72+
end
73+
end)
74+
end
6975
T = promote_type(map(typeof, X)...)
70-
convert.(T, X)
76+
map(T, X)
7177
construct = if u isa ModelingToolkit.StaticArrays.StaticArray
7278
ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X))
7379
else
7480
x->begin
75-
du = similar(u, T, 3)
76-
vec(du) .= x
77-
du
81+
convert(typeof(u), x)
7882
end
7983
end
8084
construct(X)
8185
end)
8286
=#
8387

84-
myode_iip = generate_function(de, [x,y,z], [σ,ρ,β])[2] # second one is the in-place function
88+
myode_iip = generate_function(de)[2] # second one is the in-place function
8589
MacroTools.striplines(myode_iip) # print without line numbers
8690

8791
#=
88-
(var"##MTIIPVar#409", u, p, t)->begin
89-
@inbounds begin
90-
@inbounds begin
91-
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
92-
var"##MTIIPVar#409"[1] = σ * (y - x)
93-
var"##MTIIPVar#409"[2] = x * (ρ - z) - y
94-
var"##MTIIPVar#409"[3] = x * y - β * z
95-
end
96-
end
97-
end
98-
nothing
99-
end
92+
:((var"##MTIIPVar#793", u, p, t)->begin
93+
@inbounds begin
94+
@inbounds begin
95+
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
96+
var"##MTIIPVar#793"[1] = σ * (y - x)
97+
var"##MTIIPVar#793"[2] = x * (ρ - z) - y
98+
var"##MTIIPVar#793"[3] = x * y - β * z
99+
end
100+
end
101+
end
102+
nothing
103+
end)
100104
=#
101105
```
102106

src/utils.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,26 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
4747
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
4848
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
4949

50-
sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
51-
let_expr = Expr(:let, var_eqs, sys_expr)
50+
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
51+
vector_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
52+
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
53+
vector_let_expr = Expr(:let, var_eqs, vector_sys_expr)
5254
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
55+
vector_bounds_block = checkbounds ? vector_let_expr : :(@inbounds begin $vector_let_expr end)
5356
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
5457

5558
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
5659

5760
oop_ex = :(
5861
($(fargs.args...),) -> begin
59-
@inbounds begin
62+
if $(fargs.args[1]) isa Array
63+
return $vector_bounds_block
64+
else
6065
X = $bounds_block
6166
end
6267
T = promote_type(map(typeof,X)...)
63-
convert.(T,X)
64-
construct = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, T, $(size(rhss)...)); vec(du) .= x; du)) : constructor)
68+
map(T,X)
69+
construct = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->convert(typeof(u),x)) : constructor)
6570
construct(X)
6671
end
6772
)

0 commit comments

Comments
 (0)