Skip to content

Commit 2139667

Browse files
authored
metal : fix out-of-bounds write (#11314)
ggml-ci
1 parent 80d0d6b commit 2139667

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

ggml/src/ggml-metal/ggml-metal.metal

+17-14
Original file line numberDiff line numberDiff line change
@@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl(
44164416
device const half * dh = &x[ib].d;
44174417

44184418
for (int row = 0; row < N_DST; row++) {
4419-
44204419
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
44214420
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
44224421
for (int i = 0; i < 8; i += 2) {
@@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl(
44474446

44484447
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
44494448

4450-
for (int row = 0; row < N_DST; ++row) {
4449+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
44514450
all_sum = simd_sum(sumf[row]);
44524451
if (tiisg == 0) {
44534452
dst_f32[first_row + row] = all_sum;
@@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl(
46134612
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
46144613

46154614
if (tiisg == 0) {
4616-
for (int row = 0; row < 2; ++row) {
4615+
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
46174616
dst_f32[first_row + row] = sumf1[row];
46184617
}
46194618
}
@@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl(
47294728

47304729
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
47314730

4732-
for (int row = 0; row < N_DST; ++row) {
4731+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
47334732
all_sum = simd_sum(sumf[row]);
47344733
if (tiisg == 0) {
47354734
dst_f32[first_row + row] = all_sum;
@@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl(
48614860

48624861
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
48634862

4864-
for (int row = 0; row < 2; ++row) {
4863+
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
48654864
const float tot = simd_sum(sumf[row]);
48664865
if (tiisg == 0) {
48674866
dst_f32[first_row + row] = tot;
@@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl(
49064905

49074906
const int row = 2*r0 + sgitg;
49084907

4908+
if (row >= args.ne0) {
4909+
return;
4910+
}
4911+
49094912
const uint i12 = im%args.ne12;
49104913
const uint i13 = im/args.ne12;
49114914

@@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
50615064

50625065
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
50635066

5064-
for (int row = 0; row < N_DST; ++row) {
5067+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
50655068
all_sum = simd_sum(sumf[row]);
50665069
if (tiisg == 0) {
50675070
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
51795182

51805183
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
51815184

5182-
for (int row = 0; row < N_DST; ++row) {
5185+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
51835186
all_sum = simd_sum(sumf[row]);
51845187
if (tiisg == 0) {
51855188
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
52895292

52905293
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
52915294

5292-
for (int row = 0; row < N_DST; ++row) {
5295+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
52935296
all_sum = simd_sum(sumf[row]);
52945297
if (tiisg == 0) {
52955298
dst_f32[first_row + row] = all_sum * 0.5f;
@@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
54015404

54025405
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
54035406

5404-
for (int row = 0; row < N_DST; ++row) {
5407+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
54055408
all_sum = simd_sum(sumf[row]);
54065409
if (tiisg == 0) {
54075410
dst_f32[first_row + row] = all_sum;
@@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
55145517

55155518
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
55165519

5517-
for (int row = 0; row < N_DST; ++row) {
5520+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
55185521
all_sum = simd_sum(sumf[row]);
55195522
if (tiisg == 0) {
55205523
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
56145617

56155618
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
56165619

5617-
for (int row = 0; row < N_DST; ++row) {
5620+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
56185621
all_sum = simd_sum(sumf[row]);
56195622
if (tiisg == 0) {
56205623
dst_f32[first_row + row] = all_sum;
@@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
57095712

57105713
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
57115714

5712-
for (int row = 0; row < N_DST; ++row) {
5715+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
57135716
all_sum = simd_sum(sumf[row]);
57145717
if (tiisg == 0) {
57155718
dst_f32[first_row + row] = all_sum;
@@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
57995802

58005803
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
58015804

5802-
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
5805+
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
58035806
all_sum = simd_sum(sumf[row]);
58045807
if (tiisg == 0) {
58055808
dst_f32[first_row + row] = all_sum;
@@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
58885891

58895892
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
58905893

5891-
for (int row = 0; row < 2; ++row) {
5894+
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
58925895
all_sum = simd_sum(sumf[row]);
58935896
if (tiisg == 0) {
58945897
dst_f32[first_row + row] = all_sum;

0 commit comments

Comments
 (0)