Skip to content

Commit 98a377d

Browse files
Merge pull request #436 from avik-pal/ap/defaults
Support StaticArrays Properly
2 parents 9787717 + 21559ed commit 98a377d

File tree

7 files changed

+78
-10
lines changed

7 files changed

+78
-10
lines changed

Project.toml

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.20.0"
4+
version = "2.20.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -26,6 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2626
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2727
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2828
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
29+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2930
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3031

3132
[weakdeps]
@@ -48,14 +49,14 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
4849
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
4950
LinearSolveCUDAExt = "CUDA"
5051
LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"]
52+
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
5153
LinearSolveHYPREExt = "HYPRE"
5254
LinearSolveIterativeSolversExt = "IterativeSolvers"
5355
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
5456
LinearSolveKrylovKitExt = "KrylovKit"
5557
LinearSolveMetalExt = "Metal"
5658
LinearSolvePardisoExt = "Pardiso"
5759
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
58-
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
5960

6061
[compat]
6162
Aqua = "0.8"
@@ -101,6 +102,8 @@ SciMLOperators = "0.3"
101102
Setfield = "1"
102103
SparseArrays = "1.9"
103104
Sparspak = "0.3.6"
105+
StaticArraysCore = "1"
106+
StaticArrays = "1"
104107
Test = "1"
105108
UnPack = "1"
106109
julia = "1.9"
@@ -126,7 +129,8 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
126129
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
127130
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
128131
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
132+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
129133
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
130134

131135
[targets]
132-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices"]
136+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"]

src/LinearSolve.jl

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ PrecompileTools.@recompile_invalidations begin
2626
using Requires
2727
import InteractiveUtils
2828

29+
import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
30+
2931
using LinearAlgebra: BlasInt, LU
3032
using LinearAlgebra.LAPACK: require_one_based_indexing,
3133
chkfinite, chkstride1,

src/default.jl

+7
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ function defaultalg(A, b, assump::OperatorAssumptions)
175175
DefaultAlgorithmChoice.LUFactorization
176176
end
177177

178+
# For static arrays GMRES allocates a lot. Use factorization
179+
elseif A isa StaticArray
180+
DefaultAlgorithmChoice.LUFactorization
181+
178182
# This catches the cases where a factorization overload could exist
179183
# For example, BlockBandedMatrix
180184
elseif A !== nothing && ArrayInterface.isstructured(A)
@@ -186,6 +190,9 @@ function defaultalg(A, b, assump::OperatorAssumptions)
186190
end
187191
elseif assump.condition === OperatorCondition.WellConditioned
188192
DefaultAlgorithmChoice.NormalCholeskyFactorization
193+
elseif A isa StaticArray
194+
# Static Array doesn't have QR() \ b defined
195+
DefaultAlgorithmChoice.SVDFactorization
189196
elseif assump.condition === OperatorCondition.IllConditioned
190197
if is_underdetermined(A)
191198
# Underdetermined

src/factorization.jl

+36-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ end
1010

1111
_ldiv!(x, A, b) = ldiv!(x, A, b)
1212

13+
_ldiv!(x, A, b::SVector) = (x .= A \ b)
14+
_ldiv!(::SVector, A, b::SVector) = (A \ b)
15+
_ldiv!(::SVector, A, b) = (A \ b)
16+
1317
function _ldiv!(x::Vector, A::Factorization, b::Vector)
1418
# workaround https://github.com/JuliaLang/julia/issues/43507
1519
# Fallback if working with non-square matrices
@@ -74,6 +78,8 @@ function do_factorization(alg::LUFactorization, A, b, u)
7478
if A isa AbstractSparseMatrixCSC
7579
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
7680
check = false)
81+
elseif !ArrayInterface.can_setindex(typeof(A))
82+
fact = lu(A, alg.pivot, check = false)
7783
else
7884
fact = lu!(A, alg.pivot, check = false)
7985
end
@@ -136,10 +142,14 @@ end
136142

137143
function do_factorization(alg::QRFactorization, A, b, u)
138144
A = convert(AbstractMatrix, A)
139-
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
140-
fact = qr!(A, alg.pivot)
145+
if ArrayInterface.can_setindex(typeof(A))
146+
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
147+
fact = qr!(A, alg.pivot)
148+
else
149+
fact = qr(A) # CUDA.jl does not allow other args!
150+
end
141151
else
142-
fact = qr(A) # CUDA.jl does not allow other args!
152+
fact = qr(A, alg.pivot)
143153
end
144154
return fact
145155
end
@@ -202,6 +212,16 @@ function do_factorization(alg::CholeskyFactorization, A, b, u)
202212
return fact
203213
end
204214

215+
function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, Pr,
216+
maxiters::Int, abstol, reltol, verbose::Bool,
217+
assumptions::OperatorAssumptions) where {S1, S2}
218+
# StaticArrays doesn't have the pivot argument. Prevent generic fallback.
219+
# CholeskyFactorization is part of DefaultLinearSolver, so it is possible that `A` is
220+
# not Hermitian.
221+
(!issquare(A) || !ishermitian(A)) && return nothing
222+
cholesky(A)
223+
end
224+
205225
function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
206226
maxiters::Int, abstol, reltol, verbose::Bool,
207227
assumptions::OperatorAssumptions)
@@ -276,11 +296,15 @@ SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
276296

277297
function do_factorization(alg::SVDFactorization, A, b, u)
278298
A = convert(AbstractMatrix, A)
279-
fact = svd!(A; full = alg.full, alg = alg.alg)
299+
if ArrayInterface.can_setindex(typeof(A))
300+
fact = svd!(A; alg.full, alg.alg)
301+
else
302+
fact = svd(A; alg.full)
303+
end
280304
return fact
281305
end
282306

283-
function init_cacheval(alg::SVDFactorization, A::Matrix, b, u, Pl, Pr,
307+
function init_cacheval(alg::SVDFactorization, A::Union{Matrix, SMatrix}, b, u, Pl, Pr,
284308
maxiters::Int, abstol, reltol, verbose::Bool,
285309
assumptions::OperatorAssumptions)
286310
ArrayInterface.svd_instance(convert(AbstractMatrix, A))
@@ -882,7 +906,8 @@ end
882906
function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
883907
maxiters::Int, abstol, reltol, verbose::Bool,
884908
assumptions::OperatorAssumptions)
885-
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
909+
A_ = convert(AbstractMatrix, A)
910+
ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot)
886911
end
887912

888913
function init_cacheval(alg::NormalCholeskyFactorization,
@@ -1128,6 +1153,11 @@ function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int,
11281153
end
11291154
end
11301155

1156+
function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr,
1157+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1158+
nothing
1159+
end
1160+
11311161
function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs...)
11321162
A = cache.A
11331163
if cache.isfresh

test/default_algs.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ solve(prob)
5050
A = rand(4, 4)
5151
b = rand(4)
5252
prob = LinearProblem(A, b)
53-
JET.@test_opt init(prob, nothing)
53+
VERSION v"1.10-" && JET.@test_opt init(prob, nothing)
5454
JET.@test_opt solve(prob, LUFactorization())
5555
JET.@test_opt solve(prob, GenericLUFactorization())
5656
@test_skip JET.@test_opt solve(prob, QRFactorization())

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ if GROUP == "All" || GROUP == "Core"
1717
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
1818
@time @safetestset "Traits" include("traits.jl")
1919
@time @safetestset "BandedMatrices" include("banded.jl")
20+
@time @safetestset "Static Arrays" include("static_arrays.jl")
2021
end
2122

2223
if GROUP == "LinearSolveCUDA"

test/static_arrays.jl

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using LinearSolve, StaticArrays, LinearAlgebra
2+
3+
A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I))
4+
b = SVector{5}(rand(5))
5+
6+
for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(),
7+
KrylovJL_GMRES())
8+
sol = solve(LinearProblem(A, b), alg)
9+
@test norm(A * sol .- b) < 1e-10
10+
end
11+
12+
A = SMatrix{7, 5}(rand(7, 5))
13+
b = SVector{7}(rand(7))
14+
15+
for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
16+
@test_nowarn solve(LinearProblem(A, b), alg)
17+
end
18+
19+
A = SMatrix{5, 7}(rand(5, 7))
20+
b = SVector{5}(rand(5))
21+
22+
for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
23+
@test_nowarn solve(LinearProblem(A, b), alg)
24+
end

0 commit comments

Comments
 (0)