Skip to content

Commit 8a65a34

Browse files
fix array fast path on matrices
1 parent 600ea21 commit 8a65a34

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

src/utils.jl

+13-4
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,28 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
5050
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
5151

5252
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
53-
vector_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
53+
54+
if rhss isa Matrix
55+
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,2)])
56+
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
57+
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
58+
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
59+
else # Vector
60+
arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
61+
end
62+
5463
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
55-
vector_let_expr = Expr(:let, var_eqs, vector_sys_expr)
64+
arr_let_expr = Expr(:let, var_eqs, arr_sys_expr)
5665
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
57-
vector_bounds_block = checkbounds ? vector_let_expr : :(@inbounds begin $vector_let_expr end)
66+
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
5867
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
5968

6069
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
6170

6271
oop_ex = :(
6372
($(fargs.args...),) -> begin
6473
if $(fargs.args[1]) isa Array
65-
return $vector_bounds_block
74+
return $arr_bounds_block
6675
else
6776
X = $bounds_block
6877
end

test/direct.jl

+15
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ end
2929
@test all(isequal.(ModelingToolkit.gradient(eqs[1],[x,y,z]),[σ * -1,σ,0]))
3030
@test all(isequal.(ModelingToolkit.hessian(eqs[1],[x,y,z]),0))
3131

32+
Joop,Jiip = eval.(ModelingToolkit.build_function(∂,[x,y,z],[σ,ρ,β],[t.op.name]))
33+
J = Joop([1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
34+
@test J isa Matrix
35+
J2 = copy(J)
36+
Jiip(J2,[1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
37+
@test J2 == J
38+
39+
∂3 = cat(∂,∂,dims=3)
40+
Joop,Jiip = eval.(ModelingToolkit.build_function(∂3,[x,y,z],[σ,ρ,β],[t.op.name]))
41+
J = Joop([1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
42+
@test size(J) == (3,3,2)
43+
J2 = copy(J)
44+
Jiip(J2,[1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
45+
@test J2 == J
46+
3247
# Function building
3348

3449
@parameters σ() ρ() β()

0 commit comments

Comments
 (0)