Skip to content

Commit 0bc65ee

Browse files
Merge pull request #64 from SciML/init
make sure to init the default algorithm and move first factorization
2 parents 6e4fa5a + 873004a commit 0bc65ee

File tree

6 files changed

+101
-59
lines changed

6 files changed

+101
-59
lines changed

src/common.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function set_cacheval(cache::LinearCache, alg_cache)
6565
return cache
6666
end
6767

68-
init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing
68+
init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
6969

7070
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)
7171

@@ -91,9 +91,9 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
9191
Pl = (Pl !== nothing) ? Pl : Identity()
9292
Pr = (Pr !== nothing) ? Pr : Identity()
9393

94-
cacheval = init_cacheval(alg, A, b, u0)
95-
isfresh = cacheval === nothing
96-
Tc = isfresh ? Any : typeof(cacheval)
94+
cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose)
95+
isfresh = true
96+
Tc = typeof(cacheval)
9797

9898
A = alias_A ? A : deepcopy(A)
9999
b = alias_b ? b : deepcopy(b)

src/default.jl

+55
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,58 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
5656
SciMLBase.solve(cache, alg, args...; kwargs...)
5757
end
5858
end
59+
60+
function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
61+
if A isa DiffEqArrayOperator
62+
A = A.A
63+
end
64+
65+
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
66+
# it makes sense according to the benchmarks, which is dependent on
67+
# whether MKL or OpenBLAS is being used
68+
if A isa Matrix
69+
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
70+
ArrayInterface.can_setindex(b) && (size(A,1) <= 100 ||
71+
(isopenblas() && size(A,1) <= 500)
72+
)
73+
alg = RFLUFactorization()
74+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
75+
else
76+
alg = LUFactorization()
77+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
78+
end
79+
80+
# These few cases ensure the choice is optimal without the
81+
# dynamic dispatching of factorize
82+
elseif A isa Tridiagonal
83+
alg = GenericFactorization(;fact_alg=lu!)
84+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
85+
elseif A isa SymTridiagonal
86+
alg = GenericFactorization(;fact_alg=ldlt!)
87+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
88+
elseif A isa SparseMatrixCSC
89+
alg = UMFPACKFactorization()
90+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
91+
92+
# This catches the cases where a factorization overload could exist
93+
# For example, BlockBandedMatrix
94+
elseif ArrayInterface.isstructured(A)
95+
alg = GenericFactorization()
96+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
97+
98+
# This catches the case where A is a CuMatrix
99+
# Which does not have LU fully defined
100+
elseif !(A isa AbstractDiffEqOperator)
101+
alg = QRFactorization()
102+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
103+
104+
# Not factorizable operator, default to only using A*x
105+
# IterativeSolvers is faster on CPU but not GPU-compatible
106+
elseif cache.u isa Array
107+
alg = IterativeSolversJL_GMRES()
108+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
109+
else
110+
alg = KrylovJL_GMRES()
111+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
112+
end
113+
end

src/factorization.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization; kwargs...)
22
if cache.isfresh
3-
fact = init_cacheval(alg, cache.A, cache.b, cache.u)
3+
fact = do_factorization(alg, cache.A, cache.b, cache.u)
44
cache = set_cacheval(cache, fact)
55
end
66

77
y = ldiv!(cache.u, cache.cacheval, cache.b)
88
SciMLBase.build_linear_solution(alg,y,nothing,cache)
99
end
1010

11+
# Bad fallback: will fail if `A` is just a stand-in
12+
# This should instead just create the factorization type.
13+
init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = do_factorization(alg, A, b, u)
14+
1115
## LU Factorizations
1216

1317
struct LUFactorization{P} <: AbstractFactorization
@@ -23,7 +27,7 @@ function LUFactorization()
2327
LUFactorization(pivot)
2428
end
2529

26-
function init_cacheval(alg::LUFactorization, A, b, u)
30+
function do_factorization(alg::LUFactorization, A, b, u)
2731
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
2832
error("LU is not defined for $(typeof(A))")
2933

@@ -34,12 +38,14 @@ function init_cacheval(alg::LUFactorization, A, b, u)
3438
return fact
3539
end
3640

41+
init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
42+
3743
# This could be a GenericFactorization perhaps?
3844
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
3945
reuse_symbolic::Bool = true
4046
end
4147

42-
function init_cacheval(::UMFPACKFactorization, A, b, u)
48+
function do_factorization(::UMFPACKFactorization, A, b, u)
4349
if A isa AbstractDiffEqOperator
4450
A = A.A
4551
end
@@ -62,7 +68,7 @@ function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization)
6268
SuiteSparse.UMFPACK.umfpack_symbolic!(cache.cacheval)
6369
fact = lu!(cache.cacheval, A)
6470
else
65-
fact = init_cacheval(alg, A, cache.b, cache.u)
71+
fact = do_factorization(alg, A, cache.b, cache.u)
6672
end
6773
cache = set_cacheval(cache, fact)
6874
end
@@ -75,7 +81,7 @@ Base.@kwdef struct KLUFactorization <: AbstractFactorization
7581
reuse_symbolic::Bool = true
7682
end
7783

78-
function init_cacheval(::KLUFactorization, A, b, u)
84+
function do_factorization(::KLUFactorization, A, b, u)
7985
if A isa AbstractDiffEqOperator
8086
A = A.A
8187
end
@@ -98,7 +104,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization)
98104
KLU.klu_analyze!(cache.cacheval)
99105
fact = klu!(cache.cacheval, A)
100106
else
101-
fact = init_cacheval(alg, A, cache.b, cache.u)
107+
fact = do_factorization(alg, A, cache.b, cache.u)
102108
end
103109
cache = set_cacheval(cache, fact)
104110
end
@@ -123,7 +129,7 @@ function QRFactorization()
123129
QRFactorization(pivot, 16)
124130
end
125131

126-
function init_cacheval(alg::QRFactorization, A, b, u)
132+
function do_factorization(alg::QRFactorization, A, b, u)
127133
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
128134
error("QR is not defined for $(typeof(A))")
129135

@@ -143,7 +149,7 @@ end
143149

144150
SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
145151

146-
function init_cacheval(alg::SVDFactorization, A, b, u)
152+
function do_factorization(alg::SVDFactorization, A, b, u)
147153
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
148154
error("SVD is not defined for $(typeof(A))")
149155

@@ -164,7 +170,7 @@ end
164170
GenericFactorization(;fact_alg = LinearAlgebra.factorize) =
165171
GenericFactorization(fact_alg)
166172

167-
function init_cacheval(alg::GenericFactorization, A, b, u)
173+
function do_factorization(alg::GenericFactorization, A, b, u)
168174
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
169175
error("GenericFactorization is not defined for $(typeof(A))")
170176

src/iterative_wrappers.jl

+9-16
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function get_KrylovJL_solver(KrylovAlg)
6969
return KS
7070
end
7171

72-
function init_cacheval(alg::KrylovJL, A, b, u)
72+
function init_cacheval(alg::KrylovJL, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
7373

7474
KS = get_KrylovJL_solver(alg.KrylovAlg)
7575

@@ -101,7 +101,7 @@ end
101101

102102
function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
103103
if cache.isfresh
104-
solver = init_cacheval(alg, cache.A, cache.b, cache.u)
104+
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, cache.maxiters, cache.abstol, cache.reltol, cache.verbose)
105105
cache = set_cacheval(cache, solver)
106106
end
107107

@@ -183,20 +183,13 @@ IterativeSolversJL_MINRES(args...;kwargs...) =
183183
generate_iterator=IterativeSolvers.minres_iterable!,
184184
kwargs...)
185185

186-
function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
187-
@unpack A, b, u = cache
188-
189-
Pl = get_preconditioner(alg.Pl, cache.Pl)
190-
Pr = get_preconditioner(alg.Pr, cache.Pr)
191-
192-
abstol = cache.abstol
193-
reltol = cache.reltol
194-
maxiter = cache.maxiters
195-
verbose = cache.verbose
186+
function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
187+
Pl = get_preconditioner(alg.Pl, Pl)
188+
Pr = get_preconditioner(alg.Pr, Pr)
196189

197190
restart = (alg.gmres_restart == 0) ? min(20, size(A,1)) : alg.gmres_restart
198191

199-
kwargs = (abstol=abstol, reltol=reltol, maxiter=maxiter,
192+
kwargs = (abstol=abstol, reltol=reltol, maxiter=maxiters,
200193
alg.kwargs...)
201194

202195
iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
@@ -212,19 +205,19 @@ function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
212205
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
213206
alg.generate_iterator(u, A, b, alg.args...; Pl=Pl,
214207
abstol=abstol, reltol=reltol,
215-
max_mv_products=maxiter*2,
208+
max_mv_products=maxiters*2,
216209
alg.kwargs...)
217210
else # minres, qmr
218211
alg.generate_iterator(u, A, b, alg.args...;
219-
abstol=abstol, reltol=reltol, maxiter=maxiter,
212+
abstol=abstol, reltol=reltol, maxiter=maxiters,
220213
alg.kwargs...)
221214
end
222215
return iterable
223216
end
224217

225218
function SciMLBase.solve(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
226219
if cache.isfresh
227-
solver = init_cacheval(alg, cache)
220+
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, cache.maxiters, cache.abstol, cache.reltol, cache.verbose)
228221
cache = set_cacheval(cache, solver)
229222
end
230223

src/pardiso.jl

+13-25
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,17 @@ Base.@kwdef struct PardisoJL <: SciMLLinearSolveAlgorithm
22
nprocs::Union{Int, Nothing} = nothing
33
solver_type::Union{Int, Pardiso.Solver, Nothing} = nothing
44
matrix_type::Union{Int, Pardiso.MatrixType, Nothing} = nothing
5-
fact_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
6-
solve_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
7-
release_phase::Union{Int, Nothing} = nothing
85
iparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
96
dparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
107
end
118

12-
MKLPardisoFactorize(;kwargs...) = PardisoJL(;fact_phase=Pardiso.NUM_FACT,
13-
solve_phase=Pardiso.SOLVE_ITERATIVE_REFINE,
14-
kwargs...)
15-
MKLPardisoIterate(;kwargs...) = PardisoJL(;solve_phase=Pardiso.NUM_FACT_SOLVE_REFINE,
16-
kwargs...)
9+
MKLPardisoFactorize(;kwargs...) = PardisoJL(;kwargs...)
10+
MKLPardisoIterate(;kwargs...) = PardisoJL(;kwargs...)
1711

1812
# TODO schur complement functionality
1913

20-
function init_cacheval(alg::PardisoJL, cache::LinearCache)
21-
@unpack nprocs, solver_type, matrix_type, fact_phase, solve_phase, iparm, dparm = alg
22-
@unpack A, b, u = cache
14+
function init_cacheval(alg::PardisoJL, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
15+
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
2316

2417
if A isa DiffEqArrayOperator
2518
A = A.A
@@ -51,7 +44,7 @@ function init_cacheval(alg::PardisoJL, cache::LinearCache)
5144
error("Number type not supported by Pardiso")
5245
end
5346
end
54-
cache.verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
47+
verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
5548

5649
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
5750
if iparm !== nothing
@@ -66,15 +59,8 @@ function init_cacheval(alg::PardisoJL, cache::LinearCache)
6659
end
6760
end
6861

69-
if (fact_phase !== nothing) | (solve_phase !== nothing)
70-
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
71-
Pardiso.pardiso(solver, u, A, b)
72-
end
73-
74-
if fact_phase !== nothing
75-
Pardiso.set_phase!(solver, fact_phase)
76-
Pardiso.pardiso(solver, u, A, b)
77-
end
62+
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
63+
Pardiso.pardiso(solver, u, A, b)
7864

7965
return solver
8066
end
@@ -86,15 +72,17 @@ function SciMLBase.solve(cache::LinearCache, alg::PardisoJL; kwargs...)
8672
end
8773

8874
if cache.isfresh
89-
solver = init_cacheval(alg, cache)
90-
cache = set_cacheval(cache, solver)
75+
Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
76+
Pardiso.pardiso(cache.cacheval, cache.u, cache.A, cache.b)
9177
end
9278

93-
alg.solve_phase !== nothing && Pardiso.set_phase!(cache.cacheval, alg.solve_phase)
79+
Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
9480
Pardiso.pardiso(cache.cacheval, u, A, b)
95-
alg.release_phase !== nothing && Pardiso.set_phase!(cache.cacheval, alg.release_phase)
9681

9782
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
9883
end
9984

85+
# Add finalizer to release memory
86+
# Pardiso.set_phase!(cache.cacheval, Pardiso.RELEASE_ALL)
87+
10088
export PardisoJL, MKLPardisoFactorize, MKLPardisoIterate

test/runtests.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,24 @@ end
4343
y = solve(prob1)
4444
@test A1 * y b1
4545

46-
_prob = LinearProblem(SymTridiagonal(A1.A), b1; u0=x1)
46+
_prob = LinearProblem(SymTridiagonal(A1), b1; u0=x1)
4747
y = solve(_prob)
4848
@test A1 * y b1
4949

50-
_prob = LinearProblem(Tridiagonal(A1.A), b1; u0=x1)
50+
_prob = LinearProblem(Tridiagonal(A1), b1; u0=x1)
5151
y = solve(_prob)
5252
@test A1 * y b1
5353

54-
_prob = LinearProblem(Symmetric(A1.A), b1; u0=x1)
54+
_prob = LinearProblem(Symmetric(A1), b1; u0=x1)
5555
y = solve(_prob)
5656
@test A1 * y b1
5757

58-
_prob = LinearProblem(Hermitian(A1.A), b1; u0=x1)
58+
_prob = LinearProblem(Hermitian(A1), b1; u0=x1)
5959
y = solve(_prob)
6060
@test A1 * y b1
6161

6262

63-
_prob = LinearProblem(sparse(A1.A), b1; u0=x1)
63+
_prob = LinearProblem(sparse(A1), b1; u0=x1)
6464
y = solve(_prob)
6565
@test A1 * y b1
6666
end

0 commit comments

Comments
 (0)