Skip to content

Commit a9b5581

Browse files
Merge pull request #415 from SciML/complex
Handle complex number dispatches in AppleAccelerate and MKL
2 parents d863895 + 06e6f81 commit a9b5581

File tree

4 files changed

+200
-9
lines changed

4 files changed

+200
-9
lines changed

src/appleaccelerate.jl

+97
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,46 @@ function appleaccelerate_isavailable()
2626
return true
2727
end
2828

29+
function aa_getrf!(A::AbstractMatrix{<:ComplexF64};
30+
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
31+
info = Ref{Cint}(),
32+
check = false)
33+
require_one_based_indexing(A)
34+
check && chkfinite(A)
35+
chkstride1(A)
36+
m, n = size(A)
37+
lda = max(1, stride(A, 2))
38+
if isempty(ipiv)
39+
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
40+
end
41+
ccall(("zgetrf_", libacc), Cvoid,
42+
(Ref{Cint}, Ref{Cint}, Ptr{ComplexF64},
43+
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
44+
m, n, A, lda, ipiv, info)
45+
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
46+
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
47+
end
48+
49+
function aa_getrf!(A::AbstractMatrix{<:ComplexF32};
50+
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
51+
info = Ref{Cint}(),
52+
check = false)
53+
require_one_based_indexing(A)
54+
check && chkfinite(A)
55+
chkstride1(A)
56+
m, n = size(A)
57+
lda = max(1, stride(A, 2))
58+
if isempty(ipiv)
59+
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
60+
end
61+
ccall(("cgetrf_", libacc), Cvoid,
62+
(Ref{Cint}, Ref{Cint}, Ptr{ComplexF32},
63+
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
64+
m, n, A, lda, ipiv, info)
65+
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
66+
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
67+
end
68+
2969
function aa_getrf!(A::AbstractMatrix{<:Float64};
3070
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
3171
info = Ref{Cint}(),
@@ -67,6 +107,55 @@ function aa_getrf!(A::AbstractMatrix{<:Float32};
67107
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
68108
end
69109

110+
function aa_getrs!(trans::AbstractChar,
111+
A::AbstractMatrix{<:ComplexF64},
112+
ipiv::AbstractVector{Cint},
113+
B::AbstractVecOrMat{<:ComplexF64};
114+
info = Ref{Cint}())
115+
require_one_based_indexing(A, ipiv, B)
116+
LinearAlgebra.LAPACK.chktrans(trans)
117+
chkstride1(A, B, ipiv)
118+
n = LinearAlgebra.checksquare(A)
119+
if n != size(B, 1)
120+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
121+
end
122+
if n != length(ipiv)
123+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
124+
end
125+
nrhs = size(B, 2)
126+
ccall(("zgetrs_", libacc), Cvoid,
127+
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF64}, Ref{Cint},
128+
Ptr{Cint}, Ptr{ComplexF64}, Ref{Cint}, Ptr{Cint}, Clong),
129+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
130+
1)
131+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
132+
end
133+
134+
function aa_getrs!(trans::AbstractChar,
135+
A::AbstractMatrix{<:ComplexF32},
136+
ipiv::AbstractVector{Cint},
137+
B::AbstractVecOrMat{<:ComplexF32};
138+
info = Ref{Cint}())
139+
require_one_based_indexing(A, ipiv, B)
140+
LinearAlgebra.LAPACK.chktrans(trans)
141+
chkstride1(A, B, ipiv)
142+
n = LinearAlgebra.checksquare(A)
143+
if n != size(B, 1)
144+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
145+
end
146+
if n != length(ipiv)
147+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
148+
end
149+
nrhs = size(B, 2)
150+
ccall(("cgetrs_", libacc), Cvoid,
151+
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF32}, Ref{Cint},
152+
Ptr{Cint}, Ptr{ComplexF32}, Ref{Cint}, Ptr{Cint}, Clong),
153+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
154+
1)
155+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
156+
B
157+
end
158+
70159
function aa_getrs!(trans::AbstractChar,
71160
A::AbstractMatrix{<:Float64},
72161
ipiv::AbstractVector{Cint},
@@ -134,6 +223,14 @@ function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u,
134223
PREALLOCATED_APPLE_LU
135224
end
136225

226+
function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr,
227+
maxiters::Int, abstol, reltol, verbose::Bool,
228+
assumptions::OperatorAssumptions)
229+
A = rand(eltype(A), 0, 0)
230+
luinst = ArrayInterface.lu_instance(A)
231+
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
232+
end
233+
137234
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;
138235
kwargs...)
139236
A = cache.A

src/default.jl

+3-7
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ function defaultalg(A, b, assump::OperatorAssumptions)
162162
__conditioning(assump) === OperatorCondition.WellConditioned)
163163
if length(b) <= 10
164164
DefaultAlgorithmChoice.GenericLUFactorization
165-
elseif VERSION >= v"1.8" && appleaccelerate_isavailable() &&
166-
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
167-
eltype(A) <: Union{Float32, Float64})
165+
elseif VERSION >= v"1.8" && appleaccelerate_isavailable()
168166
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
169167
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
170168
(usemkl && length(b) <= 200)) &&
@@ -173,8 +171,7 @@ function defaultalg(A, b, assump::OperatorAssumptions)
173171
DefaultAlgorithmChoice.RFLUFactorization
174172
#elseif A === nothing || A isa Matrix
175173
# alg = FastLUFactorization()
176-
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
177-
eltype(A) <: Union{Float32, Float64})
174+
elseif usemkl
178175
DefaultAlgorithmChoice.MKLLUFactorization
179176
else
180177
DefaultAlgorithmChoice.LUFactorization
@@ -183,8 +180,7 @@ function defaultalg(A, b, assump::OperatorAssumptions)
183180
DefaultAlgorithmChoice.QRFactorization
184181
elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned
185182
DefaultAlgorithmChoice.SVDFactorization
186-
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
187-
eltype(A) <: Union{Float32, Float64})
183+
elseif usemkl
188184
DefaultAlgorithmChoice.MKLLUFactorization
189185
else
190186
DefaultAlgorithmChoice.LUFactorization

src/mkl.jl

+98-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,46 @@ to avoid allocations and does not require libblastrampoline.
88
"""
99
struct MKLLUFactorization <: AbstractFactorization end
1010

11+
function getrf!(A::AbstractMatrix{<:ComplexF64};
12+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
13+
info = Ref{BlasInt}(),
14+
check = false)
15+
require_one_based_indexing(A)
16+
check && chkfinite(A)
17+
chkstride1(A)
18+
m, n = size(A)
19+
lda = max(1, stride(A, 2))
20+
if isempty(ipiv)
21+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
22+
end
23+
ccall((@blasfunc(zgetrf_), MKL_jll.libmkl_rt), Cvoid,
24+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
25+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
26+
m, n, A, lda, ipiv, info)
27+
chkargsok(info[])
28+
A, ipiv, info[], info #Error code is stored in LU factorization type
29+
end
30+
31+
function getrf!(A::AbstractMatrix{<:ComplexF32};
32+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
33+
info = Ref{BlasInt}(),
34+
check = false)
35+
require_one_based_indexing(A)
36+
check && chkfinite(A)
37+
chkstride1(A)
38+
m, n = size(A)
39+
lda = max(1, stride(A, 2))
40+
if isempty(ipiv)
41+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
42+
end
43+
ccall((@blasfunc(cgetrf_), MKL_jll.libmkl_rt), Cvoid,
44+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
45+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
46+
m, n, A, lda, ipiv, info)
47+
chkargsok(info[])
48+
A, ipiv, info[], info #Error code is stored in LU factorization type
49+
end
50+
1151
function getrf!(A::AbstractMatrix{<:Float64};
1252
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
1353
info = Ref{BlasInt}(),
@@ -48,6 +88,56 @@ function getrf!(A::AbstractMatrix{<:Float32};
4888
A, ipiv, info[], info #Error code is stored in LU factorization type
4989
end
5090

91+
function getrs!(trans::AbstractChar,
92+
A::AbstractMatrix{<:ComplexF64},
93+
ipiv::AbstractVector{BlasInt},
94+
B::AbstractVecOrMat{<:ComplexF64};
95+
info = Ref{BlasInt}())
96+
require_one_based_indexing(A, ipiv, B)
97+
LinearAlgebra.LAPACK.chktrans(trans)
98+
chkstride1(A, B, ipiv)
99+
n = LinearAlgebra.checksquare(A)
100+
if n != size(B, 1)
101+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
102+
end
103+
if n != length(ipiv)
104+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
105+
end
106+
nrhs = size(B, 2)
107+
ccall(("zgetrs_", MKL_jll.libmkl_rt), Cvoid,
108+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
109+
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
110+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
111+
1)
112+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
113+
B
114+
end
115+
116+
function getrs!(trans::AbstractChar,
117+
A::AbstractMatrix{<:ComplexF32},
118+
ipiv::AbstractVector{BlasInt},
119+
B::AbstractVecOrMat{<:ComplexF32};
120+
info = Ref{BlasInt}())
121+
require_one_based_indexing(A, ipiv, B)
122+
LinearAlgebra.LAPACK.chktrans(trans)
123+
chkstride1(A, B, ipiv)
124+
n = LinearAlgebra.checksquare(A)
125+
if n != size(B, 1)
126+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
127+
end
128+
if n != length(ipiv)
129+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
130+
end
131+
nrhs = size(B, 2)
132+
ccall(("cgetrs_", MKL_jll.libmkl_rt), Cvoid,
133+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
134+
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
135+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
136+
1)
137+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
138+
B
139+
end
140+
51141
function getrs!(trans::AbstractChar,
52142
A::AbstractMatrix{<:Float64},
53143
ipiv::AbstractVector{BlasInt},
@@ -106,12 +196,19 @@ const PREALLOCATED_MKL_LU = begin
106196
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
107197
end
108198

109-
function init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
199+
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
110200
maxiters::Int, abstol, reltol, verbose::Bool,
111201
assumptions::OperatorAssumptions)
112202
PREALLOCATED_MKL_LU
113203
end
114204

205+
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr,
206+
maxiters::Int, abstol, reltol, verbose::Bool,
207+
assumptions::OperatorAssumptions)
208+
A = rand(eltype(A), 0, 0)
209+
ArrayInterface.lu_instance(A), Ref{BlasInt}()
210+
end
211+
115212
function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
116213
kwargs...)
117214
A = cache.A

test/basictests.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,12 @@ end
235235
for alg in test_algs
236236
@testset "$alg" begin
237237
test_interface(alg, prob1, prob2)
238-
VERSION >= v"1.9" && (alg isa MKLLUFactorization || test_interface(alg, prob3, prob4))
238+
VERSION >= v"1.9" && test_interface(alg, prob3, prob4)
239239
end
240240
end
241241
if LinearSolve.appleaccelerate_isavailable()
242242
test_interface(AppleAccelerateLUFactorization(), prob1, prob2)
243+
test_interface(AppleAccelerateLUFactorization(), prob3, prob4)
243244
end
244245
end
245246

0 commit comments

Comments
 (0)