@@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl(
4416
4416
device const half * dh = &x[ib].d ;
4417
4417
4418
4418
for (int row = 0 ; row < N_DST; row++) {
4419
-
4420
4419
float4 acc1 = {0 .f , 0 .f , 0 .f , 0 .f };
4421
4420
float4 acc2 = {0 .f , 0 .f , 0 .f , 0 .f };
4422
4421
for (int i = 0 ; i < 8 ; i += 2 ) {
@@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl(
4447
4446
4448
4447
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
4449
4448
4450
- for (int row = 0 ; row < N_DST; ++row) {
4449
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
4451
4450
all_sum = simd_sum (sumf[row]);
4452
4451
if (tiisg == 0 ) {
4453
4452
dst_f32[first_row + row] = all_sum;
@@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4613
4612
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
4614
4613
4615
4614
if (tiisg == 0 ) {
4616
- for (int row = 0 ; row < 2 ; ++row) {
4615
+ for (int row = 0 ; row < 2 && first_row + row < args. ne0 ; ++row) {
4617
4616
dst_f32[first_row + row] = sumf1[row];
4618
4617
}
4619
4618
}
@@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl(
4729
4728
4730
4729
device float * dst_f32 = (device float *) dst + (int64_t )im*args.ne0 *args.ne1 + (int64_t )r1*args.ne0 ;
4731
4730
4732
- for (int row = 0 ; row < N_DST; ++row) {
4731
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
4733
4732
all_sum = simd_sum (sumf[row]);
4734
4733
if (tiisg == 0 ) {
4735
4734
dst_f32[first_row + row] = all_sum;
@@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4861
4860
4862
4861
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
4863
4862
4864
- for (int row = 0 ; row < 2 ; ++row) {
4863
+ for (int row = 0 ; row < 2 && first_row + row < args. ne0 ; ++row) {
4865
4864
const float tot = simd_sum (sumf[row]);
4866
4865
if (tiisg == 0 ) {
4867
4866
dst_f32[first_row + row] = tot;
@@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl(
4906
4905
4907
4906
const int row = 2 *r0 + sgitg;
4908
4907
4908
+ if (row >= args.ne0 ) {
4909
+ return ;
4910
+ }
4911
+
4909
4912
const uint i12 = im%args.ne12 ;
4910
4913
const uint i13 = im/args.ne12 ;
4911
4914
@@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5061
5064
5062
5065
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5063
5066
5064
- for (int row = 0 ; row < N_DST; ++row) {
5067
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
5065
5068
all_sum = simd_sum (sumf[row]);
5066
5069
if (tiisg == 0 ) {
5067
5070
dst_f32[first_row + row] = all_sum * 0 .25f ;
@@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5179
5182
5180
5183
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5181
5184
5182
- for (int row = 0 ; row < N_DST; ++row) {
5185
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
5183
5186
all_sum = simd_sum (sumf[row]);
5184
5187
if (tiisg == 0 ) {
5185
5188
dst_f32[first_row + row] = all_sum * 0 .25f ;
@@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5289
5292
5290
5293
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5291
5294
5292
- for (int row = 0 ; row < N_DST; ++row) {
5295
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
5293
5296
all_sum = simd_sum (sumf[row]);
5294
5297
if (tiisg == 0 ) {
5295
5298
dst_f32[first_row + row] = all_sum * 0 .5f ;
@@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
5401
5404
5402
5405
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5403
5406
5404
- for (int row = 0 ; row < N_DST; ++row) {
5407
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
5405
5408
all_sum = simd_sum (sumf[row]);
5406
5409
if (tiisg == 0 ) {
5407
5410
dst_f32[first_row + row] = all_sum;
@@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
5514
5517
5515
5518
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5516
5519
5517
- for (int row = 0 ; row < N_DST; ++row) {
5520
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
5518
5521
all_sum = simd_sum (sumf[row]);
5519
5522
if (tiisg == 0 ) {
5520
5523
dst_f32[first_row + row] = all_sum * 0 .25f ;
@@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
5614
5617
5615
5618
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5616
5619
5617
- for (int row = 0 ; row < N_DST; ++row) {
5620
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
5618
5621
all_sum = simd_sum (sumf[row]);
5619
5622
if (tiisg == 0 ) {
5620
5623
dst_f32[first_row + row] = all_sum;
@@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5709
5712
5710
5713
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5711
5714
5712
- for (int row = 0 ; row < N_DST; ++row) {
5715
+ for (int row = 0 ; row < N_DST && first_row + row < args. ne0 ; ++row) {
5713
5716
all_sum = simd_sum (sumf[row]);
5714
5717
if (tiisg == 0 ) {
5715
5718
dst_f32[first_row + row] = all_sum;
@@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5799
5802
5800
5803
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5801
5804
5802
- for (int row = 0 ; row < 2 && first_row + row < args.ne01 ; ++row) {
5805
+ for (int row = 0 ; row < 2 && first_row + row < args.ne0 ; ++row) {
5803
5806
all_sum = simd_sum (sumf[row]);
5804
5807
if (tiisg == 0 ) {
5805
5808
dst_f32[first_row + row] = all_sum;
@@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5888
5891
5889
5892
device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5890
5893
5891
- for (int row = 0 ; row < 2 ; ++row) {
5894
+ for (int row = 0 ; row < 2 && first_row + row < args. ne0 ; ++row) {
5892
5895
all_sum = simd_sum (sumf[row]);
5893
5896
if (tiisg == 0 ) {
5894
5897
dst_f32[first_row + row] = all_sum;
0 commit comments