Spaces:
Sleeping
Sleeping
| // kernel parameters for mat-vec threadgroups | |
| // | |
| // N_R0: number of src0 rows to process per simdgroup | |
| // N_SG: number of simdgroups per threadgroup | |
| // | |
| // TODO: for optimal performance, become function of the device and work size | |
| // kernel argument structs | |
| // | |
| // - element counters (e.g. ne00) typically use int32_t to reduce register usage | |
| // however, be careful from int overflows when using those in the kernel implementation | |
| // | |
| // - strides (e.g. nb00) use uint64_t | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| int32_t dim; | |
| } ggml_metal_kargs_concat; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| uint64_t offs; | |
| uint64_t o1[8]; | |
| } ggml_metal_kargs_bin; | |
| typedef struct { | |
| int64_t ne0; | |
| int64_t ne1; | |
| size_t nb01; | |
| size_t nb02; | |
| size_t nb11; | |
| size_t nb21; | |
| } ggml_metal_kargs_add_id; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_repeat; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_cpy; | |
| typedef struct { | |
| int64_t ne10; | |
| int64_t ne11; | |
| int64_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| uint64_t offs; | |
| bool inplace; | |
| } ggml_metal_kargs_set; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| int32_t n_past; | |
| int32_t n_dims; | |
| int32_t n_ctx_orig; | |
| float freq_base; | |
| float freq_scale; | |
| float ext_factor; | |
| float attn_factor; | |
| float beta_fast; | |
| float beta_slow; | |
| int32_t sect_0; | |
| int32_t sect_1; | |
| int32_t sect_2; | |
| int32_t sect_3; | |
| } ggml_metal_kargs_rope; | |
| typedef struct { | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne11; | |
| int32_t ne_12_2; // assume K and V are same shape | |
| int32_t ne_12_3; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| uint64_t nb21; | |
| uint64_t nb22; | |
| uint64_t nb23; | |
| int32_t ne32; | |
| int32_t ne33; | |
| uint64_t nb31; | |
| uint64_t nb32; | |
| uint64_t nb33; | |
| int32_t ne1; | |
| int32_t ne2; | |
| float scale; | |
| float max_bias; | |
| float m0; | |
| float m1; | |
| int32_t n_head_log2; | |
| float logit_softcap; | |
| } ggml_metal_kargs_flash_attn_ext; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne02; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| } ggml_metal_kargs_mul_mm; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| } ggml_metal_kargs_mul_mv; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| int16_t nsg; | |
| int16_t nxpsg; | |
| int16_t r1ptg; | |
| } ggml_metal_kargs_mul_mv_ext; | |
| typedef struct { | |
| int32_t ne10; | |
| int32_t ne11; // n_expert_used (bcast) | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| int32_t neh11; // n_tokens | |
| uint64_t nbh11; | |
| int32_t ne20; // n_expert_used | |
| uint64_t nb21; | |
| } ggml_metal_kargs_mul_mm_id_map0; | |
| typedef struct { | |
| int32_t ne20; // n_expert_used | |
| int32_t neh0; | |
| int32_t neh1; | |
| uint64_t nbh1; | |
| uint64_t nbh2; | |
| int32_t ne0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| } ggml_metal_kargs_mul_mm_id_map1; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne02; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t neh12; | |
| uint64_t nbh10; | |
| uint64_t nbh11; | |
| uint64_t nbh12; | |
| uint64_t nbh13; | |
| int32_t neh0; | |
| int32_t neh1; | |
| int16_t r2; | |
| int16_t r3; | |
| } ggml_metal_kargs_mul_mm_id; | |
| typedef struct { | |
| int32_t nei0; | |
| int32_t nei1; | |
| uint64_t nbi1; | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| int32_t ne0; | |
| int32_t ne1; | |
| uint64_t nb1; | |
| } ggml_metal_kargs_mul_mv_id; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne00_4; | |
| uint64_t nb01; | |
| float eps; | |
| } ggml_metal_kargs_norm; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne00_4; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| float eps; | |
| int32_t nef1[3]; | |
| int32_t nef2[3]; | |
| int32_t nef3[3]; | |
| uint64_t nbf1[3]; | |
| uint64_t nbf2[3]; | |
| uint64_t nbf3[3]; | |
| } ggml_metal_kargs_rms_norm; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne00_4; | |
| uint64_t nb01; | |
| float eps; | |
| } ggml_metal_kargs_l2_norm; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int32_t n_groups; | |
| float eps; | |
| } ggml_metal_kargs_group_norm; | |
| typedef struct { | |
| int32_t IC; | |
| int32_t IL; | |
| int32_t K; | |
| int32_t s0; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| } ggml_metal_kargs_conv_transpose_1d; | |
| typedef struct { | |
| uint64_t ofs0; | |
| uint64_t ofs1; | |
| int32_t IW; | |
| int32_t IH; | |
| int32_t CHW; | |
| int32_t s0; | |
| int32_t s1; | |
| int32_t p0; | |
| int32_t p1; | |
| int32_t d0; | |
| int32_t d1; | |
| int32_t N; | |
| int32_t KH; | |
| int32_t KW; | |
| int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources | |
| } ggml_metal_kargs_im2col; | |
| typedef struct{ | |
| int32_t ne00; | |
| uint64_t nb01; | |
| int32_t ne10; | |
| uint64_t nb11; | |
| int32_t ne0; | |
| uint64_t nb1; | |
| int32_t i00; | |
| int32_t i10; | |
| float alpha; | |
| float limit; | |
| } ggml_metal_kargs_glu; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne10; | |
| int64_t ne11; | |
| int64_t ne12; | |
| int64_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_sum_rows; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| float scale; | |
| float max_bias; | |
| float m0; | |
| float m1; | |
| int32_t n_head_log2; | |
| } ggml_metal_kargs_soft_max; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int n_past; | |
| } ggml_metal_kargs_diag_mask_inf; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int64_t ne10; | |
| int64_t ne11; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| } ggml_metal_kargs_ssm_conv; | |
| typedef struct { | |
| int64_t d_state; | |
| int64_t d_inner; | |
| int64_t n_head; | |
| int64_t n_group; | |
| int64_t n_seq_tokens; | |
| int64_t n_seqs; | |
| int64_t s_off; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| uint64_t nb21; | |
| uint64_t nb22; | |
| uint64_t nb31; | |
| uint64_t nb41; | |
| uint64_t nb42; | |
| uint64_t nb43; | |
| uint64_t nb51; | |
| uint64_t nb52; | |
| uint64_t nb53; | |
| } ggml_metal_kargs_ssm_scan; | |
| typedef struct { | |
| int64_t ne00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int64_t ne10; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| } ggml_metal_kargs_get_rows; | |
| typedef struct { | |
| int32_t nk0; | |
| int32_t ne01; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne11; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_set_rows; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| float sf0; | |
| float sf1; | |
| float sf2; | |
| float sf3; | |
| } ggml_metal_kargs_upscale; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_pad; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| int32_t p0; | |
| int32_t p1; | |
| } ggml_metal_kargs_pad_reflect_1d; | |
| typedef struct { | |
| uint64_t nb1; | |
| int dim; | |
| int max_period; | |
| } ggml_metal_kargs_timestep_embedding; | |
| typedef struct { | |
| float slope; | |
| } ggml_metal_kargs_leaky_relu; | |
| typedef struct { | |
| int64_t ncols; | |
| int64_t ncols_pad; | |
| } ggml_metal_kargs_argsort; | |
| typedef struct { | |
| int64_t ne0; | |
| float start; | |
| float step; | |
| } ggml_metal_kargs_arange; | |
| typedef struct { | |
| int32_t k0; | |
| int32_t k1; | |
| int32_t s0; | |
| int32_t s1; | |
| int32_t p0; | |
| int32_t p1; | |
| int64_t IH; | |
| int64_t IW; | |
| int64_t OH; | |
| int64_t OW; | |
| int64_t parallel_elements; | |
| } ggml_metal_kargs_pool_2d; | |