Skip to content

Commit 873004a

Browse files
fix initializations
1 parent 571058c commit 873004a

File tree

5 files changed

+35
-54
lines changed

5 files changed

+35
-54
lines changed

src/common.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
9191
Pl = (Pl !== nothing) ? Pl : Identity()
9292
Pr = (Pr !== nothing) ? Pr : Identity()
9393

94-
cacheval = init_cacheval(alg, A, b, u0)
94+
cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose)
9595
isfresh = true
9696
Tc = typeof(cacheval)
9797

src/default.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
5757
end
5858
end
5959

60-
function init_cacheval(alg::Nothing, A, b, u)
60+
function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
6161
if A isa DiffEqArrayOperator
6262
A = A.A
6363
end
@@ -71,43 +71,43 @@ function init_cacheval(alg::Nothing, A, b, u)
7171
(isopenblas() && size(A,1) <= 500)
7272
)
7373
alg = RFLUFactorization()
74-
init_cacheval(alg, A, b, u)
74+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
7575
else
7676
alg = LUFactorization()
77-
init_cacheval(alg, A, b, u)
77+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
7878
end
7979

8080
# These few cases ensure the choice is optimal without the
8181
# dynamic dispatching of factorize
8282
elseif A isa Tridiagonal
8383
alg = GenericFactorization(;fact_alg=lu!)
84-
init_cacheval(alg, A, b, u)
84+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
8585
elseif A isa SymTridiagonal
8686
alg = GenericFactorization(;fact_alg=ldlt!)
87-
init_cacheval(alg, A, b, u)
87+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
8888
elseif A isa SparseMatrixCSC
8989
alg = UMFPACKFactorization()
90-
init_cacheval(alg, A, b, u)
90+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
9191

9292
# This catches the cases where a factorization overload could exist
9393
# For example, BlockBandedMatrix
9494
elseif ArrayInterface.isstructured(A)
9595
alg = GenericFactorization()
96-
init_cacheval(alg, A, b, u)
96+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
9797

9898
# This catches the case where A is a CuMatrix
9999
# Which does not have LU fully defined
100100
elseif !(A isa AbstractDiffEqOperator)
101101
alg = QRFactorization()
102-
init_cacheval(alg, A, b, u)
102+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
103103

104104
# Not factorizable operator, default to only using A*x
105105
# IterativeSolvers is faster on CPU but not GPU-compatible
106106
elseif cache.u isa Array
107107
alg = IterativeSolversJL_GMRES()
108-
init_cacheval(alg, A, b, u)
108+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
109109
else
110110
alg = KrylovJL_GMRES()
111-
init_cacheval(alg, A, b, u)
111+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
112112
end
113113
end

src/factorization.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010

1111
# Bad fallback: will fail if `A` is just a stand-in
1212
# This should instead just create the factorization type.
13-
init_cacheval(alg::AbstractFactorization, A, b, u) = do_factorization(alg, A, b, u)
13+
init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = do_factorization(alg, A, b, u)
1414

1515
## LU Factorizations
1616

@@ -38,7 +38,7 @@ function do_factorization(alg::LUFactorization, A, b, u)
3838
return fact
3939
end
4040

41-
init_cacheval(alg::LUFactorization, A, b, u) = ArrayInterface.lu_instance(A)
41+
init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
4242

4343
# This could be a GenericFactorization perhaps?
4444
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization

src/iterative_wrappers.jl

+9-16
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function get_KrylovJL_solver(KrylovAlg)
6969
return KS
7070
end
7171

72-
function init_cacheval(alg::KrylovJL, A, b, u)
72+
function init_cacheval(alg::KrylovJL, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
7373

7474
KS = get_KrylovJL_solver(alg.KrylovAlg)
7575

@@ -101,7 +101,7 @@ end
101101

102102
function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
103103
if cache.isfresh
104-
solver = init_cacheval(alg, cache.A, cache.b, cache.u)
104+
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, cache.maxiters, cache.abstol, cache.reltol, cache.verbose)
105105
cache = set_cacheval(cache, solver)
106106
end
107107

@@ -183,20 +183,13 @@ IterativeSolversJL_MINRES(args...;kwargs...) =
183183
generate_iterator=IterativeSolvers.minres_iterable!,
184184
kwargs...)
185185

186-
function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
187-
@unpack A, b, u = cache
188-
189-
Pl = get_preconditioner(alg.Pl, cache.Pl)
190-
Pr = get_preconditioner(alg.Pr, cache.Pr)
191-
192-
abstol = cache.abstol
193-
reltol = cache.reltol
194-
maxiter = cache.maxiters
195-
verbose = cache.verbose
186+
function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
187+
Pl = get_preconditioner(alg.Pl, Pl)
188+
Pr = get_preconditioner(alg.Pr, Pr)
196189

197190
restart = (alg.gmres_restart == 0) ? min(20, size(A,1)) : alg.gmres_restart
198191

199-
kwargs = (abstol=abstol, reltol=reltol, maxiter=maxiter,
192+
kwargs = (abstol=abstol, reltol=reltol, maxiter=maxiters,
200193
alg.kwargs...)
201194

202195
iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
@@ -212,19 +205,19 @@ function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
212205
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
213206
alg.generate_iterator(u, A, b, alg.args...; Pl=Pl,
214207
abstol=abstol, reltol=reltol,
215-
max_mv_products=maxiter*2,
208+
max_mv_products=maxiters*2,
216209
alg.kwargs...)
217210
else # minres, qmr
218211
alg.generate_iterator(u, A, b, alg.args...;
219-
abstol=abstol, reltol=reltol, maxiter=maxiter,
212+
abstol=abstol, reltol=reltol, maxiter=maxiters,
220213
alg.kwargs...)
221214
end
222215
return iterable
223216
end
224217

225218
function SciMLBase.solve(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
226219
if cache.isfresh
227-
solver = init_cacheval(alg, cache)
220+
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, cache.maxiters, cache.abstol, cache.reltol, cache.verbose)
228221
cache = set_cacheval(cache, solver)
229222
end
230223

src/pardiso.jl

+13-25
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,17 @@ Base.@kwdef struct PardisoJL <: SciMLLinearSolveAlgorithm
22
nprocs::Union{Int, Nothing} = nothing
33
solver_type::Union{Int, Pardiso.Solver, Nothing} = nothing
44
matrix_type::Union{Int, Pardiso.MatrixType, Nothing} = nothing
5-
fact_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
6-
solve_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
7-
release_phase::Union{Int, Nothing} = nothing
85
iparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
96
dparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
107
end
118

12-
MKLPardisoFactorize(;kwargs...) = PardisoJL(;fact_phase=Pardiso.NUM_FACT,
13-
solve_phase=Pardiso.SOLVE_ITERATIVE_REFINE,
14-
kwargs...)
15-
MKLPardisoIterate(;kwargs...) = PardisoJL(;solve_phase=Pardiso.NUM_FACT_SOLVE_REFINE,
16-
kwargs...)
9+
MKLPardisoFactorize(;kwargs...) = PardisoJL(;kwargs...)
10+
MKLPardisoIterate(;kwargs...) = PardisoJL(;kwargs...)
1711

1812
# TODO schur complement functionality
1913

20-
function init_cacheval(alg::PardisoJL, cache::LinearCache)
21-
@unpack nprocs, solver_type, matrix_type, fact_phase, solve_phase, iparm, dparm = alg
22-
@unpack A, b, u = cache
14+
function init_cacheval(alg::PardisoJL, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
15+
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
2316

2417
if A isa DiffEqArrayOperator
2518
A = A.A
@@ -51,7 +44,7 @@ function init_cacheval(alg::PardisoJL, cache::LinearCache)
5144
error("Number type not supported by Pardiso")
5245
end
5346
end
54-
cache.verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
47+
verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
5548

5649
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
5750
if iparm !== nothing
@@ -66,15 +59,8 @@ function init_cacheval(alg::PardisoJL, cache::LinearCache)
6659
end
6760
end
6861

69-
if (fact_phase !== nothing) | (solve_phase !== nothing)
70-
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
71-
Pardiso.pardiso(solver, u, A, b)
72-
end
73-
74-
if fact_phase !== nothing
75-
Pardiso.set_phase!(solver, fact_phase)
76-
Pardiso.pardiso(solver, u, A, b)
77-
end
62+
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
63+
Pardiso.pardiso(solver, u, A, b)
7864

7965
return solver
8066
end
@@ -86,15 +72,17 @@ function SciMLBase.solve(cache::LinearCache, alg::PardisoJL; kwargs...)
8672
end
8773

8874
if cache.isfresh
89-
solver = init_cacheval(alg, cache)
90-
cache = set_cacheval(cache, solver)
75+
Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
76+
Pardiso.pardiso(cache.cacheval, cache.u, cache.A, cache.b)
9177
end
9278

93-
alg.solve_phase !== nothing && Pardiso.set_phase!(cache.cacheval, alg.solve_phase)
79+
Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
9480
Pardiso.pardiso(cache.cacheval, u, A, b)
95-
alg.release_phase !== nothing && Pardiso.set_phase!(cache.cacheval, alg.release_phase)
9681

9782
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
9883
end
9984

85+
# Add finalizer to release memory
86+
# Pardiso.set_phase!(cache.cacheval, Pardiso.RELEASE_ALL)
87+
10088
export PardisoJL, MKLPardisoFactorize, MKLPardisoIterate

0 commit comments

Comments
 (0)