@@ -8,6 +8,46 @@ to avoid allocations and does not require libblastrampoline.
8
8
"""
9
9
struct MKLLUFactorization <: AbstractFactorization end
10
10
11
+ function getrf! (A:: AbstractMatrix{<:ComplexF64} ;
12
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
13
+ info = Ref {BlasInt} (),
14
+ check = false )
15
+ require_one_based_indexing (A)
16
+ check && chkfinite (A)
17
+ chkstride1 (A)
18
+ m, n = size (A)
19
+ lda = max (1 , stride (A, 2 ))
20
+ if isempty (ipiv)
21
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
22
+ end
23
+ ccall ((@blasfunc (zgetrf_), MKL_jll. libmkl_rt), Cvoid,
24
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
25
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
26
+ m, n, A, lda, ipiv, info)
27
+ chkargsok (info[])
28
+ A, ipiv, info[], info # Error code is stored in LU factorization type
29
+ end
30
+
31
+ function getrf! (A:: AbstractMatrix{<:ComplexF32} ;
32
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
33
+ info = Ref {BlasInt} (),
34
+ check = false )
35
+ require_one_based_indexing (A)
36
+ check && chkfinite (A)
37
+ chkstride1 (A)
38
+ m, n = size (A)
39
+ lda = max (1 , stride (A, 2 ))
40
+ if isempty (ipiv)
41
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
42
+ end
43
+ ccall ((@blasfunc (cgetrf_), MKL_jll. libmkl_rt), Cvoid,
44
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
45
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
46
+ m, n, A, lda, ipiv, info)
47
+ chkargsok (info[])
48
+ A, ipiv, info[], info # Error code is stored in LU factorization type
49
+ end
50
+
11
51
function getrf! (A:: AbstractMatrix{<:Float64} ;
12
52
ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
13
53
info = Ref {BlasInt} (),
@@ -48,6 +88,56 @@ function getrf!(A::AbstractMatrix{<:Float32};
48
88
A, ipiv, info[], info # Error code is stored in LU factorization type
49
89
end
50
90
91
+ function getrs! (trans:: AbstractChar ,
92
+ A:: AbstractMatrix{<:ComplexF64} ,
93
+ ipiv:: AbstractVector{BlasInt} ,
94
+ B:: AbstractVecOrMat{<:ComplexF64} ;
95
+ info = Ref {BlasInt} ())
96
+ require_one_based_indexing (A, ipiv, B)
97
+ LinearAlgebra. LAPACK. chktrans (trans)
98
+ chkstride1 (A, B, ipiv)
99
+ n = LinearAlgebra. checksquare (A)
100
+ if n != size (B, 1 )
101
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
102
+ end
103
+ if n != length (ipiv)
104
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
105
+ end
106
+ nrhs = size (B, 2 )
107
+ ccall ((" zgetrs_" , MKL_jll. libmkl_rt), Cvoid,
108
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
109
+ Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
110
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
111
+ 1 )
112
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
113
+ B
114
+ end
115
+
116
+ function getrs! (trans:: AbstractChar ,
117
+ A:: AbstractMatrix{<:ComplexF32} ,
118
+ ipiv:: AbstractVector{BlasInt} ,
119
+ B:: AbstractVecOrMat{<:ComplexF32} ;
120
+ info = Ref {BlasInt} ())
121
+ require_one_based_indexing (A, ipiv, B)
122
+ LinearAlgebra. LAPACK. chktrans (trans)
123
+ chkstride1 (A, B, ipiv)
124
+ n = LinearAlgebra. checksquare (A)
125
+ if n != size (B, 1 )
126
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
127
+ end
128
+ if n != length (ipiv)
129
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
130
+ end
131
+ nrhs = size (B, 2 )
132
+ ccall ((" cgetrs_" , MKL_jll. libmkl_rt), Cvoid,
133
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
134
+ Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
135
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
136
+ 1 )
137
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
138
+ B
139
+ end
140
+
51
141
function getrs! (trans:: AbstractChar ,
52
142
A:: AbstractMatrix{<:Float64} ,
53
143
ipiv:: AbstractVector{BlasInt} ,
@@ -106,12 +196,19 @@ const PREALLOCATED_MKL_LU = begin
106
196
luinst = ArrayInterface. lu_instance (A), Ref {BlasInt} ()
107
197
end
108
198
109
- function init_cacheval (alg:: MKLLUFactorization , A, b, u, Pl, Pr,
199
+ function LinearSolve . init_cacheval (alg:: MKLLUFactorization , A, b, u, Pl, Pr,
110
200
maxiters:: Int , abstol, reltol, verbose:: Bool ,
111
201
assumptions:: OperatorAssumptions )
112
202
PREALLOCATED_MKL_LU
113
203
end
114
204
205
+ function LinearSolve. init_cacheval (alg:: MKLLUFactorization , A:: AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}} , b, u, Pl, Pr,
206
+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
207
+ assumptions:: OperatorAssumptions )
208
+ A = rand (eltype (A), 0 , 0 )
209
+ ArrayInterface. lu_instance (A), Ref {BlasInt} ()
210
+ end
211
+
115
212
function SciMLBase. solve! (cache:: LinearCache , alg:: MKLLUFactorization ;
116
213
kwargs... )
117
214
A = cache. A
0 commit comments