Skip to content

Commit 8862e42

Browse files
MQ-mengqingmglambda
authored andcommitted
ggml : optimize and build warning fix for LoongArch (ggml-org#11709)
* ggml : optimize convert f32<->f16 for loongarch_asx * ggml : optimize loongarch_asx extend i16,i8,u8 to i32,i16 * ggml : Fix warnings when run cpu CI locally on LoongArch
1 parent 8af0ab8 commit 8862e42

File tree

3 files changed

+22
-57
lines changed

3 files changed

+22
-57
lines changed

ggml/src/ggml-cpu/ggml-cpu-impl.h

+6-12
Original file line numberDiff line numberDiff line change
@@ -360,21 +360,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
360360
#endif
361361

362362
#if defined(__loongarch_asx)
363-
364-
typedef union {
365-
int32_t i;
366-
float f;
367-
} ft_union;
368-
369363
/* float type data load instructions */
370-
static __m128 __lsx_vreplfr2vr_s(float val) {
371-
ft_union fi_tmpval = {.f = val};
372-
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
364+
static __m128 __lsx_vreplfr2vr_s(const float val) {
365+
v4f32 res = {val, val, val, val};
366+
return (__m128)res;
373367
}
374368

375-
static __m256 __lasx_xvreplfr2vr_s(float val) {
376-
ft_union fi_tmpval = {.f = val};
377-
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
369+
static __m256 __lasx_xvreplfr2vr_s(const float val) {
370+
v8f32 res = {val, val, val, val, val, val, val, val};
371+
return (__m256)res;
378372
}
379373
#endif
380374

ggml/src/ggml-cpu/ggml-cpu-quants.c

+7-30
Original file line numberDiff line numberDiff line change
@@ -501,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
501501
}
502502

503503
static __m256i lasx_extu8_16(__m128i a) {
504-
__m128i zero = __lsx_vldi(0);
505-
__m128i vlo = __lsx_vilvl_b(zero, a);
506-
__m128i vhi = __lsx_vilvh_b(zero, a);
507-
return lasx_set_q(vhi, vlo);
504+
return __lasx_vext2xv_hu_bu(____m256i(a));
508505
}
509506

510507
static __m256i lasx_ext8_16(__m128i a) {
511-
__m128i sign = __lsx_vslti_b(a, 0);
512-
__m128i vlo = __lsx_vilvl_b(sign, a);
513-
__m128i vhi = __lsx_vilvh_b(sign, a);
514-
return lasx_set_q(vhi, vlo);
508+
return __lasx_vext2xv_h_b(____m256i(a));
515509
}
516510

517511
static __m256i lasx_ext16_32(__m128i a) {
518-
__m256i tmp1;
519-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
520-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
521-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
522-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
523-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
524-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
525-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
526-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
527-
return tmp1;
512+
return __lasx_vext2xv_w_h(____m256i(a));
528513
}
529514

530515
static __m128i lasx_extracti128( __m256i a, int pos) {
@@ -592,12 +577,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
592577
// horizontally add 8 floats
593578
static inline float hsum_float_8(const __m256 x) {
594579
__m128 res = lasx_extractf128(x, 1);
595-
ft_union tmp;
596580
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
597581
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
598582
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
599-
tmp.i = __lsx_vpickve2gr_w(res, 0);
600-
return tmp.f;
583+
return ((v4f32)res)[0];
601584
}
602585

603586
// horizontally add 8 int32_t
@@ -939,7 +922,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
939922

940923
#elif defined(__loongarch_asx)
941924
for (int i = 0; i < nb; i++) {
942-
ft_union fi;
943925
__m256 v0 = (__m256)__lasx_xvld( x , 0);
944926
__m256 v1 = (__m256)__lasx_xvld( x , 32);
945927
__m256 v2 = (__m256)__lasx_xvld( x , 64);
@@ -957,8 +939,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
957939
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
958940
__m128 tmp = max4;
959941
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
960-
fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
961-
const float max_scalar = fi.f;
942+
const float max_scalar = ((v4f32)max4)[0];
962943

963944
// Quantize these floats
964945
const float d = max_scalar / 127.f;
@@ -1263,7 +1244,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
12631244

12641245
#elif defined(__loongarch_asx)
12651246
for (int i = 0; i < nb; i++) {
1266-
ft_union ft;
12671247
__m256 v0 = (__m256)__lasx_xvld( x , 0 );
12681248
__m256 v1 = (__m256)__lasx_xvld( x , 32 );
12691249
__m256 v2 = (__m256)__lasx_xvld( x , 64 );
@@ -1281,8 +1261,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
12811261
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
12821262
__m128 tmp = max4;
12831263
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
1284-
ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
1285-
const float max_scalar = ft.f;
1264+
const float max_scalar = ((v4f32)max4)[0];
12861265

12871266
// Quantize these floats
12881267
const float d = max_scalar / 127.f;
@@ -6154,9 +6133,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
61546133
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
61556134

61566135

6157-
ft_union fi;
6158-
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
6159-
*s = hsum_float_8(acc) + fi.f ;
6136+
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
61606137
#else
61616138

61626139
const uint8_t * scales = (const uint8_t*)&utmp[0];

ggml/src/ggml-cpu/ggml-cpu.c

+9-15
Original file line numberDiff line numberDiff line change
@@ -1078,29 +1078,23 @@ do { \
10781078
#define GGML_F16_STEP 32
10791079
#define GGML_F16_EPR 8
10801080

1081-
// F16 arithmetic is not supported by AVX, so we use F32 instead
1081+
// F16 arithmetic is not supported by LASX, so we use F32 instead
10821082

10831083
#define GGML_F32Cx8 __m256
10841084
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
10851085
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
10861086

10871087
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
1088-
float tmp[8];
1089-
1090-
for (int i = 0; i < 8; i++) {
1091-
tmp[i] = GGML_FP16_TO_FP32(x[i]);
1092-
}
1093-
1094-
return (__m256)__lasx_xvld(tmp, 0);
1088+
__m256i a;
1089+
memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
1090+
a = __lasx_xvpermi_d(a, 0 | (1 << 4));
1091+
return __lasx_xvfcvtl_s_h(a);
10951092
}
1096-
static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1097-
float arr[8];
10981093

1099-
__lasx_xvst(y, arr, 0);
1100-
1101-
for (int i = 0; i < 8; i++) {
1102-
x[i] = GGML_FP32_TO_FP16(arr[i]);
1103-
}
1094+
static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1095+
__m256i a = __lasx_xvfcvt_h_s(y, y);
1096+
a = __lasx_xvpermi_d(a, 0 | (2 << 2));
1097+
memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
11041098
}
11051099
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
11061100
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)

0 commit comments

Comments
 (0)