Skip to content

Commit dacd618

Browse files
Merge pull request #83 from SciML/abstractmatrix
do AbstractMatrix conversion before factorization
2 parents 8723422 + 0c2a193 commit dacd618

File tree

2 files changed

+19
-47
lines changed

2 files changed

+19
-47
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "1.2.4"
4+
version = "1.2.5"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/factorization.jl

+18-46
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717

1818
# Bad fallback: will fail if `A` is just a stand-in
1919
# This should instead just create the factorization type.
20-
init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = do_factorization(alg, A, b, u)
20+
init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = do_factorization(alg, convert(AbstractMatrix,A), b, u)
2121

2222
## LU Factorizations
2323

@@ -35,28 +35,24 @@ function LUFactorization()
3535
end
3636

3737
function do_factorization(alg::LUFactorization, A, b, u)
38-
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
39-
error("LU is not defined for $(typeof(A))")
40-
41-
if A isa DiffEqArrayOperator
42-
A = A.A
43-
end
38+
A = convert(AbstractMatrix,A)
4439
if A isa SparseMatrixCSC
45-
fact = lu(A, alg.pivot)
40+
return lu(A)
4641
else
4742
fact = lu!(A, alg.pivot)
4843
end
4944
return fact
5045
end
5146

52-
init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
47+
init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
5348

5449
# This could be a GenericFactorization perhaps?
5550
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
5651
reuse_symbolic::Bool = true
5752
end
5853

5954
function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
55+
A = convert(AbstractMatrix,A)
6056
zerobased = SparseArrays.getcolptr(A)[1] == 0
6157
res = SuiteSparse.UMFPACK.UmfpackLU(C_NULL, C_NULL, size(A, 1), size(A, 2),
6258
zerobased ? copy(SparseArrays.getcolptr(A)) : SuiteSparse.decrement(SparseArrays.getcolptr(A)),
@@ -67,9 +63,7 @@ function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters, abs
6763
end
6864

6965
function do_factorization(::UMFPACKFactorization, A, b, u)
70-
if A isa DiffEqArrayOperator
71-
A = A.A
72-
end
66+
A = convert(AbstractMatrix,A)
7367
if A isa SparseMatrixCSC
7468
return lu(A)
7569
else
@@ -79,9 +73,7 @@ end
7973

8074
function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization)
8175
A = cache.A
82-
if A isa DiffEqArrayOperator
83-
A = A.A
84-
end
76+
A = convert(AbstractMatrix,A)
8577
if cache.isfresh
8678
if cache.cacheval !== nothing && alg.reuse_symbolic
8779
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
@@ -103,13 +95,11 @@ Base.@kwdef struct KLUFactorization <: AbstractFactorization
10395
end
10496

10597
function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
106-
return KLU.KLUFactorization(A) # this takes care of the copy internally.
98+
return KLU.KLUFactorization(convert(AbstractMatrix,A)) # this takes care of the copy internally.
10799
end
108100

109101
function do_factorization(::KLUFactorization, A, b, u)
110-
if A isa DiffEqArrayOperator
111-
A = A.A
112-
end
102+
A = convert(AbstractMatrix,A)
113103
if A isa SparseMatrixCSC
114104
return klu(A)
115105
else
@@ -119,9 +109,7 @@ end
119109

120110
function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization)
121111
A = cache.A
122-
if A isa DiffEqArrayOperator
123-
A = A.A
124-
end
112+
A = convert(AbstractMatrix,A)
125113
if cache.isfresh
126114
if cache.cacheval !== nothing && alg.reuse_symbolic
127115
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
@@ -159,12 +147,7 @@ function QRFactorization(inplace = true)
159147
end
160148

161149
function do_factorization(alg::QRFactorization, A, b, u)
162-
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
163-
error("QR is not defined for $(typeof(A))")
164-
165-
if A isa DiffEqArrayOperator
166-
A = A.A
167-
end
150+
A = convert(AbstractMatrix,A)
168151
if alg.inplace
169152
fact = qr!(A, alg.pivot)
170153
else
@@ -183,13 +166,7 @@ end
183166
SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
184167

185168
function do_factorization(alg::SVDFactorization, A, b, u)
186-
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
187-
error("SVD is not defined for $(typeof(A))")
188-
189-
if A isa DiffEqArrayOperator
190-
A = A.A
191-
end
192-
169+
A = convert(AbstractMatrix,A)
193170
fact = svd!(A; full = alg.full, alg = alg.alg)
194171
return fact
195172
end
@@ -204,18 +181,13 @@ GenericFactorization(;fact_alg = LinearAlgebra.factorize) =
204181
GenericFactorization(fact_alg)
205182

206183
function do_factorization(alg::GenericFactorization, A, b, u)
207-
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
208-
error("GenericFactorization is not defined for $(typeof(A))")
209-
210-
if A isa DiffEqArrayOperator
211-
A = A.A
212-
end
184+
A = convert(AbstractMatrix,A)
213185
fact = alg.fact_alg(A)
214186
return fact
215187
end
216188

217-
init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
218-
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
189+
init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
190+
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
219191

220192
init_cacheval(alg::GenericFactorization{typeof(lu)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
221193
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
@@ -245,13 +217,13 @@ end
245217
# Fallback, tries to make nonsingular and just factorizes
246218
# Try to never use it.
247219
function init_cacheval(alg::Union{QRFactorization,SVDFactorization,GenericFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
248-
newA = copy(A)
220+
newA = copy(convert(AbstractMatrix,A))
249221
fill!(newA,true)
250222
do_factorization(alg, newA, b, u)
251223
end
252224

253225
## RFLUFactorization
254226

255227
RFLUFactorization() = GenericFactorization(;fact_alg=RecursiveFactorization.lu!)
256-
init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
257-
init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
228+
init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
229+
init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))

0 commit comments

Comments
 (0)