Skip to content

Commit 834b255

Browse files
make sure to init the default algorithm and move first factorization
1 parent 6e4fa5a commit 834b255

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

src/common.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function set_cacheval(cache::LinearCache, alg_cache)
6565
return cache
6666
end
6767

68-
init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing
68+
init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
6969

7070
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)
7171

src/default.jl

+55
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,58 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
5656
SciMLBase.solve(cache, alg, args...; kwargs...)
5757
end
5858
end
59+
60+
function init_cacheval(alg::Nothing, A, b, u)
61+
if A isa DiffEqArrayOperator
62+
A = A.A
63+
end
64+
65+
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
66+
# it makes sense according to the benchmarks, which is dependent on
67+
# whether MKL or OpenBLAS is being used
68+
if A isa Matrix
69+
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
70+
ArrayInterface.can_setindex(b) && (size(A,1) <= 100 ||
71+
(isopenblas() && size(A,1) <= 500)
72+
)
73+
alg = RFLUFactorization()
74+
init_cacheval(alg, A, b, u)
75+
else
76+
alg = LUFactorization()
77+
init_cacheval(alg, A, b, u)
78+
end
79+
80+
# These few cases ensure the choice is optimal without the
81+
# dynamic dispatching of factorize
82+
elseif A isa Tridiagonal
83+
alg = GenericFactorization(;fact_alg=lu!)
84+
init_cacheval(alg, A, b, u)
85+
elseif A isa SymTridiagonal
86+
alg = GenericFactorization(;fact_alg=ldlt!)
87+
init_cacheval(alg, A, b, u)
88+
elseif A isa SparseMatrixCSC
89+
alg = UMFPACKFactorization()
90+
init_cacheval(alg, A, b, u)
91+
92+
# This catches the cases where a factorization overload could exist
93+
# For example, BlockBandedMatrix
94+
elseif ArrayInterface.isstructured(A)
95+
alg = GenericFactorization()
96+
init_cacheval(alg, A, b, u)
97+
98+
# This catches the case where A is a CuMatrix
99+
# Which does not have LU fully defined
100+
elseif !(A isa AbstractDiffEqOperator)
101+
alg = QRFactorization()
102+
init_cacheval(alg, A, b, u)
103+
104+
# Not factorizable operator, default to only using A*x
105+
# IterativeSolvers is faster on CPU but not GPU-compatible
106+
elseif cache.u isa Array
107+
alg = IterativeSolversJL_GMRES()
108+
init_cacheval(alg, A, b, u)
109+
else
110+
alg = KrylovJL_GMRES()
111+
init_cacheval(alg, A, b, u)
112+
end
113+
end

0 commit comments

Comments
 (0)