Skip to content

Commit dbcb972

Browse files
Merge pull request #550 from j-fu/jf/PardisoAbstractSparse
Fix Pardiso extension for the case of an AbstractSparseMatrixCSC
2 parents c20ca2d + c88b634 commit dbcb972

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

ext/LinearSolvePardisoExt.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,11 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs
134134
if cache.isfresh
135135
phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT
136136
Pardiso.set_phase!(cache.cacheval, phase)
137-
Pardiso.pardiso(cache.cacheval, A, eltype(A)[])
137+
Pardiso.pardiso(cache.cacheval, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), eltype(A)[])
138138
cache.isfresh = false
139139
end
140140
Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
141-
Pardiso.pardiso(cache.cacheval, u, A, b)
142-
141+
Pardiso.pardiso(cache.cacheval, u, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), b)
143142
return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
144143
end
145144

src/extension_algs.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ All values default to `nothing` and the solver internally determines the values
217217
given the input types, and these keyword arguments are only for overriding the
218218
default handling process. This should not be required by most users.
219219
"""
220-
struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
220+
struct PardisoJL{T1, T2} <: AbstractSparseFactorization
221221
nprocs::Union{Int, Nothing}
222222
solver_type::T1
223223
matrix_type::T2

test/pardiso/pardiso.jl

+37
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,40 @@ for solver in solvers
177177
@test Pardiso.get_iparm(solver, i) == iparm[i][2]
178178
end
179179
end
180+
181+
@testset "AbstractSparseMatrixCSC" begin
182+
struct MySparseMatrixCSC2{Tv, Ti} <: SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}
183+
csc::SparseMatrixCSC{Tv, Ti}
184+
end
185+
186+
Base.size(m::MySparseMatrixCSC2) = size(m.csc)
187+
SparseArrays.getcolptr(m::MySparseMatrixCSC2) = SparseArrays.getcolptr(m.csc)
188+
SparseArrays.rowvals(m::MySparseMatrixCSC2) = SparseArrays.rowvals(m.csc)
189+
SparseArrays.nonzeros(m::MySparseMatrixCSC2) = SparseArrays.nonzeros(m.csc)
190+
191+
for alg in algs
192+
N = 100
193+
u0 = ones(N)
194+
A0 = spdiagm(1 => -ones(N - 1), 0 => fill(10.0, N), -1 => -ones(N - 1))
195+
b0 = A0 * u0
196+
B0 = MySparseMatrixCSC2(A0)
197+
A1 = spdiagm(1 => -ones(N - 1), 0 => fill(100.0, N), -1 => -ones(N - 1))
198+
b1=A1*u0
199+
B1= MySparseMatrixCSC2(A1)
200+
201+
202+
pr = LinearProblem(B0, b0)
203+
# test default algorithn
204+
u=solve(pr,alg)
205+
@test norm(u - u0, Inf) < 1.0e-13
206+
207+
# test factorization with reinit!
208+
pr = LinearProblem(B0, b0)
209+
cache=init(pr,alg)
210+
u=solve!(cache)
211+
@test norm(u - u0, Inf) < 1.0e-13
212+
reinit!(cache; A=B1, b=b1)
213+
u=solve!(cache)
214+
@test norm(u - u0, Inf) < 1.0e-13
215+
end
216+
end

0 commit comments

Comments
 (0)