@@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64;
23
23
layout (constant_id = 2) const uint BN = 64;
24
24
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
25
25
26
+ layout (constant_id = 4) const bool enable_smaller_matrices = false;
27
+ const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
28
+ const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
29
+
26
30
layout (push_constant) uniform parameter
27
31
{
28
32
uint M;
@@ -168,15 +172,13 @@ void main() {
168
172
const uint end_k = min(p.K, (ik + 1) * p.k_split);
169
173
#endif
170
174
171
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
172
- sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
173
-
174
175
#ifdef MUL_MAT_ID
175
176
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
176
177
uint pos_b = 0;
177
178
#else
178
179
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
179
180
uint pos_b = batch_idx * p.batch_stride_b;
181
+ uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
180
182
#endif
181
183
182
184
uint stride_a = p.stride_a / QUANT_K;
@@ -197,6 +199,7 @@ void main() {
197
199
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
198
200
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
199
201
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
202
+ tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
200
203
201
204
#if QUANT_K > 1
202
205
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
@@ -232,16 +235,54 @@ void main() {
232
235
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
233
236
234
237
uint k_iters = (end_k - start_k + BK - 1) / BK;
238
+ if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
239
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
240
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
235
241
236
- for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
242
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
243
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
237
244
238
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
239
- coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
245
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
246
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
247
+
248
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
249
+ }
250
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
251
+
252
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
253
+ return;
254
+ } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
255
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
256
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
257
+
258
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
259
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
260
+
261
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
262
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
263
+
264
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
265
+ }
266
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
267
+
268
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
269
+ return;
270
+ } else {
271
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
272
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
273
+
274
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
240
276
241
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
242
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
277
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
279
+
280
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
281
+ }
282
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
243
283
244
- sum = coopMatMulAdd(mat_a, mat_b, sum);
284
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
285
+ return;
245
286
}
246
287
} else
247
288
#endif // !defined(MUL_MAT_ID)
@@ -254,6 +295,9 @@ void main() {
254
295
255
296
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
256
297
298
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
299
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
300
+
257
301
[[dont_unroll]]
258
302
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
259
303
@@ -296,19 +340,16 @@ void main() {
296
340
sum = coopMatMulAdd(mat_a, mat_b, sum);
297
341
}
298
342
}
299
- }
300
343
301
- // Convert from ACC_TYPE to D_TYPE
302
- coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
303
- mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
344
+ // Convert from ACC_TYPE to D_TYPE
345
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
346
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
304
347
305
348
#ifdef MUL_MAT_ID
306
- // Call callback to store each element, remapping row through shared memory
307
- coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
349
+ // Call callback to store each element, remapping row through shared memory
350
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
308
351
#else
309
- tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
310
-
311
- uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
312
- coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
352
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
313
353
#endif
354
+ }
314
355
}
0 commit comments