Skip to content

Commit 296b142

Browse files
Merge pull request #388 from SciML/gpu_default
Fix GPU tests
2 parents a53f644 + b8ef3e4 commit 296b142

File tree

2 files changed

+19
-24
lines changed

2 files changed

+19
-24
lines changed

ext/LinearSolveCUDAExt.jl

+11-16
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,26 @@
11
module LinearSolveCUDAExt
22

3-
using CUDA, LinearAlgebra, LinearSolve, SciMLBase
3+
using CUDA
4+
using LinearSolve
5+
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
46
using SciMLBase: AbstractSciMLOperator
57

68
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
79
kwargs...)
810
if cache.isfresh
9-
fact = LinearSolve.do_factorization(alg, CUDA.CuArray(cache.A), cache.b, cache.u)
10-
cache = LinearSolve.set_cacheval(cache, fact)
11+
fact = qr(CUDA.CuArray(cache.A))
12+
cache.cacheval = fact
1113
cache.isfresh = false
1214
end
13-
14-
copyto!(cache.u, cache.b)
15-
y = Array(ldiv!(cache.cacheval, CUDA.CuArray(cache.u)))
15+
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
16+
cache.u .= y
1617
SciMLBase.build_linear_solution(alg, y, nothing, cache)
1718
end
1819

19-
function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
20-
A isa Union{AbstractMatrix, AbstractSciMLOperator} ||
21-
error("LU is not defined for $(typeof(A))")
22-
23-
if A isa Union{MatrixOperator, DiffEqArrayOperator}
24-
A = A.A
25-
end
26-
27-
fact = qr(CUDA.CuArray(A))
28-
return fact
20+
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
21+
maxiters::Int, abstol, reltol, verbose::Bool,
22+
assumptions::OperatorAssumptions)
23+
qr(CUDA.CuArray(A))
2924
end
3025

3126
end

test/gpu/cuda.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ function test_interface(alg, prob1, prob2)
2828
@test A1 * y b1
2929

3030
cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
31-
y = solve(cache)
32-
@test A1 * y b1
31+
solve!(cache)
32+
@test A1 * cache.u b1
3333

34-
cache = LinearSolve.set_A(cache, copy(A2))
35-
y = solve(cache)
36-
@test A2 * y b1
34+
cache.A = copy(A2)
35+
solve!(cache)
36+
@test A2 * cache.u b1
3737

38-
cache = LinearSolve.set_b(cache, b2)
39-
y = solve(cache)
40-
@test A2 * y b2
38+
cache.b = copy(b2)
39+
solve!(cache)
40+
@test A2 * cache.u b2
4141

4242
return
4343
end

0 commit comments

Comments
 (0)