Justine Tunney commited on
Commit
c78b872
·
1 Parent(s): b441739

ggml : rewrite silu and softmax for cpu (llama/7154)

Browse files

This change upstreams llamafile's vectorized expf() functions. This lets
us compute softmax and silu more accurately than the short[65536] lookup
table that GGML previously used to make this operation go faster. We can
support aarch64 and sse2+ with the worst case rounding error of 2ulp. It
makes make -j8 tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf
go 1.5x faster for SSE2+FMA, 1.9x faster for AVX2+FMA and 2.1x on AVX512

Files changed (1) hide show
  1. ggml.c +283 -193
ggml.c CHANGED
@@ -165,9 +165,6 @@ void ggml_print_backtrace(void) {
165
  #define GGML_DEBUG 0
166
  #define GGML_GELU_FP16
167
  #define GGML_GELU_QUICK_FP16
168
- #define GGML_SILU_FP16
169
- // #define GGML_CROSS_ENTROPY_EXP_FP16
170
- // #define GGML_FLASH_ATTN_EXP_FP16
171
 
172
  #define GGML_SOFT_MAX_UNROLL 4
173
  #define GGML_VEC_DOT_UNROLL 2
@@ -318,12 +315,6 @@ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
318
  // precomputed quick gelu table for f16 (128 KB)
319
  static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
320
 
321
- // precomputed silu table for f16 (128 KB)
322
- static ggml_fp16_t ggml_table_silu_f16[1 << 16];
323
-
324
- // precomputed exp table for f16 (128 KB)
325
- static ggml_fp16_t ggml_table_exp_f16[1 << 16];
326
-
327
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
328
  float ggml_table_f32_f16[1 << 16];
329
 
@@ -2085,52 +2076,291 @@ inline static float ggml_silu_f32(float x) {
2085
  return x/(1.0f + expf(-x));
2086
  }
2087
 
2088
- //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
2089
- // const uint16_t * i16 = (const uint16_t *) x;
2090
- // for (int i = 0; i < n; ++i) {
2091
- // y[i] = ggml_table_silu_f16[i16[i]];
2092
- // }
2093
- //}
2094
 
2095
- #ifdef GGML_SILU_FP16
2096
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2097
- uint16_t t;
2098
- for (int i = 0; i < n; ++i) {
2099
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
2100
- memcpy(&t, &fp16, sizeof(uint16_t));
2101
- y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
2102
- }
2103
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2104
  #else
2105
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2106
- for (int i = 0; i < n; ++i) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2107
  y[i] = ggml_silu_f32(x[i]);
2108
  }
2109
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2110
  #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2111
 
2112
  inline static float ggml_silu_backward_f32(float x, float dy) {
2113
  const float s = 1.0f/(1.0f + expf(-x));
2114
  return dy*s*(1.0f + x*(1.0f - s));
2115
  }
2116
 
2117
- #ifdef GGML_SILU_FP16
2118
- inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2119
- for (int i = 0; i < n; ++i) {
2120
- // we did not use x[i] to compute forward silu but its f16 equivalent
2121
- // take derivative at f16 of x[i]:
2122
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
2123
- float usedx = GGML_FP16_TO_FP32(fp16);
2124
- dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
2125
- }
2126
- }
2127
- #else
2128
  inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2129
  for (int i = 0; i < n; ++i) {
2130
  dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
2131
  }
2132
  }
2133
- #endif
2134
 
2135
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
2136
  #ifndef GGML_USE_ACCELERATE
@@ -2922,8 +3152,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2922
  float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2923
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2924
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2925
- ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2926
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2927
  }
2928
 
2929
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -13600,22 +13828,7 @@ static void ggml_compute_forward_soft_max_f32(
13600
  float max = -INFINITY;
13601
  ggml_vec_max_f32(nc, &max, wp);
13602
 
13603
- ggml_float sum = 0.0;
13604
-
13605
- uint16_t scvt;
13606
- for (int i = 0; i < nc; i++) {
13607
- if (wp[i] == -INFINITY) {
13608
- dp[i] = 0.0f;
13609
- } else {
13610
- // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
13611
- ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
13612
- memcpy(&scvt, &s, sizeof(scvt));
13613
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
13614
- sum += (ggml_float)val;
13615
- dp[i] = val;
13616
- }
13617
- }
13618
-
13619
  assert(sum > 0.0);
13620
 
13621
  sum = 1.0/sum;
@@ -15374,37 +15587,7 @@ static void ggml_compute_forward_flash_attn_f32(
15374
  vvexpf(S, S, &Mup);
15375
  ggml_vec_sum_f32(Mup, &sum, S);
15376
  #else
15377
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15378
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15379
-
15380
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15381
- if (i >= masked_begin) {
15382
- break;
15383
- }
15384
- float * SS = S + i;
15385
-
15386
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15387
- if (i + j >= masked_begin) {
15388
- break;
15389
- } else if (SS[j] == -INFINITY) {
15390
- SS[j] = 0.0f;
15391
- } else {
15392
- #ifndef GGML_FLASH_ATTN_EXP_FP16
15393
- const float val = expf(SS[j] - max);
15394
- #else
15395
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15396
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15397
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15398
- #endif
15399
- sump[j] += (ggml_float)val;
15400
- SS[j] = val;
15401
- }
15402
- }
15403
- }
15404
-
15405
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15406
- sum += sump[i];
15407
- }
15408
  #endif
15409
  }
15410
 
@@ -15586,28 +15769,7 @@ static void ggml_compute_forward_flash_attn_f16(
15586
  vvexpf(S, S, &Mup);
15587
  ggml_vec_sum_f32(Mup, &sum, S);
15588
  #else
15589
- uint16_t scvt[GGML_SOFT_MAX_UNROLL];
15590
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15591
-
15592
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15593
- float * SS = S + i;
15594
-
15595
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15596
- if (SS[j] == -INFINITY) {
15597
- SS[j] = 0.0f;
15598
- } else {
15599
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15600
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15601
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15602
- sump[j] += (ggml_float)val;
15603
- SS[j] = val;
15604
- }
15605
- }
15606
- }
15607
-
15608
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15609
- sum += sump[i];
15610
- }
15611
  #endif
15612
  }
15613
 
@@ -16234,38 +16396,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
16234
  vvexpf(SM, SM, &Mup);
16235
  ggml_vec_sum_f32(Mup, &sum, SM);
16236
  #else
16237
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
16238
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
16239
-
16240
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
16241
- if (i >= masked_begin) {
16242
- break;
16243
- }
16244
- float * SR = S + i;
16245
- float * SW = SM + i;
16246
-
16247
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
16248
- if (i + j >= masked_begin) {
16249
- break;
16250
- } else if (SR[j] == -INFINITY) {
16251
- SW[j] = 0.0f;
16252
- } else {
16253
- #ifndef GGML_FLASH_ATTN_EXP_FP16
16254
- const float val = expf(SR[j] - max);
16255
- #else
16256
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
16257
- memcpy(&scvt[j], &s, sizeof(uint16_t));
16258
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
16259
- #endif
16260
- sump[j] += (ggml_float)val;
16261
- SW[j] = val;
16262
- }
16263
- }
16264
- }
16265
-
16266
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
16267
- sum += sump[i];
16268
- }
16269
  #endif
16270
  }
16271
 
@@ -17291,35 +17422,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
17291
  assert(!isnan(s1[i]));
17292
  }
17293
  #endif
17294
- // soft_max
17295
- ggml_float sum = 0.0;
17296
- {
17297
- float max = -INFINITY;
17298
- ggml_vec_max_f32(nc, &max, s0);
17299
 
17300
- uint16_t scvt; UNUSED(scvt);
17301
- for (int i = 0; i < nc; i++) {
17302
- if (s0[i] == -INFINITY) {
17303
- st[i] = 0.0f;
17304
- } else {
17305
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
17306
- const float s = s0[i] - max;
17307
- const float val = expf(s);
17308
- #else
17309
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
17310
- memcpy(&scvt, &s, sizeof(scvt));
17311
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
17312
- #endif
17313
- sum += (ggml_float)val;
17314
- st[i] = val;
17315
- }
17316
- }
17317
 
17318
- assert(sum > 0.0);
17319
- // sum = 1.0/sum;
17320
- }
17321
  // avoid log(0) by rescaling from [0..1] to [eps..1]
17322
- sum = (1.0 - eps) / sum;
17323
  ggml_vec_scale_f32(nc, st, sum);
17324
  ggml_vec_add1_f32(nc, st, st, eps);
17325
  ggml_vec_log_f32(nc, st, st);
@@ -17409,32 +17520,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
17409
  #endif
17410
 
17411
  // soft_max
17412
- ggml_float sum = 0.0;
17413
- {
17414
- float max = -INFINITY;
17415
- ggml_vec_max_f32(nc, &max, s0);
17416
-
17417
- uint16_t scvt; UNUSED(scvt);
17418
- for (int i = 0; i < nc; i++) {
17419
- if (s0[i] == -INFINITY) {
17420
- ds0[i] = 0.0f;
17421
- } else {
17422
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
17423
- const float s = s0[i] - max;
17424
- const float val = expf(s);
17425
- #else
17426
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
17427
- memcpy(&scvt, &s, sizeof(scvt));
17428
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
17429
- #endif
17430
- sum += (ggml_float)val;
17431
- ds0[i] = val;
17432
- }
17433
- }
17434
-
17435
- assert(sum > 0.0);
17436
- sum = (1.0 - eps)/sum;
17437
- }
17438
 
17439
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17440
  ggml_vec_scale_f32(nc, ds0, sum);
 
165
  #define GGML_DEBUG 0
166
  #define GGML_GELU_FP16
167
  #define GGML_GELU_QUICK_FP16
 
 
 
168
 
169
  #define GGML_SOFT_MAX_UNROLL 4
170
  #define GGML_VEC_DOT_UNROLL 2
 
315
  // precomputed quick gelu table for f16 (128 KB)
316
  static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
317
 
 
 
 
 
 
 
318
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
319
  float ggml_table_f32_f16[1 << 16];
320
 
 
2076
  return x/(1.0f + expf(-x));
2077
  }
2078
 
2079
+ #if defined(__ARM_NEON)
 
 
 
 
 
2080
 
2081
+ // adapted from arm limited optimized routine
2082
+ // the maximum error is 1.45358 plus 0.5 ulps
2083
+ // numbers above 88.38 will flush to infinity
2084
+ // numbers beneath -103.97 will flush to zero
2085
+ inline static float32x4_t ggml_v_expf(float32x4_t x) {
2086
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
2087
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
2088
+ const float32x4_t n = vsubq_f32(z, r);
2089
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
2090
+ vdupq_n_f32(0x1.7f7d1cp-20f));
2091
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
2092
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
2093
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
2094
+ const float32x4_t u = vmulq_f32(b, b);
2095
+ const float32x4_t j = vfmaq_f32(
2096
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
2097
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
2098
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
2099
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
2100
+ return vfmaq_f32(k, j, k);
2101
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
2102
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
2103
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
2104
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
2105
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
2106
+ }
2107
+
2108
+ // computes silu x/(1+exp(-x)) in single precision vector
2109
+ inline static float32x4_t ggml_v_silu(float32x4_t x) {
2110
+ const float32x4_t one = vdupq_n_f32(1.0f);
2111
+ const float32x4_t zero = vdupq_n_f32(0.0f);
2112
+ const float32x4_t neg_x = vsubq_f32(zero, x);
2113
+ const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
2114
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
2115
+ return vdivq_f32(x, one_plus_exp_neg_x);
2116
+ }
2117
+
2118
+ #elif defined(__AVX512F__) && defined(__AVX512DQ__)
2119
+
2120
+ // adapted from arm limited optimized routine
2121
+ // the maximum error is 1.45358 plus 0.5 ulps
2122
+ // numbers above 88.38 will flush to infinity
2123
+ // numbers beneath -103.97 will flush to zero
2124
+ inline static __m512 ggml_v_expf(__m512 x) {
2125
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
2126
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2127
+ const __m512 n = _mm512_sub_ps(z, r);
2128
+ const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2129
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2130
+ const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2131
+ const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2132
+ const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2133
+ const __m512 u = _mm512_mul_ps(b, b);
2134
+ const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2135
+ _mm512_set1_ps(0x1.573e2ep-5f)), u,
2136
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2137
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2138
+ u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2139
+ if (_mm512_kortestz(c, c))
2140
+ return _mm512_fmadd_ps(j, k, k);
2141
+ const __m512i g = _mm512_and_si512(
2142
+ _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2143
+ _mm512_set1_epi32(0x82000000u));
2144
+ const __m512 s1 =
2145
+ _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2146
+ const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2147
+ const __mmask16 d =
2148
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2149
+ return _mm512_mask_blend_ps(
2150
+ d, _mm512_mask_blend_ps(
2151
+ c, _mm512_fmadd_ps(k, j, k),
2152
+ _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2153
+ _mm512_mul_ps(s1, s1));
2154
+ }
2155
+
2156
+ // computes silu x/(1+exp(-x)) in single precision vector
2157
+ inline static __m512 ggml_v_silu(__m512 x) {
2158
+ const __m512 one = _mm512_set1_ps(1);
2159
+ const __m512 zero = _mm512_setzero_ps();
2160
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
2161
+ const __m512 exp_neg_x = ggml_v_expf(neg_x);
2162
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
2163
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
2164
+ }
2165
+
2166
+ #elif defined(__AVX2__) && defined(__FMA__)
2167
+
2168
+ // adapted from arm limited optimized routine
2169
+ // the maximum error is 1.45358 plus 0.5 ulps
2170
+ // numbers above 88.38 will flush to infinity
2171
+ // numbers beneath -103.97 will flush to zero
2172
+ inline static __m256 ggml_v_expf(__m256 x) {
2173
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
2174
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
2175
+ const __m256 n = _mm256_sub_ps(z, r);
2176
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
2177
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
2178
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
2179
+ const __m256 k = _mm256_castsi256_ps(
2180
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
2181
+ const __m256i c = _mm256_castps_si256(
2182
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2183
+ _mm256_set1_ps(126), _CMP_GT_OQ));
2184
+ const __m256 u = _mm256_mul_ps(b, b);
2185
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
2186
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
2187
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
2188
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
2189
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
2190
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
2191
+ return _mm256_fmadd_ps(j, k, k);
2192
+ const __m256i g = _mm256_and_si256(
2193
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
2194
+ _mm256_set1_epi32(0x82000000u));
2195
+ const __m256 s1 =
2196
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
2197
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
2198
+ const __m256i d = _mm256_castps_si256(
2199
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2200
+ _mm256_set1_ps(192), _CMP_GT_OQ));
2201
+ return _mm256_or_ps(
2202
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
2203
+ _mm256_andnot_ps(
2204
+ _mm256_castsi256_ps(d),
2205
+ _mm256_or_ps(
2206
+ _mm256_and_ps(_mm256_castsi256_ps(c),
2207
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
2208
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
2209
+ }
2210
+
2211
+ // computes silu x/(1+exp(-x)) in single precision vector
2212
+ inline static __m256 ggml_v_silu(__m256 x) {
2213
+ const __m256 one = _mm256_set1_ps(1);
2214
+ const __m256 zero = _mm256_setzero_ps();
2215
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
2216
+ const __m256 exp_neg_x = ggml_v_expf(neg_x);
2217
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
2218
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
2219
+ }
2220
+
2221
+ #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
2222
+
2223
+ #if defined(__FMA__)
2224
+ #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
2225
+ #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
2226
  #else
2227
+ #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
2228
+ #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
2229
+ #endif
2230
+
2231
+ // adapted from arm limited optimized routine
2232
+ // the maximum error is 1.45358 plus 0.5 ulps
2233
+ // numbers above 88.38 will flush to infinity
2234
+ // numbers beneath -103.97 will flush to zero
2235
+ inline static __m128 ggml_v_expf(__m128 x) {
2236
+ const __m128 r = _mm_set1_ps(0x1.8p23f);
2237
+ const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
2238
+ const __m128 n = _mm_sub_ps(z, r);
2239
+ const __m128 b =
2240
+ NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
2241
+ const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
2242
+ const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
2243
+ const __m128i c =
2244
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
2245
+ const __m128 u = _mm_mul_ps(b, b);
2246
+ const __m128 j =
2247
+ MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
2248
+ MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
2249
+ u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
2250
+ if (!_mm_movemask_epi8(c))
2251
+ return MADD128(j, k, k);
2252
+ const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
2253
+ _mm_set1_epi32(0x82000000u));
2254
+ const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
2255
+ const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
2256
+ const __m128i d =
2257
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
2258
+ return _mm_or_ps(
2259
+ _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
2260
+ _mm_andnot_ps(_mm_castsi128_ps(d),
2261
+ _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
2262
+ _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
2263
+ }
2264
+
2265
+ // computes silu x/(1+exp(-x)) in single precision vector
2266
+ inline static __m128 ggml_v_silu(__m128 x) {
2267
+ const __m128 one = _mm_set1_ps(1);
2268
+ const __m128 zero = _mm_setzero_ps();
2269
+ const __m128 neg_x = _mm_sub_ps(zero, x);
2270
+ const __m128 exp_neg_x = ggml_v_expf(neg_x);
2271
+ const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
2272
+ return _mm_div_ps(x, one_plus_exp_neg_x);
2273
+ }
2274
+
2275
+ #endif // __ARM_NEON / __AVX2__ / __SSE2__
2276
+
2277
+ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2278
+ int i = 0;
2279
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2280
+ for (; i + 15 < n; i += 16) {
2281
+ _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
2282
+ }
2283
+ #elif defined(__AVX2__) && defined(__FMA__)
2284
+ for (; i + 7 < n; i += 8) {
2285
+ _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
2286
+ }
2287
+ #elif defined(__SSE2__)
2288
+ for (; i + 3 < n; i += 4) {
2289
+ _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2290
+ }
2291
+ #elif defined(__ARM_NEON)
2292
+ for (; i + 3 < n; i += 4) {
2293
+ vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2294
+ }
2295
+ #endif
2296
+ for (; i < n; ++i) {
2297
  y[i] = ggml_silu_f32(x[i]);
2298
  }
2299
  }
2300
+
2301
+ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
2302
+ int i = 0;
2303
+ ggml_float sum = 0;
2304
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2305
+ for (; i + 15 < n; i += 16) {
2306
+ __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
2307
+ _mm512_set1_ps(max)));
2308
+ _mm512_storeu_ps(y + i, val);
2309
+ sum += (ggml_float)_mm512_reduce_add_ps(val);
2310
+ }
2311
+ #elif defined(__AVX2__) && defined(__FMA__)
2312
+ for (; i + 7 < n; i += 8) {
2313
+ __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
2314
+ _mm256_set1_ps(max)));
2315
+ _mm256_storeu_ps(y + i, val);
2316
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
2317
+ _mm256_castps256_ps128(val));
2318
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
2319
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
2320
+ sum += (ggml_float)_mm_cvtss_f32(val2);
2321
+ }
2322
+ #elif defined(__SSE2__)
2323
+ for (; i + 3 < n; i += 4) {
2324
+ __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
2325
+ _mm_set1_ps(max)));
2326
+ _mm_storeu_ps(y + i, val);
2327
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
2328
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
2329
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
2330
+ #else
2331
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
2332
+ val = _mm_add_ps(val, tmp);
2333
+ tmp = _mm_movehl_ps(tmp, val);
2334
+ val = _mm_add_ss(val, tmp);
2335
  #endif
2336
+ sum += (ggml_float)_mm_cvtss_f32(val);
2337
+ }
2338
+ #elif defined(__ARM_NEON)
2339
+ for (; i + 3 < n; i += 4) {
2340
+ float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2341
+ vdupq_n_f32(max)));
2342
+ vst1q_f32(y + i, val);
2343
+ sum += (ggml_float)vaddvq_f32(val);
2344
+ }
2345
+ #endif
2346
+ for (; i < n; ++i) {
2347
+ float val = expf(x[i] - max);
2348
+ sum += (ggml_float)val;
2349
+ y[i] = val;
2350
+ }
2351
+ return sum;
2352
+ }
2353
 
2354
  inline static float ggml_silu_backward_f32(float x, float dy) {
2355
  const float s = 1.0f/(1.0f + expf(-x));
2356
  return dy*s*(1.0f + x*(1.0f - s));
2357
  }
2358
 
 
 
 
 
 
 
 
 
 
 
 
2359
  inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2360
  for (int i = 0; i < n; ++i) {
2361
  dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
2362
  }
2363
  }
 
2364
 
2365
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
2366
  #ifndef GGML_USE_ACCELERATE
 
3152
  float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
3153
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
3154
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
 
 
3155
  }
3156
 
3157
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
13828
  float max = -INFINITY;
13829
  ggml_vec_max_f32(nc, &max, wp);
13830
 
13831
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13832
  assert(sum > 0.0);
13833
 
13834
  sum = 1.0/sum;
 
15587
  vvexpf(S, S, &Mup);
15588
  ggml_vec_sum_f32(Mup, &sum, S);
15589
  #else
15590
+ sum = ggml_vec_soft_max_f32(Mup, S, S, max);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15591
  #endif
15592
  }
15593
 
 
15769
  vvexpf(S, S, &Mup);
15770
  ggml_vec_sum_f32(Mup, &sum, S);
15771
  #else
15772
+ sum = ggml_vec_soft_max_f32(Mup, S, S, max);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15773
  #endif
15774
  }
15775
 
 
16396
  vvexpf(SM, SM, &Mup);
16397
  ggml_vec_sum_f32(Mup, &sum, SM);
16398
  #else
16399
+ sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16400
  #endif
16401
  }
16402
 
 
17422
  assert(!isnan(s1[i]));
17423
  }
17424
  #endif
 
 
 
 
 
17425
 
17426
+ // soft_max
17427
+ float max = -INFINITY;
17428
+ ggml_vec_max_f32(nc, &max, s0);
17429
+ ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
17430
+ assert(sum > 0.0);
17431
+ sum = (1.0 - eps) / sum;
 
 
 
 
 
 
 
 
 
 
 
17432
 
 
 
 
17433
  // avoid log(0) by rescaling from [0..1] to [eps..1]
 
17434
  ggml_vec_scale_f32(nc, st, sum);
17435
  ggml_vec_add1_f32(nc, st, st, eps);
17436
  ggml_vec_log_f32(nc, st, st);
 
17520
  #endif
17521
 
17522
  // soft_max
17523
+ float max = -INFINITY;
17524
+ ggml_vec_max_f32(nc, &max, s0);
17525
+ ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
17526
+ assert(sum > 0.0);
17527
+ sum = (1.0 - eps) / sum;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17528
 
17529
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17530
  ggml_vec_scale_f32(nc, ds0, sum);