@@ -501,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
501
501
}
502
502
503
503
static __m256i lasx_extu8_16 (__m128i a ) {
504
- __m128i zero = __lsx_vldi (0 );
505
- __m128i vlo = __lsx_vilvl_b (zero , a );
506
- __m128i vhi = __lsx_vilvh_b (zero , a );
507
- return lasx_set_q (vhi , vlo );
504
+ return __lasx_vext2xv_hu_bu (____m256i (a ));
508
505
}
509
506
510
507
static __m256i lasx_ext8_16 (__m128i a ) {
511
- __m128i sign = __lsx_vslti_b (a , 0 );
512
- __m128i vlo = __lsx_vilvl_b (sign , a );
513
- __m128i vhi = __lsx_vilvh_b (sign , a );
514
- return lasx_set_q (vhi , vlo );
508
+ return __lasx_vext2xv_h_b (____m256i (a ));
515
509
}
516
510
517
511
static __m256i lasx_ext16_32 (__m128i a ) {
518
- __m256i tmp1 ;
519
- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 0 ), 0 );
520
- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 1 ), 1 );
521
- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 2 ), 2 );
522
- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 3 ), 3 );
523
- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 4 ), 4 );
524
- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 5 ), 5 );
525
- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 6 ), 6 );
526
- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 7 ), 7 );
527
- return tmp1 ;
512
+ return __lasx_vext2xv_w_h (____m256i (a ));
528
513
}
529
514
530
515
static __m128i lasx_extracti128 ( __m256i a , int pos ) {
@@ -592,12 +577,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
592
577
// horizontally add 8 floats
593
578
static inline float hsum_float_8 (const __m256 x ) {
594
579
__m128 res = lasx_extractf128 (x , 1 );
595
- ft_union tmp ;
596
580
res = __lsx_vfadd_s (res , lasx_extractf128 (x , 0 ));
597
581
res = __lsx_vfadd_s (res , (__m128 )__lsx_vpickod_d ((__m128i )res , (__m128i )res ));
598
582
res = __lsx_vfadd_s (res , (__m128 )__lsx_vinsgr2vr_w (__lsx_vldi (0 ), __lsx_vpickve2gr_w (res , 1 ), 0 ));
599
- tmp .i = __lsx_vpickve2gr_w (res , 0 );
600
- return tmp .f ;
583
+ return ((v4f32 )res )[0 ];
601
584
}
602
585
603
586
// horizontally add 8 int32_t
@@ -939,7 +922,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
939
922
940
923
#elif defined(__loongarch_asx )
941
924
for (int i = 0 ; i < nb ; i ++ ) {
942
- ft_union fi ;
943
925
__m256 v0 = (__m256 )__lasx_xvld ( x , 0 );
944
926
__m256 v1 = (__m256 )__lasx_xvld ( x , 32 );
945
927
__m256 v2 = (__m256 )__lasx_xvld ( x , 64 );
@@ -957,8 +939,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
957
939
max4 = __lsx_vfmax_s ( max4 , (__m128 )__lsx_vpickod_d ((__m128i ) max4 , (__m128i )max4 ) );
958
940
__m128 tmp = max4 ;
959
941
max4 = __lsx_vfmax_s ( max4 , (__m128 )__lsx_vinsgr2vr_w (tmp , __lsx_vpickve2gr_w ( max4 , 1 ), 0 ));
960
- fi .i = __lsx_vpickve2gr_w ( (__m128i )max4 , 0 );
961
- const float max_scalar = fi .f ;
942
+ const float max_scalar = ((v4f32 )max4 )[0 ];
962
943
963
944
// Quantize these floats
964
945
const float d = max_scalar / 127.f ;
@@ -1263,7 +1244,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1263
1244
1264
1245
#elif defined(__loongarch_asx )
1265
1246
for (int i = 0 ; i < nb ; i ++ ) {
1266
- ft_union ft ;
1267
1247
__m256 v0 = (__m256 )__lasx_xvld ( x , 0 );
1268
1248
__m256 v1 = (__m256 )__lasx_xvld ( x , 32 );
1269
1249
__m256 v2 = (__m256 )__lasx_xvld ( x , 64 );
@@ -1281,8 +1261,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1281
1261
max4 = __lsx_vfmax_s ( max4 , (__m128 )__lsx_vpickod_d ((__m128i ) max4 , (__m128i )max4 ) );
1282
1262
__m128 tmp = max4 ;
1283
1263
max4 = __lsx_vfmax_s ( max4 , (__m128 )__lsx_vextrins_w ((__m128i )tmp , (__m128i )max4 , 0x10 ));
1284
- ft .i = __lsx_vpickve2gr_w ( (__m128i )max4 , 0 );
1285
- const float max_scalar = ft .f ;
1264
+ const float max_scalar = ((v4f32 )max4 )[0 ];
1286
1265
1287
1266
// Quantize these floats
1288
1267
const float d = max_scalar / 127.f ;
@@ -6154,9 +6133,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6154
6133
acc_m = __lsx_vfadd_s (acc_m , (__m128 )tmp1 );
6155
6134
6156
6135
6157
- ft_union fi ;
6158
- fi .i = __lsx_vpickve2gr_w (acc_m , 0 );
6159
- * s = hsum_float_8 (acc ) + fi .f ;
6136
+ * s = hsum_float_8 (acc ) + ((v4f32 )acc_m )[0 ];
6160
6137
#else
6161
6138
6162
6139
const uint8_t * scales = (const uint8_t * )& utmp [0 ];
0 commit comments