Skip to content

Commit fbcbef4

Browse files
optimize sparse matrix formation
1 parent 8a65a34 commit fbcbef4

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

src/ModelingToolkit.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export modelingtoolkitize
99

1010

1111
using DiffEqBase, Distributed
12-
using StaticArrays, LinearAlgebra
12+
using StaticArrays, LinearAlgebra, SparseArrays
1313
using Latexify
1414

1515
using MacroTools

src/utils.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
4646
fname = gensym(:ModelingToolkitFunction)
4747

4848
X = gensym(:MTIIPVar)
49-
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
49+
if rhss isa SparseMatrixCSC
50+
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss.nzval)]
51+
else
52+
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
53+
end
54+
5055
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
5156

5257
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
@@ -56,6 +61,9 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
5661
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
5762
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
5863
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
64+
elseif rhss isa SparseMatrixCSC
65+
vector_form = build_expr(:vect, [conv(rhs) for rhs nonzeros(rhss)])
66+
arr_sys_expr = :(SparseMatrixCSC{eltype(u),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
5967
else # Vector
6068
arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
6169
end

test/direct.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, StaticArrays, LinearAlgebra
1+
using ModelingToolkit, StaticArrays, LinearAlgebra, SparseArrays
22
using DiffEqBase
33
using Test
44

@@ -44,6 +44,15 @@ J2 = copy(J)
4444
Jiip(J2,[1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
4545
@test J2 == J
4646

47+
s∂ = sparse(∂)
48+
@test nnz(s∂) == 8
49+
Joop,Jiip = eval.(ModelingToolkit.build_function(s∂,[x,y,z],[σ,ρ,β],[t.op.name]))
50+
J = Joop([1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
51+
length(nonzeros(s∂)) == 8
52+
J2 = copy(J)
53+
Jiip(J2,[1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
54+
@test J2 == J
55+
4756
# Function building
4857

4958
@parameters σ() ρ() β()

0 commit comments

Comments
 (0)