Skip to content

Commit d914caa

Browse files
Merge pull request #418 from SciML/simplegmres
For SimpleGMRES we need to reinitialize some cache when `b` is set again
2 parents 7ee3fa2 + fc36891 commit d914caa

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.16.0"
4+
version = "2.16.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/common.jl

+8
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,21 @@ end
8585
function Base.setproperty!(cache::LinearCache, name::Symbol, x)
8686
if name === :A
8787
setfield!(cache, :isfresh, true)
88+
elseif name === :b
89+
# In case there is something that needs to be done when b is updated
90+
update_cacheval!(cache, :b, x)
8891
elseif name === :cacheval && cache.alg isa DefaultLinearSolver
8992
@assert cache.cacheval isa DefaultLinearSolverInit
9093
return setfield!(cache.cacheval, Symbol(cache.alg.alg), x)
9194
end
9295
setfield!(cache, name, x)
9396
end
9497

98+
function update_cacheval!(cache::LinearCache, name::Symbol, x)
99+
return update_cacheval!(cache, cache.cacheval, name, x)
100+
end
101+
update_cacheval!(cache, cacheval, name::Symbol, x) = cacheval
102+
95103
init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing
96104

97105
function SciMLBase.init(prob::LinearProblem, args...; kwargs...)

src/simplegmres.jl

+7
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ end
7272
warm_start::Bool
7373
end
7474

75+
function update_cacheval!(cache::LinearCache, cacheval::SimpleGMRESCache, name::Symbol, x)
76+
(name != :b || cache.isfresh) && return cacheval
77+
vec(cacheval.w) .= vec(x)
78+
fill!(cacheval.x, 0)
79+
return cacheval
80+
end
81+
7582
"""
7683
(c, s, ρ) = _sym_givens(a, b)
7784

test/gpu/cuda.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using LinearSolve, CUDA, LinearAlgebra, SparseArrays
22
using Test
33

4+
CUDA.allowscalar(false)
5+
46
n = 8
57
A = Matrix(I, n, n)
68
b = ones(n)
@@ -25,19 +27,19 @@ function test_interface(alg, prob1, prob2)
2527
x2 = prob2.u0
2628

2729
y = solve(prob1, alg; cache_kwargs...)
28-
@test A1 * y b1
30+
@test CUDA.@allowscalar(Array(A1 * y) Array(b1))
2931

3032
cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
3133
solve!(cache)
32-
@test A1 * cache.u b1
34+
@test CUDA.@allowscalar(Array(A1 * cache.u) Array(b1))
3335

3436
cache.A = copy(A2)
3537
solve!(cache)
36-
@test A2 * cache.u b1
38+
@test CUDA.@allowscalar(Array(A2 * cache.u) Array(b1))
3739

3840
cache.b = copy(b2)
3941
solve!(cache)
40-
@test A2 * cache.u b2
42+
@test CUDA.@allowscalar(Array(A2 * cache.u) Array(b2))
4143

4244
return
4345
end
@@ -62,8 +64,8 @@ using BlockDiagonals
6264
A = BlockDiagonal([rand(2, 2) for _ in 1:3]) |> cu
6365
b = rand(size(A, 1)) |> cu
6466

65-
x1 = zero(b)
66-
x2 = zero(b)
67+
x1 = zero(b) |> cu
68+
x2 = zero(b) |> cu
6769
prob1 = LinearProblem(A, b, x1)
6870
prob2 = LinearProblem(A, b, x2)
6971

0 commit comments

Comments
 (0)