Skip to content

Commit 484a8ab

Browse files
authored
vulkan: Add N/2 and N/4 optimized paths in coopmat2 shader (#12312)
1 parent cf2270e commit 484a8ab

File tree

2 files changed

+72
-31
lines changed

2 files changed

+72
-31
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -1597,33 +1597,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
15971597
uint32_t l_align, m_align, s_align;
15981598
if (device->coopmat2) {
15991599
// spec constants and tile sizes for non-quant matmul/matmul_id
1600-
l_warptile = { 256, 128, 256, 64 };
1601-
m_warptile = { 256, 128, 128, 64 };
1602-
s_warptile = { 128, 64, 64, 64 };
1600+
l_warptile = { 256, 128, 256, 64, 1 };
1601+
m_warptile = { 256, 128, 128, 64, 0 };
1602+
s_warptile = { 128, 64, 64, 64, 0 };
16031603
l_wg_denoms = {128, 256, 1 };
16041604
m_wg_denoms = {128, 128, 1 };
16051605
s_wg_denoms = { 64, 64, 1 };
16061606

16071607
// spec constants and tile sizes for quant matmul (non-Qi_K)
1608-
l_warptile_mmq = { 256, 128, 256, 64 };
1609-
m_warptile_mmq = { 256, 128, 128, 64 };
1610-
s_warptile_mmq = { 256, 32, 64, 128 };
1608+
l_warptile_mmq = { 256, 128, 256, 64, 1 };
1609+
m_warptile_mmq = { 256, 128, 128, 64, 1 };
1610+
s_warptile_mmq = { 256, 32, 64, 128, 0 };
16111611
l_mmq_wg_denoms = { 128, 256, 1 };
16121612
m_mmq_wg_denoms = { 128, 128, 1 };
16131613
s_mmq_wg_denoms = { 32, 64, 1 };
16141614

16151615
// spec constants and tile sizes for quant matmul (Qi_K)
1616-
l_warptile_mmq_k = { 256, 64, 128, 64 };
1617-
m_warptile_mmq_k = { 256, 32, 64, 64 };
1618-
s_warptile_mmq_k = { 256, 32, 32, 128 };
1616+
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
1617+
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
1618+
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
16191619
l_mmq_wg_denoms_k = { 64, 128, 1 };
16201620
m_mmq_wg_denoms_k = { 32, 64, 1 };
16211621
s_mmq_wg_denoms_k = { 32, 32, 1 };
16221622

16231623
// spec constants and tile sizes for quant matmul_id
1624-
l_warptile_mmqid = { 256, 128, 64, 16 };
1625-
m_warptile_mmqid = { 256, 128, 64, 16 };
1626-
s_warptile_mmqid = { 256, 128, 64, 16 };
1624+
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
1625+
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
1626+
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
16271627
l_mmqid_wg_denoms = { 128, 64, 1 };
16281628
m_mmqid_wg_denoms = { 128, 64, 1 };
16291629
s_mmqid_wg_denoms = { 128, 64, 1 };

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

+60-19
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64;
2323
layout (constant_id = 2) const uint BN = 64;
2424
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
2525

26+
layout (constant_id = 4) const bool enable_smaller_matrices = false;
27+
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
28+
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
29+
2630
layout (push_constant) uniform parameter
2731
{
2832
uint M;
@@ -168,15 +172,13 @@ void main() {
168172
const uint end_k = min(p.K, (ik + 1) * p.k_split);
169173
#endif
170174

171-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
172-
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
173-
174175
#ifdef MUL_MAT_ID
175176
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
176177
uint pos_b = 0;
177178
#else
178179
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
179180
uint pos_b = batch_idx * p.batch_stride_b;
181+
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
180182
#endif
181183

182184
uint stride_a = p.stride_a / QUANT_K;
@@ -197,6 +199,7 @@ void main() {
197199
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
198200
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
199201
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
202+
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
200203

201204
#if QUANT_K > 1
202205
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
@@ -232,16 +235,54 @@ void main() {
232235
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
233236

234237
uint k_iters = (end_k - start_k + BK - 1) / BK;
238+
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
239+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
240+
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
235241

236-
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
242+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
243+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
237244

238-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
239-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
245+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
246+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
247+
248+
sum = coopMatMulAdd(mat_a, mat_b, sum);
249+
}
250+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
251+
252+
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
253+
return;
254+
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
255+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
256+
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
257+
258+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
259+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
260+
261+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
262+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
263+
264+
sum = coopMatMulAdd(mat_a, mat_b, sum);
265+
}
266+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
267+
268+
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
269+
return;
270+
} else {
271+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
272+
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
273+
274+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
240276

241-
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
242-
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
277+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
279+
280+
sum = coopMatMulAdd(mat_a, mat_b, sum);
281+
}
282+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
243283

244-
sum = coopMatMulAdd(mat_a, mat_b, sum);
284+
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
285+
return;
245286
}
246287
} else
247288
#endif // !defined(MUL_MAT_ID)
@@ -254,6 +295,9 @@ void main() {
254295

255296
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
256297

298+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
299+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
300+
257301
[[dont_unroll]]
258302
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
259303

@@ -296,19 +340,16 @@ void main() {
296340
sum = coopMatMulAdd(mat_a, mat_b, sum);
297341
}
298342
}
299-
}
300343

301-
// Convert from ACC_TYPE to D_TYPE
302-
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
303-
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
344+
// Convert from ACC_TYPE to D_TYPE
345+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
346+
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
304347

305348
#ifdef MUL_MAT_ID
306-
// Call callback to store each element, remapping row through shared memory
307-
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
349+
// Call callback to store each element, remapping row through shared memory
350+
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
308351
#else
309-
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
310-
311-
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
312-
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
352+
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
313353
#endif
354+
}
314355
}

0 commit comments

Comments
 (0)