Skip to content

Commit 2afc31d

Browse files
committed
For SimpleGMRES we need to reinitialize some cache when b is set again
1 parent a9b5581 commit 2afc31d

File tree

4 files changed

+22
-7
lines changed

4 files changed

+22
-7
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.15.0"
4+
version = "2.15.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/common.jl

+8
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,21 @@ end
8585
function Base.setproperty!(cache::LinearCache, name::Symbol, x)
8686
if name === :A
8787
setfield!(cache, :isfresh, true)
88+
elseif name === :b
89+
# In case there is something that needs to be done when b is updated
90+
update_cacheval!(cache, :b, x)
8891
elseif name === :cacheval && cache.alg isa DefaultLinearSolver
8992
@assert cache.cacheval isa DefaultLinearSolverInit
9093
return setfield!(cache.cacheval, Symbol(cache.alg.alg), x)
9194
end
9295
setfield!(cache, name, x)
9396
end
9497

98+
function update_cacheval!(cache::LinearCache, name::Symbol, x)
99+
return update_cacheval!(cache, cache.cacheval, name, x)
100+
end
101+
update_cacheval!(cache, cacheval, name::Symbol, x) = cacheval
102+
95103
init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing
96104

97105
function SciMLBase.init(prob::LinearProblem, args...; kwargs...)

src/simplegmres.jl

+7
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ end
7272
warm_start::Bool
7373
end
7474

75+
function update_cacheval!(cache::LinearCache, cacheval::SimpleGMRESCache, name::Symbol, x)
76+
(name != :b || cache.isfresh) && return cacheval
77+
vec(cacheval.w) .= vec(x)
78+
fill!(cacheval.x, 0)
79+
return cacheval
80+
end
81+
7582
"""
7683
(c, s, ρ) = _sym_givens(a, b)
7784

test/gpu/cuda.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@ function test_interface(alg, prob1, prob2)
2525
x2 = prob2.u0
2626

2727
y = solve(prob1, alg; cache_kwargs...)
28-
@test A1 * y b1
28+
@test Array(A1 * y) Array(b1)
2929

3030
cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
3131
solve!(cache)
32-
@test A1 * cache.u b1
32+
@test Array(A1 * cache.u) Array(b1)
3333

3434
cache.A = copy(A2)
3535
solve!(cache)
36-
@test A2 * cache.u b1
36+
@test Array(A2 * cache.u) Array(b1)
3737

3838
cache.b = copy(b2)
3939
solve!(cache)
40-
@test A2 * cache.u b2
40+
@test Array(A2 * cache.u) Array(b2)
4141

4242
return
4343
end
@@ -62,8 +62,8 @@ using BlockDiagonals
6262
A = BlockDiagonal([rand(2, 2) for _ in 1:3]) |> cu
6363
b = rand(size(A, 1)) |> cu
6464

65-
x1 = zero(b)
66-
x2 = zero(b)
65+
x1 = zero(b) |> cu
66+
x2 = zero(b) |> cu
6767
prob1 = LinearProblem(A, b, x1)
6868
prob2 = LinearProblem(A, b, x2)
6969

0 commit comments

Comments
 (0)