Skip to content

Commit 571058c

Browse files
force type stability in init
1 parent 834b255 commit 571058c

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

src/common.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
9292
Pr = (Pr !== nothing) ? Pr : Identity()
9393

9494
cacheval = init_cacheval(alg, A, b, u0)
95-
isfresh = cacheval === nothing
96-
Tc = isfresh ? Any : typeof(cacheval)
95+
isfresh = true
96+
Tc = typeof(cacheval)
9797

9898
A = alias_A ? A : deepcopy(A)
9999
b = alias_b ? b : deepcopy(b)

src/factorization.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization; kwargs...)
22
if cache.isfresh
3-
fact = init_cacheval(alg, cache.A, cache.b, cache.u)
3+
fact = do_factorization(alg, cache.A, cache.b, cache.u)
44
cache = set_cacheval(cache, fact)
55
end
66

77
y = ldiv!(cache.u, cache.cacheval, cache.b)
88
SciMLBase.build_linear_solution(alg,y,nothing,cache)
99
end
1010

11+
# Bad fallback: will fail if `A` is just a stand-in
12+
# This should instead just create the factorization type.
13+
init_cacheval(alg::AbstractFactorization, A, b, u) = do_factorization(alg, A, b, u)
14+
1115
## LU Factorizations
1216

1317
struct LUFactorization{P} <: AbstractFactorization
@@ -23,7 +27,7 @@ function LUFactorization()
2327
LUFactorization(pivot)
2428
end
2529

26-
function init_cacheval(alg::LUFactorization, A, b, u)
30+
function do_factorization(alg::LUFactorization, A, b, u)
2731
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
2832
error("LU is not defined for $(typeof(A))")
2933

@@ -34,12 +38,14 @@ function init_cacheval(alg::LUFactorization, A, b, u)
3438
return fact
3539
end
3640

41+
init_cacheval(alg::LUFactorization, A, b, u) = ArrayInterface.lu_instance(A)
42+
3743
# This could be a GenericFactorization perhaps?
3844
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
3945
reuse_symbolic::Bool = true
4046
end
4147

42-
function init_cacheval(::UMFPACKFactorization, A, b, u)
48+
function do_factorization(::UMFPACKFactorization, A, b, u)
4349
if A isa AbstractDiffEqOperator
4450
A = A.A
4551
end
@@ -62,7 +68,7 @@ function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization)
6268
SuiteSparse.UMFPACK.umfpack_symbolic!(cache.cacheval)
6369
fact = lu!(cache.cacheval, A)
6470
else
65-
fact = init_cacheval(alg, A, cache.b, cache.u)
71+
fact = do_factorization(alg, A, cache.b, cache.u)
6672
end
6773
cache = set_cacheval(cache, fact)
6874
end
@@ -75,7 +81,7 @@ Base.@kwdef struct KLUFactorization <: AbstractFactorization
7581
reuse_symbolic::Bool = true
7682
end
7783

78-
function init_cacheval(::KLUFactorization, A, b, u)
84+
function do_factorization(::KLUFactorization, A, b, u)
7985
if A isa AbstractDiffEqOperator
8086
A = A.A
8187
end
@@ -98,7 +104,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization)
98104
KLU.klu_analyze!(cache.cacheval)
99105
fact = klu!(cache.cacheval, A)
100106
else
101-
fact = init_cacheval(alg, A, cache.b, cache.u)
107+
fact = do_factorization(alg, A, cache.b, cache.u)
102108
end
103109
cache = set_cacheval(cache, fact)
104110
end
@@ -123,7 +129,7 @@ function QRFactorization()
123129
QRFactorization(pivot, 16)
124130
end
125131

126-
function init_cacheval(alg::QRFactorization, A, b, u)
132+
function do_factorization(alg::QRFactorization, A, b, u)
127133
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
128134
error("QR is not defined for $(typeof(A))")
129135

@@ -143,7 +149,7 @@ end
143149

144150
SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
145151

146-
function init_cacheval(alg::SVDFactorization, A, b, u)
152+
function do_factorization(alg::SVDFactorization, A, b, u)
147153
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
148154
error("SVD is not defined for $(typeof(A))")
149155

@@ -164,7 +170,7 @@ end
164170
GenericFactorization(;fact_alg = LinearAlgebra.factorize) =
165171
GenericFactorization(fact_alg)
166172

167-
function init_cacheval(alg::GenericFactorization, A, b, u)
173+
function do_factorization(alg::GenericFactorization, A, b, u)
168174
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
169175
error("GenericFactorization is not defined for $(typeof(A))")
170176

test/runtests.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,24 @@ end
4343
y = solve(prob1)
4444
@test A1 * y b1
4545

46-
_prob = LinearProblem(SymTridiagonal(A1.A), b1; u0=x1)
46+
_prob = LinearProblem(SymTridiagonal(A1), b1; u0=x1)
4747
y = solve(_prob)
4848
@test A1 * y b1
4949

50-
_prob = LinearProblem(Tridiagonal(A1.A), b1; u0=x1)
50+
_prob = LinearProblem(Tridiagonal(A1), b1; u0=x1)
5151
y = solve(_prob)
5252
@test A1 * y b1
5353

54-
_prob = LinearProblem(Symmetric(A1.A), b1; u0=x1)
54+
_prob = LinearProblem(Symmetric(A1), b1; u0=x1)
5555
y = solve(_prob)
5656
@test A1 * y b1
5757

58-
_prob = LinearProblem(Hermitian(A1.A), b1; u0=x1)
58+
_prob = LinearProblem(Hermitian(A1), b1; u0=x1)
5959
y = solve(_prob)
6060
@test A1 * y b1
6161

6262

63-
_prob = LinearProblem(sparse(A1.A), b1; u0=x1)
63+
_prob = LinearProblem(sparse(A1), b1; u0=x1)
6464
y = solve(_prob)
6565
@test A1 * y b1
6666
end

0 commit comments

Comments
 (0)