Skip to content

Commit 8f21ca3

Browse files
Merge pull request #160 from SciML/defaults
Many tweaks to the default algorithm choice
2 parents 2e1b5ed + ce5c593 commit 8f21ca3

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/default.jl

+20-9
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ function defaultalg(A, b)
1414
ArrayInterfaceCore.can_setindex(b)
1515
if length(b) <= 10
1616
alg = GenericLUFactorization()
17-
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
17+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
18+
eltype(A) <: Union{Float32, Float64}
1819
alg = RFLUFactorization()
20+
#elseif A === nothing || A isa Matrix
21+
# alg = FastLUFactorization()
1922
else
2023
alg = LUFactorization()
2124
end
@@ -30,7 +33,7 @@ function defaultalg(A, b)
3033
elseif A isa SymTridiagonal
3134
alg = GenericFactorization(; fact_alg = ldlt!)
3235
elseif A isa SparseMatrixCSC
33-
alg = UMFPACKFactorization()
36+
alg = KLUFactorization()
3437

3538
# This catches the cases where a factorization overload could exist
3639
# For example, BlockBandedMatrix
@@ -40,7 +43,7 @@ function defaultalg(A, b)
4043
# This catches the case where A is a CuMatrix
4144
# Which does not have LU fully defined
4245
elseif A isa GPUArraysCore.AbstractGPUArray || b isa GPUArraysCore.AbstractGPUArray
43-
alg = QRFactorization(false)
46+
alg = LUFactorization()
4447

4548
# Not factorizable operator, default to only using A*x
4649
else
@@ -68,9 +71,13 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
6871
if length(b) <= 10
6972
alg = GenericLUFactorization()
7073
SciMLBase.solve(cache, alg, args...; kwargs...)
71-
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
74+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
75+
eltype(A) <: Union{Float32, Float64}
7276
alg = RFLUFactorization()
7377
SciMLBase.solve(cache, alg, args...; kwargs...)
78+
#elseif A isa Matrix
79+
# alg = FastLUFactorization()
80+
# SciMLBase.solve(cache, alg, args...; kwargs...)
7481
else
7582
alg = LUFactorization()
7683
SciMLBase.solve(cache, alg, args...; kwargs...)
@@ -89,7 +96,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
8996
alg = GenericFactorization(; fact_alg = ldlt!)
9097
SciMLBase.solve(cache, alg, args...; kwargs...)
9198
elseif A isa SparseMatrixCSC
92-
alg = UMFPACKFactorization()
99+
alg = KLUFactorization()
93100
SciMLBase.solve(cache, alg, args...; kwargs...)
94101

95102
# This catches the cases where a factorization overload could exist
@@ -101,7 +108,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
101108
# This catches the case where A is a CuMatrix
102109
# Which does not have LU fully defined
103110
elseif A isa GPUArraysCore.AbstractGPUArray
104-
alg = QRFactorization(false)
111+
alg = LUFactorization()
105112
SciMLBase.solve(cache, alg, args...; kwargs...)
106113

107114
# Not factorizable operator, default to only using A*x
@@ -126,9 +133,13 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
126133
if length(b) <= 10
127134
alg = GenericLUFactorization()
128135
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
129-
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
136+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
137+
eltype(A) <: Union{Float32, Float64}
130138
alg = RFLUFactorization()
131139
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
140+
#elseif A isa Matrix
141+
# alg = FastLUFactorization()
142+
# init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
132143
else
133144
alg = LUFactorization()
134145
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
@@ -147,7 +158,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
147158
alg = GenericFactorization(; fact_alg = ldlt!)
148159
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
149160
elseif A isa SparseMatrixCSC
150-
alg = UMFPACKFactorization()
161+
alg = KLUFactorization()
151162
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
152163

153164
# This catches the cases where a factorization overload could exist
@@ -159,7 +170,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
159170
# This catches the case where A is a CuMatrix
160171
# Which does not have LU fully defined
161172
elseif A isa GPUArraysCore.AbstractGPUArray
162-
alg = QRFactorization(false)
173+
alg = LUFactorization()
163174
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
164175

165176
# Not factorizable operator, default to only using A*x

0 commit comments

Comments
 (0)