Skip to content

Commit 4ee3e1e

Browse files
Merge pull request #400 from avik-pal/ap/needs_square_A
Add `needs_square_A` trait
2 parents a880003 + 12c7ed9 commit 4ee3e1e

File tree

5 files changed

+84
-6
lines changed

5 files changed

+84
-6
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.11.1"
4+
version = "2.12.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/LinearSolve.jl

+29
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
6262
needs_concrete_A(alg::AbstractSolveFunction) = false
6363

6464
# Util
65+
is_underdetermined(x) = false
66+
is_underdetermined(A::AbstractMatrix) = size(A, 1) < size(A, 2)
67+
is_underdetermined(A::AbstractSciMLOperator) = size(A, 1) < size(A, 2)
6568

6669
_isidentity_struct(A) = false
6770
_isidentity_struct::Number) = isone(λ)
@@ -96,6 +99,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
9699
NormalCholeskyFactorization
97100
AppleAccelerateLUFactorization
98101
MKLLUFactorization
102+
QRFactorizationPivoted
99103
end
100104

101105
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
@@ -143,6 +147,31 @@ end
143147
include("factorization_sparse.jl")
144148
end
145149

150+
# Solver Specific Traits
151+
## Needs Square Matrix
152+
"""
153+
needs_square_A(alg)
154+
155+
Returns `true` if the algorithm requires a square matrix.
156+
"""
157+
needs_square_A(::Nothing) = false # Linear Solve automatically will use a correct alg!
158+
needs_square_A(alg::SciMLLinearSolveAlgorithm) = true
159+
for alg in (:QRFactorization, :FastQRFactorization, :NormalCholeskyFactorization,
160+
:NormalBunchKaufmanFactorization)
161+
@eval needs_square_A(::$(alg)) = false
162+
end
163+
for kralg in (Krylov.lsmr!, Krylov.craigmr!)
164+
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
165+
end
166+
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
167+
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
168+
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
169+
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
170+
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
171+
:MKLLUFactorization, :MetalLUFactorization)
172+
@eval needs_square_A(::$(alg)) = true
173+
end
174+
146175
const IS_OPENBLAS = Ref(true)
147176
isopenblas() = IS_OPENBLAS[]
148177

src/default.jl

+28-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16, T17, T18}
3+
T13, T14, T15, T16, T17, T18, T19}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
1919
NormalCholeskyFactorization::T16
2020
AppleAccelerateLUFactorization::T17
2121
MKLLUFactorization::T18
22+
QRFactorizationPivoted::T19
2223
end
2324

2425
# Legacy fallback
@@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
168169
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
169170
eltype(A) <: Union{Float32, Float64})
170171
DefaultAlgorithmChoice.RFLUFactorization
171-
#elseif A === nothing || A isa Matrix
172-
# alg = FastLUFactorization()
172+
#elseif A === nothing || A isa Matrix
173+
# alg = FastLUFactorization()
173174
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
174175
eltype(A) <: Union{Float32, Float64})
175176
DefaultAlgorithmChoice.MKLLUFactorization
@@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
199200
elseif assump.condition === OperatorCondition.WellConditioned
200201
DefaultAlgorithmChoice.NormalCholeskyFactorization
201202
elseif assump.condition === OperatorCondition.IllConditioned
202-
DefaultAlgorithmChoice.QRFactorization
203+
if is_underdetermined(A)
204+
# Underdetermined
205+
DefaultAlgorithmChoice.QRFactorizationPivoted
206+
else
207+
DefaultAlgorithmChoice.QRFactorization
208+
end
203209
elseif assump.condition === OperatorCondition.VeryIllConditioned
204-
DefaultAlgorithmChoice.QRFactorization
210+
if is_underdetermined(A)
211+
# Underdetermined
212+
DefaultAlgorithmChoice.QRFactorizationPivoted
213+
else
214+
DefaultAlgorithmChoice.QRFactorization
215+
end
205216
elseif assump.condition === OperatorCondition.SuperIllConditioned
206217
DefaultAlgorithmChoice.SVDFactorization
207218
else
@@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol)
247258
NormalCholeskyFactorization()
248259
elseif alg === :AppleAccelerateLUFactorization
249260
AppleAccelerateLUFactorization()
261+
elseif alg === :QRFactorizationPivoted
262+
@static if VERSION v"1.7beta"
263+
QRFactorization(ColumnNorm())
264+
else
265+
QRFactorization(Val(true))
266+
end
250267
else
251268
error("Algorithm choice symbol $alg not allowed in the default")
252269
end
@@ -311,6 +328,12 @@ function defaultalg_symbol(::Type{T}) where {T}
311328
end
312329
defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization
313330

331+
@static if VERSION >= v"1.7"
332+
defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted
333+
else
334+
defaultalg_symbol(::Type{<:QRFactorization{Val{true}}}) = :QRFactorizationPivoted
335+
end
336+
314337
"""
315338
if alg.alg === DefaultAlgorithmChoice.LUFactorization
316339
SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...))

src/factorization.jl

+10
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ function QRFactorization(inplace = true)
158158
QRFactorization(pivot, 16, inplace)
159159
end
160160

161+
@static if VERSION v"1.7beta"
162+
function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = true)
163+
QRFactorization(pivot, 16, inplace)
164+
end
165+
else
166+
function QRFactorization(pivot::Val, inplace::Bool = true)
167+
QRFactorization(pivot, 16, inplace)
168+
end
169+
end
170+
161171
function do_factorization(alg::QRFactorization, A, b, u)
162172
A = convert(AbstractMatrix, A)
163173
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)

test/nonsquare.jl

+16
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ b = rand(m)
88
prob = LinearProblem(A, b)
99
res = A \ b
1010
@test solve(prob).u res
11+
@test !LinearSolve.needs_square_A(QRFactorization())
1112
@test solve(prob, QRFactorization()) res
13+
@test !LinearSolve.needs_square_A(FastQRFactorization())
14+
@test solve(prob, FastQRFactorization()) res
15+
@test !LinearSolve.needs_square_A(KrylovJL_LSMR())
1216
@test solve(prob, KrylovJL_LSMR()) res
1317

1418
A = sprand(m, n, 0.5)
@@ -23,6 +27,7 @@ A = sprand(n, m, 0.5)
2327
b = rand(n)
2428
prob = LinearProblem(A, b)
2529
res = Matrix(A) \ b
30+
@test !LinearSolve.needs_square_A(KrylovJL_CRAIGMR())
2631
@test solve(prob, KrylovJL_CRAIGMR()) res
2732

2833
A = sprandn(1000, 100, 0.1)
@@ -35,7 +40,9 @@ A = randn(1000, 100)
3540
b = randn(1000)
3641
@test isapprox(solve(LinearProblem(A, b)).u, Symmetric(A' * A) \ (A' * b))
3742
solve(LinearProblem(A, b)).u;
43+
@test !LinearSolve.needs_square_A(NormalCholeskyFactorization())
3844
solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
45+
@test !LinearSolve.needs_square_A(NormalBunchKaufmanFactorization())
3946
solve(LinearProblem(A, b), (LinearSolve.NormalBunchKaufmanFactorization())).u;
4047
solve(LinearProblem(A, b),
4148
assumptions = (OperatorAssumptions(false;
@@ -49,3 +56,12 @@ solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
4956
solve(LinearProblem(A, b),
5057
assumptions = (OperatorAssumptions(false;
5158
condition = OperatorCondition.WellConditioned))).u;
59+
60+
# Underdetermined
61+
m, n = 2, 3
62+
63+
A = rand(m, n)
64+
b = rand(m)
65+
prob = LinearProblem(A, b)
66+
res = A \ b
67+
@test solve(prob).u res

0 commit comments

Comments
 (0)