|
1 | 1 | module LinearSolveCUDAExt
|
2 | 2 |
|
3 |
| -using CUDA, LinearAlgebra, LinearSolve, SciMLBase |
| 3 | +using CUDA |
| 4 | +using LinearSolve |
| 5 | +using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface |
4 | 6 | using SciMLBase: AbstractSciMLOperator
|
5 | 7 |
|
6 | 8 | function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
|
7 | 9 | kwargs...)
|
8 | 10 | if cache.isfresh
|
9 |
| - fact = LinearSolve.do_factorization(alg, CUDA.CuArray(cache.A), cache.b, cache.u) |
10 |
| - cache = LinearSolve.set_cacheval(cache, fact) |
| 11 | + fact = qr(CUDA.CuArray(cache.A)) |
| 12 | + cache.cacheval = fact |
11 | 13 | cache.isfresh = false
|
12 | 14 | end
|
13 |
| - |
14 |
| - copyto!(cache.u, cache.b) |
15 |
| - y = Array(ldiv!(cache.cacheval, CUDA.CuArray(cache.u))) |
| 15 | + y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b))) |
| 16 | + cache.u .= y |
16 | 17 | SciMLBase.build_linear_solution(alg, y, nothing, cache)
|
17 | 18 | end
|
18 | 19 |
|
19 |
| -function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u) |
20 |
| - A isa Union{AbstractMatrix, AbstractSciMLOperator} || |
21 |
| - error("LU is not defined for $(typeof(A))") |
22 |
| - |
23 |
| - if A isa Union{MatrixOperator, DiffEqArrayOperator} |
24 |
| - A = A.A |
25 |
| - end |
26 |
| - |
27 |
| - fact = qr(CUDA.CuArray(A)) |
28 |
| - return fact |
| 20 | +function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr, |
| 21 | + maxiters::Int, abstol, reltol, verbose::Bool, |
| 22 | + assumptions::OperatorAssumptions) |
| 23 | + qr(CUDA.CuArray(A)) |
29 | 24 | end
|
30 | 25 |
|
31 | 26 | end
|
0 commit comments