|
1 | 1 | module LinearSolve
|
2 | 2 |
|
3 |
| -using Base: cache_dependencies, Bool |
4 |
| -using SciMLBase: AbstractLinearAlgorithm, AbstractDiffEqOperator |
5 | 3 | using ArrayInterface: lu_instance
|
6 |
| -using UnPack |
7 |
| -using Reexport |
| 4 | +using Base: cache_dependencies, Bool |
| 5 | +using Krylov |
8 | 6 | using LinearAlgebra
|
| 7 | +using Reexport |
| 8 | +using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm |
9 | 9 | using Setfield
|
10 |
| -@reexport using SciMLBase |
11 |
| - |
12 |
| -export LUFactorization, QRFactorization, SVDFactorization |
13 |
| - |
14 |
| -#mutable?# |
15 |
| -struct LinearCache{TA,Tb,Tp,Talg,Tc,Tr,Tl} |
16 |
| - A::TA |
17 |
| - b::Tb |
18 |
| - p::Tp |
19 |
| - alg::Talg |
20 |
| - cacheval::Tc |
21 |
| - isfresh::Bool |
22 |
| - Pr::Tr |
23 |
| - Pl::Tl |
24 |
| -end |
25 |
| - |
26 |
| -function set_A(cache, A) |
27 |
| - @set! cache.A = A |
28 |
| - @set! cache.isfresh = true |
29 |
| -end |
30 |
| - |
31 |
| -function set_b(cache, b) |
32 |
| - @set! cache.b = b |
33 |
| -end |
34 |
| - |
35 |
| -function set_p(cache, p) |
36 |
| - @set! cache.p = p |
37 |
| - # @set! cache.isfresh = true |
38 |
| -end |
39 |
| - |
40 |
| -function set_cacheval(cache::LinearCache,alg) |
41 |
| - if cache.isfresh |
42 |
| - @set! cache.cacheval = alg |
43 |
| - @set! cache.isfresh = false |
44 |
| - end |
45 |
| - return cache |
46 |
| -end |
47 |
| - |
48 |
| -function SciMLBase.init(prob::LinearProblem, alg; |
49 |
| - alias_A = false, alias_b = false, |
50 |
| - kwargs...) |
51 |
| - @unpack A, b, p = prob |
52 |
| - if alg isa LUFactorization |
53 |
| - fact = lu_instance(A) |
54 |
| - Tfact = typeof(fact) |
55 |
| - else |
56 |
| - fact = nothing |
57 |
| - Tfact = Any |
58 |
| - end |
59 |
| - Pr = nothing |
60 |
| - Pl = nothing |
61 |
| - |
62 |
| - A = alias_A ? A : copy(A) |
63 |
| - b = alias_b ? b : copy(b) |
64 |
| - |
65 |
| - cache = LinearCache{typeof(A),typeof(b),typeof(p),typeof(alg),Tfact,typeof(Pr),typeof(Pl)}( |
66 |
| - A, b, p, alg, fact, true, Pr, Pl |
67 |
| - ) |
68 |
| - return cache |
69 |
| -end |
70 |
| - |
71 |
| -SciMLBase.solve(prob::LinearProblem, alg; kwargs...) = solve(init(prob, alg; kwargs...)) |
72 |
| -SciMLBase.solve(cache) = solve(cache, cache.alg) |
73 |
| - |
74 |
| -struct LUFactorization{P} <: AbstractLinearAlgorithm |
75 |
| - pivot::P |
76 |
| -end |
77 |
| -function LUFactorization() |
78 |
| - pivot = @static if VERSION < v"1.7beta" |
79 |
| - Val(true) |
80 |
| - else |
81 |
| - RowMaximum() |
82 |
| - end |
83 |
| - LUFactorization(pivot) |
84 |
| -end |
85 |
| - |
86 |
| -function SciMLBase.solve(cache::LinearCache, alg::LUFactorization) |
87 |
| - cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("LU is not defined for $(typeof(prob.A))") |
88 |
| - cache = set_cacheval(cache,lu!(cache.A, alg.pivot)) |
89 |
| - ldiv!(cache.cacheval, cache.b) |
90 |
| -end |
| 10 | +using UnPack |
91 | 11 |
|
92 |
| -struct QRFactorization{P} <: AbstractLinearAlgorithm |
93 |
| - pivot::P |
94 |
| - blocksize::Int |
95 |
| -end |
96 |
| -function QRFactorization() |
97 |
| - pivot = @static if VERSION < v"1.7beta" |
98 |
| - Val(false) |
99 |
| - else |
100 |
| - NoPivot() |
101 |
| - end |
102 |
| - QRFactorization(pivot, 16) |
103 |
| -end |
| 12 | +@reexport using SciMLBase |
104 | 13 |
|
105 |
| -function SciMLBase.solve(cache::LinearCache, alg::QRFactorization) |
106 |
| - cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("QR is not defined for $(typeof(prob.A))") |
107 |
| - cache = set_cacheval(cache,qr!(cache.A.A, alg.pivot; blocksize=alg.blocksize)) |
108 |
| - ldiv!(cache.cacheval, cache.b) |
109 |
| -end |
| 14 | +abstract type SciMLLinearSolveAlgorithm end |
110 | 15 |
|
111 |
| -struct SVDFactorization{A} <: AbstractLinearAlgorithm |
112 |
| - full::Bool |
113 |
| - alg::A |
114 |
| -end |
115 |
| -SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer()) |
| 16 | +include("common.jl") |
| 17 | +include("factorization.jl") |
| 18 | +include("krylov.jl") |
116 | 19 |
|
117 |
| -function SciMLBase.solve(cache::LinearCache, alg::SVDFactorization) |
118 |
| - cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("SVD is not defined for $(typeof(prob.A))") |
119 |
| - cache = set_cacheval(cache,svd!(cache.A; full=alg.full, alg=alg.alg)) |
120 |
| - ldiv!(cache.cacheval, cache.b) |
121 |
| -end |
| 20 | +export LUFactorization, SVDFactorization, QRFactorization |
| 21 | +export KrylovJL |
122 | 22 |
|
123 | 23 | end
|
0 commit comments