katsu560 commited on
Commit
00ac035
·
1 Parent(s): 71bf396

add AVX support

Browse files
Files changed (3) hide show
  1. ggml.c +166 -0
  2. ggml.h +1 -0
  3. whisper.cpp +1 -0
ggml.c CHANGED
@@ -372,6 +372,49 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
372
 
373
  sumf = _mm_cvtss_f32(r1);
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  // leftovers
376
  for (int i = n32; i < n; ++i) {
377
  sumf += x[i]*y[i];
@@ -569,6 +612,50 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
569
 
570
  sumf = _mm_cvtss_f32(r1);
571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  // leftovers
573
  for (int i = n32; i < n; ++i) {
574
  //GGML_ASSERT(false);
@@ -698,6 +785,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
698
  _mm256_storeu_ps(y + i + 24, y3);
699
  }
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  // leftovers
702
  for (int i = n32; i < n; ++i) {
703
  y[i] += x[i]*v;
@@ -859,6 +981,42 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
859
  _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
860
  }
861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  // leftovers
863
  for (int i = n32; i < n; ++i) {
864
  GGML_ASSERT(false);
@@ -8081,6 +8239,14 @@ enum ggml_opt_result ggml_opt(
8081
 
8082
  ////////////////////////////////////////////////////////////////////////////////
8083
 
 
 
 
 
 
 
 
 
8084
  int ggml_cpu_has_avx2(void) {
8085
  #if defined(__AVX2__)
8086
  return 1;
 
372
 
373
  sumf = _mm_cvtss_f32(r1);
374
 
375
+ // leftovers
376
+ for (int i = n32; i < n; ++i) {
377
+ sumf += x[i]*y[i];
378
+ }
379
+ #elif defined(__AVX__)
380
+ // AVX 256-bit
381
+ const int n32 = (n & ~31);
382
+
383
+ __m256 sum0 = _mm256_setzero_ps();
384
+ __m256 sum1 = _mm256_setzero_ps();
385
+ __m256 sum2 = _mm256_setzero_ps();
386
+ __m256 sum3 = _mm256_setzero_ps();
387
+
388
+ __m256 x0, x1, x2, x3;
389
+ __m256 y0, y1, y2, y3;
390
+
391
+ for (int i = 0; i < n32; i += 32) {
392
+ x0 = _mm256_loadu_ps(x + i + 0);
393
+ x1 = _mm256_loadu_ps(x + i + 8);
394
+ x2 = _mm256_loadu_ps(x + i + 16);
395
+ x3 = _mm256_loadu_ps(x + i + 24);
396
+
397
+ y0 = _mm256_loadu_ps(y + i + 0);
398
+ y1 = _mm256_loadu_ps(y + i + 8);
399
+ y2 = _mm256_loadu_ps(y + i + 16);
400
+ y3 = _mm256_loadu_ps(y + i + 24);
401
+
402
+ sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
403
+ sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
404
+ sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
405
+ sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
406
+ }
407
+
408
+ sum0 = _mm256_add_ps(sum0, sum1);
409
+ sum2 = _mm256_add_ps(sum2, sum3);
410
+ sum0 = _mm256_add_ps(sum0, sum2);
411
+
412
+ const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
413
+ const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
414
+ const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
415
+
416
+ sumf = _mm_cvtss_f32(r1);
417
+
418
  // leftovers
419
  for (int i = n32; i < n; ++i) {
420
  sumf += x[i]*y[i];
 
612
 
613
  sumf = _mm_cvtss_f32(r1);
614
 
615
+ // leftovers
616
+ for (int i = n32; i < n; ++i) {
617
+ //GGML_ASSERT(false);
618
+ sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
619
+ }
620
+ #elif defined(__AVX__)
621
+ // AVX 256-bit
622
+ const int n32 = (n & ~31);
623
+
624
+ __m256 sum0 = _mm256_setzero_ps();
625
+ __m256 sum1 = _mm256_setzero_ps();
626
+ __m256 sum2 = _mm256_setzero_ps();
627
+ __m256 sum3 = _mm256_setzero_ps();
628
+
629
+ __m256 x0, x1, x2, x3;
630
+ __m256 y0, y1, y2, y3;
631
+
632
+ for (int i = 0; i < n32; i += 32) {
633
+ x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
634
+ x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
635
+ x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
636
+ x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
637
+
638
+ y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
639
+ y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
640
+ y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
641
+ y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
642
+
643
+ sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
644
+ sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
645
+ sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
646
+ sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
647
+ }
648
+
649
+ const __m256 sum01 = _mm256_add_ps(sum0, sum1);
650
+ const __m256 sum23 = _mm256_add_ps(sum2, sum3);
651
+ const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
652
+
653
+ const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
654
+ const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
655
+ const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
656
+
657
+ sumf = _mm_cvtss_f32(r1);
658
+
659
  // leftovers
660
  for (int i = n32; i < n; ++i) {
661
  //GGML_ASSERT(false);
 
785
  _mm256_storeu_ps(y + i + 24, y3);
786
  }
787
 
788
+ // leftovers
789
+ for (int i = n32; i < n; ++i) {
790
+ y[i] += x[i]*v;
791
+ }
792
+ #elif defined(__AVX__)
793
+ // AVX 256-bit
794
+ const int n32 = (n & ~31);
795
+
796
+ const __m256 v4 = _mm256_set1_ps(v);
797
+
798
+ __m256 x0, x1, x2, x3;
799
+ __m256 y0, y1, y2, y3;
800
+
801
+ for (int i = 0; i < n32; i += 32) {
802
+ x0 = _mm256_loadu_ps(x + i + 0);
803
+ x1 = _mm256_loadu_ps(x + i + 8);
804
+ x2 = _mm256_loadu_ps(x + i + 16);
805
+ x3 = _mm256_loadu_ps(x + i + 24);
806
+
807
+ y0 = _mm256_loadu_ps(y + i + 0);
808
+ y1 = _mm256_loadu_ps(y + i + 8);
809
+ y2 = _mm256_loadu_ps(y + i + 16);
810
+ y3 = _mm256_loadu_ps(y + i + 24);
811
+
812
+ y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
813
+ y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
814
+ y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
815
+ y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
816
+
817
+ _mm256_storeu_ps(y + i + 0, y0);
818
+ _mm256_storeu_ps(y + i + 8, y1);
819
+ _mm256_storeu_ps(y + i + 16, y2);
820
+ _mm256_storeu_ps(y + i + 24, y3);
821
+ }
822
+
823
  // leftovers
824
  for (int i = n32; i < n; ++i) {
825
  y[i] += x[i]*v;
 
981
  _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
982
  }
983
 
984
+ // leftovers
985
+ for (int i = n32; i < n; ++i) {
986
+ GGML_ASSERT(false);
987
+ y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
988
+ }
989
+ #elif defined(__AVX__)
990
+ // AVX 256-bit
991
+ const int n32 = (n & ~31);
992
+
993
+ const __m256 v8 = _mm256_set1_ps(v);
994
+
995
+ __m256 x0, x1, x2, x3;
996
+ __m256 y0, y1, y2, y3;
997
+
998
+ for (int i = 0; i < n32; i += 32) {
999
+ y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
1000
+ y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
1001
+ y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
1002
+ y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
1003
+
1004
+ x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
1005
+ x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
1006
+ x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
1007
+ x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
1008
+
1009
+ y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
1010
+ y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
1011
+ y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
1012
+ y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
1013
+
1014
+ _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
1015
+ _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
1016
+ _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
1017
+ _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
1018
+ }
1019
+
1020
  // leftovers
1021
  for (int i = n32; i < n; ++i) {
1022
  GGML_ASSERT(false);
 
8239
 
8240
  ////////////////////////////////////////////////////////////////////////////////
8241
 
8242
+ int ggml_cpu_has_avx(void) {
8243
+ #if defined(__AVX__)
8244
+ return 1;
8245
+ #else
8246
+ return 0;
8247
+ #endif
8248
+ }
8249
+
8250
  int ggml_cpu_has_avx2(void) {
8251
  #if defined(__AVX2__)
8252
  return 1;
ggml.h CHANGED
@@ -723,6 +723,7 @@ enum ggml_opt_result ggml_opt(
723
  // system info
724
  //
725
 
 
726
  int ggml_cpu_has_avx2(void);
727
  int ggml_cpu_has_avx512(void);
728
  int ggml_cpu_has_neon(void);
 
723
  // system info
724
  //
725
 
726
+ int ggml_cpu_has_avx(void);
727
  int ggml_cpu_has_avx2(void);
728
  int ggml_cpu_has_avx512(void);
729
  int ggml_cpu_has_neon(void);
whisper.cpp CHANGED
@@ -3041,6 +3041,7 @@ const char * whisper_print_system_info() {
3041
  static std::string s;
3042
 
3043
  s = "";
 
3044
  s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
3045
  s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
3046
  s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
 
3041
  static std::string s;
3042
 
3043
  s = "";
3044
+ s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
3045
  s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
3046
  s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
3047
  s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";