lhez Skyler Szot Shangqing Gu Alexander Angus Hongqiang Wang Max Krasnyansky commited on
Commit
226358f
·
1 Parent(s): 25882f6

ggml : add opencl backend (skip) (llama/10693)

Browse files

---------

Co-authored-by: Skyler Szot <[email protected]>
Co-authored-by: Shangqing Gu <[email protected]>
Co-authored-by: Alexander Angus <[email protected]>
Co-authored-by: Hongqiang Wang <[email protected]>
Co-authored-by: Max Krasnyansky <[email protected]>

ggml/src/ggml-opencl/CMakeLists.txt ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ find_package(OpenCL REQUIRED)
2
+ find_package(Python3 REQUIRED)
3
+
4
+ set(TARGET_NAME ggml-opencl)
5
+
6
+ ggml_add_backend_library(${TARGET_NAME}
7
+ ggml-opencl.cpp
8
+ ../../include/ggml-opencl.h)
9
+ target_link_libraries(${TARGET_NAME} PRIVATE ${OpenCL_LIBRARIES})
10
+ target_include_directories(${TARGET_NAME} PRIVATE ${OpenCL_INCLUDE_DIRS})
11
+
12
+ if (GGML_OPENCL_PROFILING)
13
+ message(STATUS "OpenCL profiling enabled (increases CPU overhead)")
14
+ add_compile_definitions(GGML_OPENCL_PROFILING)
15
+ endif ()
16
+
17
+ add_compile_definitions(GGML_OPENCL_SOA_Q)
18
+
19
+ if (GGML_OPENCL_USE_ADRENO_KERNELS)
20
+ message(STATUS "OpenCL will use matmul kernels optimized for Adreno")
21
+ add_compile_definitions(GGML_OPENCL_USE_ADRENO_KERNELS)
22
+ endif ()
23
+
24
+ if (GGML_OPENCL_EMBED_KERNELS)
25
+ add_compile_definitions(GGML_OPENCL_EMBED_KERNELS)
26
+
27
+ set(OPENCL_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl.cl.h")
28
+ set(OPENCL_MM_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mm.cl.h")
29
+ set(OPENCL_CVT_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_cvt.cl.h")
30
+
31
+ set(OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle.cl.h")
32
+ set(OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle_general.cl.h")
33
+ set(OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h")
34
+ set(OPENCL_TRANSPOSE_16_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_16.cl.h")
35
+ set(OPENCL_TRANSPOSE_32_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32.cl.h")
36
+ set(OPENCL_TRANSPOSE_32_16_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32_16.cl.h")
37
+
38
+ set(EMBED_KERNEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py")
39
+ file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
40
+
41
+ include_directories("${CMAKE_BINARY_DIR}/autogenerated")
42
+
43
+ # Python must be accessible from command line
44
+ add_custom_command(
45
+ OUTPUT ${OPENCL_CL_SOURCE_EMBED}
46
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
47
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl.cl
48
+ ${OPENCL_CL_SOURCE_EMBED}
49
+ DEPENDS kernels/ggml-opencl.cl ${EMBED_KERNEL_SCRIPT}
50
+ COMMENT "Generate ggml-opencl.cl.h"
51
+ )
52
+
53
+ add_custom_command(
54
+ OUTPUT ${OPENCL_MM_CL_SOURCE_EMBED}
55
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
56
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mm.cl
57
+ ${OPENCL_MM_CL_SOURCE_EMBED}
58
+ DEPENDS kernels/ggml-opencl_mm.cl ${EMBED_KERNEL_SCRIPT}
59
+ COMMENT "Generate ggml-opencl_mm.cl.h"
60
+ )
61
+
62
+ add_custom_command(
63
+ OUTPUT ${OPENCL_CVT_CL_SOURCE_EMBED}
64
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
65
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_cvt.cl
66
+ ${OPENCL_CVT_CL_SOURCE_EMBED}
67
+ DEPENDS kernels/ggml-opencl_cvt.cl ${EMBED_KERNEL_SCRIPT}
68
+ COMMENT "Generate ggml-opencl_cvt.cl.h"
69
+ )
70
+
71
+ add_custom_command(
72
+ OUTPUT ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
73
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
74
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle.cl
75
+ ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
76
+ DEPENDS kernels/ggml-opencl_gemv_noshuffle.cl ${EMBED_KERNEL_SCRIPT}
77
+ COMMENT "Generate ggml-opencl_gemv_noshuffle.cl.h"
78
+ )
79
+
80
+ add_custom_command(
81
+ OUTPUT ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
82
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
83
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle_general.cl
84
+ ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
85
+ DEPENDS kernels/ggml-opencl_gemv_noshuffle_general.cl ${EMBED_KERNEL_SCRIPT}
86
+ COMMENT "Generate ggml-opencl_gemv_noshuffle_general.cl.h"
87
+ )
88
+
89
+ add_custom_command(
90
+ OUTPUT ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
91
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
92
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl
93
+ ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
94
+ DEPENDS kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${EMBED_KERNEL_SCRIPT}
95
+ COMMENT "Generate ggml-opencl_mul_mat_Ab_Bi_8x4.cl.cl.h"
96
+ )
97
+
98
+ add_custom_command(
99
+ OUTPUT ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
100
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
101
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_16.cl
102
+ ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
103
+ DEPENDS kernels/ggml-opencl_transpose_16.cl ${EMBED_KERNEL_SCRIPT}
104
+ COMMENT "Generate ggml-opencl_transpose_16.cl.h"
105
+ )
106
+
107
+ add_custom_command(
108
+ OUTPUT ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
109
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
110
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32.cl
111
+ ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
112
+ DEPENDS kernels/ggml-opencl_transpose_32.cl ${EMBED_KERNEL_SCRIPT}
113
+ COMMENT "Generate ggml-opencl_transpose_32.cl.h"
114
+ )
115
+
116
+ add_custom_command(
117
+ OUTPUT ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}
118
+ COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
119
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32_16.cl
120
+ ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}
121
+ DEPENDS kernels/ggml-opencl_transpose_32_16.cl ${EMBED_KERNEL_SCRIPT}
122
+ COMMENT "Generate ggml-opencl_transpose_32_16.cl.h"
123
+ )
124
+
125
+ target_sources(${TARGET_NAME} PRIVATE
126
+ ${OPENCL_CL_SOURCE_EMBED}
127
+ ${OPENCL_MM_CL_SOURCE_EMBED}
128
+ ${OPENCL_CVT_CL_SOURCE_EMBED}
129
+ ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
130
+ ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
131
+ ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
132
+ ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
133
+ ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
134
+ ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED})
135
+ else ()
136
+ # copy ggml-opencl.cl to bin directory
137
+ configure_file(kernels/ggml-opencl.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl.cl COPYONLY)
138
+ configure_file(kernels/ggml-opencl_mm.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mm.cl COPYONLY)
139
+ configure_file(kernels/ggml-opencl_cvt.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_cvt.cl COPYONLY)
140
+
141
+ configure_file(kernels/ggml-opencl_gemv_noshuffle.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle.cl COPYONLY)
142
+ configure_file(kernels/ggml-opencl_gemv_noshuffle_general.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle_general.cl COPYONLY)
143
+ configure_file(kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mul_mat_Ab_Bi_8x4.cl COPYONLY)
144
+ configure_file(kernels/ggml-opencl_transpose_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_16.cl COPYONLY)
145
+ configure_file(kernels/ggml-opencl_transpose_32.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32.cl COPYONLY)
146
+ configure_file(kernels/ggml-opencl_transpose_32_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32_16.cl COPYONLY)
147
+ endif ()
ggml/src/ggml-opencl/ggml-opencl.cpp ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-opencl/kernels/embed_kernel.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ import sys
4
+ import logging
5
+ logger = logging.getLogger("opencl-embed-kernel")
6
+
7
+
8
+ def main():
9
+ logging.basicConfig(level=logging.INFO)
10
+
11
+ if len(sys.argv) != 3:
12
+ logger.info("Usage: python embed_kernel.py <input_file> <output_file>")
13
+ sys.exit(1)
14
+
15
+ ifile = open(sys.argv[1], "r")
16
+ ofile = open(sys.argv[2], "w")
17
+
18
+ for i in ifile:
19
+ ofile.write('R"({})"\n'.format(i))
20
+
21
+ ifile.close()
22
+ ofile.close()
23
+
24
+
25
+ if __name__ == "__main__":
26
+ main()
ggml/src/ggml-opencl/kernels/ggml-opencl.cl ADDED
@@ -0,0 +1,2683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef cl_khr_fp16
2
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
3
+ #elif defined(cl_amd_fp16)
4
+ #pragma OPENCL EXTENSION cl_amd_fp16 : enable
5
+ #else
6
+ #error "Half precision floating point not supportedby OpenCL implementation on your device."
7
+ #endif
8
+
9
+ #ifdef cl_khr_subgroups
10
+ #pragma OPENCL EXTENSION cl_khr_subgroups : enable
11
+ #elif defined(cl_intel_subgroups)
12
+ #pragma OPENCL EXTENSION cl_intel_subgroups : enable
13
+ #else
14
+ #error "Subgroup not supported on your device."
15
+ #endif
16
+
17
+ #ifdef cl_intel_required_subgroup_size
18
+ // Always use subgroup size of 32 on Intel.
19
+ #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
20
+ #define INTEL_GPU 1
21
+ #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
22
+ #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
23
+ #elif defined(cl_qcom_reqd_sub_group_size)
24
+ // Always use subgroups size of 64 on Adreno.
25
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
26
+ #define ADRENO_GPU 1
27
+ #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
28
+ #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
29
+ #else
30
+ // TODO: do not know how to choose subgroup size on other GPUs.
31
+ #error "Selecting subgroup size is not supported on your device."
32
+ #endif
33
+
34
+ #define QK4_0 32
35
+ #define QR4_0 2
36
+ #define QK4_1 32
37
+ #define QR4_1 2
38
+ #define QK5_0 32
39
+ #define QR5_0 2
40
+ #define QK5_1 32
41
+ #define QR5_1 2
42
+ #define QK8_0 32
43
+ #define QR8_0 1
44
+ #define QK_K 256
45
+ #define K_QUANTS_PER_ITERATION 2
46
+
47
+ typedef char int8_t;
48
+ typedef uchar uint8_t;
49
+ typedef short int16_t;
50
+ typedef ushort uint16_t;
51
+ typedef int int32_t;
52
+ typedef uint uint32_t;
53
+
54
+ //------------------------------------------------------------------------------
55
+ // block_q4_0
56
+ //------------------------------------------------------------------------------
57
+ struct block_q4_0
58
+ {
59
+ half d;
60
+ uint8_t qs[QK4_0 / 2];
61
+ };
62
+
63
+ //------------------------------------------------------------------------------
64
+ // block_q4_1
65
+ //------------------------------------------------------------------------------
66
+ struct block_q4_1
67
+ {
68
+ half d;
69
+ half m;
70
+ uint8_t qs[QK4_1 / 2];
71
+ };
72
+
73
+ //------------------------------------------------------------------------------
74
+ // block_q5_0
75
+ //------------------------------------------------------------------------------
76
+ struct block_q5_0
77
+ {
78
+ half d;
79
+ uint32_t qh;
80
+ uint8_t qs[QK5_0 / 2];
81
+ };
82
+
83
+ //------------------------------------------------------------------------------
84
+ // block_q5_1
85
+ //------------------------------------------------------------------------------
86
+ struct block_q5_1
87
+ {
88
+ half d;
89
+ half m;
90
+ uint32_t qh;
91
+ uint8_t qs[QK5_1 / 2];
92
+ };
93
+
94
+ //------------------------------------------------------------------------------
95
+ // block_q8_0
96
+ //------------------------------------------------------------------------------
97
+ struct block_q8_0
98
+ {
99
+ half d;
100
+ int8_t qs[QK8_0];
101
+ };
102
+
103
+ //------------------------------------------------------------------------------
104
+ // block_q2_K
105
+ //------------------------------------------------------------------------------
106
+ struct block_q2_K
107
+ {
108
+ uint8_t scales[16];
109
+ uint8_t qs[64];
110
+ half d;
111
+ half dmin;
112
+ };
113
+
114
+ //------------------------------------------------------------------------------
115
+ // block_q3_K
116
+ //------------------------------------------------------------------------------
117
+ struct block_q3_K
118
+ {
119
+ uint8_t hmask[32];
120
+ uint8_t qs[64];
121
+ uint8_t scales[12];
122
+ half d;
123
+ };
124
+
125
+ //------------------------------------------------------------------------------
126
+ // block_q4_K
127
+ //------------------------------------------------------------------------------
128
+ struct block_q4_K
129
+ {
130
+ half d;
131
+ half dmin;
132
+ uint8_t scales[12];
133
+ uint8_t qs[128];
134
+ };
135
+
136
+ //------------------------------------------------------------------------------
137
+ // block_q5_K
138
+ //------------------------------------------------------------------------------
139
+ struct block_q5_K
140
+ {
141
+ half d;
142
+ half dmin;
143
+ uint8_t scales[12];
144
+ uint8_t qh[32];
145
+ uint8_t qs[128];
146
+ };
147
+
148
+ //------------------------------------------------------------------------------
149
+ // block_q6_K
150
+ //------------------------------------------------------------------------------
151
+ struct block_q6_K
152
+ {
153
+ uint8_t ql[128];
154
+ uint8_t qh[64];
155
+ int8_t scales[16];
156
+ half d;
157
+ };
158
+
159
+ //------------------------------------------------------------------------------
160
+ // dequantize_q4_0_f32, dequantize_q4_0_f16
161
+ //------------------------------------------------------------------------------
162
+ void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
163
+ global ushort * qs = ((global ushort *)xb + 1);
164
+ float d1 = il ? (xb->d / 16.h) : xb->d;
165
+ float d2 = d1 / 256.f;
166
+ float md = -8.h * xb->d;
167
+ ushort mask0 = il ? 0x00F0 : 0x000F;
168
+ ushort mask1 = mask0 << 8;
169
+
170
+ reg->s0 = d1 * (qs[0] & mask0) + md;
171
+ reg->s1 = d2 * (qs[0] & mask1) + md;
172
+
173
+ reg->s2 = d1 * (qs[1] & mask0) + md;
174
+ reg->s3 = d2 * (qs[1] & mask1) + md;
175
+
176
+ reg->s4 = d1 * (qs[2] & mask0) + md;
177
+ reg->s5 = d2 * (qs[2] & mask1) + md;
178
+
179
+ reg->s6 = d1 * (qs[3] & mask0) + md;
180
+ reg->s7 = d2 * (qs[3] & mask1) + md;
181
+
182
+ reg->s8 = d1 * (qs[4] & mask0) + md;
183
+ reg->s9 = d2 * (qs[4] & mask1) + md;
184
+
185
+ reg->sa = d1 * (qs[5] & mask0) + md;
186
+ reg->sb = d2 * (qs[5] & mask1) + md;
187
+
188
+ reg->sc = d1 * (qs[6] & mask0) + md;
189
+ reg->sd = d2 * (qs[6] & mask1) + md;
190
+
191
+ reg->se = d1 * (qs[7] & mask0) + md;
192
+ reg->sf = d2 * (qs[7] & mask1) + md;
193
+ }
194
+
195
+ void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) {
196
+ global ushort * qs = ((global ushort *)xb + 1);
197
+ half d1 = il ? (xb->d / 16.h) : xb->d;
198
+ half d2 = d1 / 256.h;
199
+ half md = -8.h * xb->d;
200
+ ushort mask0 = il ? 0x00F0 : 0x000F;
201
+ ushort mask1 = mask0 << 8;
202
+
203
+ reg->s0 = d1 * (qs[0] & mask0) + md;
204
+ reg->s1 = d2 * (qs[0] & mask1) + md;
205
+
206
+ reg->s2 = d1 * (qs[1] & mask0) + md;
207
+ reg->s3 = d2 * (qs[1] & mask1) + md;
208
+
209
+ reg->s4 = d1 * (qs[2] & mask0) + md;
210
+ reg->s5 = d2 * (qs[2] & mask1) + md;
211
+
212
+ reg->s6 = d1 * (qs[3] & mask0) + md;
213
+ reg->s7 = d2 * (qs[3] & mask1) + md;
214
+
215
+ reg->s8 = d1 * (qs[4] & mask0) + md;
216
+ reg->s9 = d2 * (qs[4] & mask1) + md;
217
+
218
+ reg->sa = d1 * (qs[5] & mask0) + md;
219
+ reg->sb = d2 * (qs[5] & mask1) + md;
220
+
221
+ reg->sc = d1 * (qs[6] & mask0) + md;
222
+ reg->sd = d2 * (qs[6] & mask1) + md;
223
+
224
+ reg->se = d1 * (qs[7] & mask0) + md;
225
+ reg->sf = d2 * (qs[7] & mask1) + md;
226
+ }
227
+
228
+ //------------------------------------------------------------------------------
229
+ // add
230
+ //------------------------------------------------------------------------------
231
+
232
+ // general-purpose kernel for addition of two tensors
233
+ // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
234
+ // cons: not very efficient
235
+ kernel void kernel_add(
236
+ global char * src0,
237
+ ulong offset0,
238
+ global char * src1,
239
+ ulong offset1,
240
+ global char * dst,
241
+ ulong offsetd,
242
+ int ne00,
243
+ int ne01,
244
+ int ne02,
245
+ int ne03,
246
+ ulong nb00,
247
+ ulong nb01,
248
+ ulong nb02,
249
+ ulong nb03,
250
+ int ne10,
251
+ int ne11,
252
+ int ne12,
253
+ int ne13,
254
+ ulong nb10,
255
+ ulong nb11,
256
+ ulong nb12,
257
+ ulong nb13,
258
+ int ne0,
259
+ int ne1,
260
+ int ne2,
261
+ int ne3,
262
+ ulong nb0,
263
+ ulong nb1,
264
+ ulong nb2,
265
+ ulong nb3
266
+ ) {
267
+ src0 = src0 + offset0;
268
+ src1 = src1 + offset1;
269
+ dst = dst + offsetd;
270
+
271
+ int i03 = get_group_id(2);
272
+ int i02 = get_group_id(1);
273
+ int i01 = get_group_id(0);
274
+
275
+ int i13 = i03 % ne13;
276
+ int i12 = i02 % ne12;
277
+ int i11 = i01 % ne11;
278
+
279
+ global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
280
+ global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
281
+ global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
282
+
283
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
284
+ const int i10 = i0 % ne10;
285
+ *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));
286
+ }
287
+ }
288
+
289
+ // assumption: src1 is a row
290
+ // broadcast src1 into src0
291
+ kernel void kernel_add_row(
292
+ global float4 * src0,
293
+ ulong offset0,
294
+ global float4 * src1,
295
+ ulong offset1,
296
+ global float4 * dst,
297
+ ulong offsetd,
298
+ int ne
299
+ ) {
300
+ src0 = (global float4*)((global char*)src0 + offset0);
301
+ src1 = (global float4*)((global char*)src1 + offset1);
302
+ dst = (global float4*)((global char*)dst + offsetd);
303
+
304
+ // This performs better than using %.
305
+ uint gid = get_global_id(0);
306
+ uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
307
+ dst[gid] = src0[gid] + src1[idx1];
308
+ }
309
+
310
+ //------------------------------------------------------------------------------
311
+ // mul
312
+ //------------------------------------------------------------------------------
313
+ kernel void kernel_mul(
314
+ global char * src0,
315
+ ulong offset0,
316
+ global char * src1,
317
+ ulong offset1,
318
+ global char * dst,
319
+ ulong offsetd,
320
+ int ne00,
321
+ int ne01,
322
+ int ne02,
323
+ int ne03,
324
+ ulong nb00,
325
+ ulong nb01,
326
+ ulong nb02,
327
+ ulong nb03,
328
+ int ne10,
329
+ int ne11,
330
+ int ne12,
331
+ int ne13,
332
+ ulong nb10,
333
+ ulong nb11,
334
+ ulong nb12,
335
+ ulong nb13,
336
+ int ne0,
337
+ int ne1,
338
+ int ne2,
339
+ int ne3,
340
+ ulong nb0,
341
+ ulong nb1,
342
+ ulong nb2,
343
+ ulong nb3
344
+ ) {
345
+ src0 = src0 + offset0;
346
+ src1 = src1 + offset1;
347
+ dst = dst + offsetd;
348
+
349
+ int i03 = get_group_id(2);
350
+ int i02 = get_group_id(1);
351
+ int i01 = get_group_id(0);
352
+
353
+ int i13 = i03 % ne13;
354
+ int i12 = i02 % ne12;
355
+ int i11 = i01 % ne11;
356
+
357
+ global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
358
+ global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
359
+ global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
360
+
361
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
362
+ const int i10 = i0 % ne10;
363
+ *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10));
364
+ }
365
+ }
366
+
367
+ // assumption: src1 is a row
368
+ // broadcast src1 into src0
369
+ kernel void kernel_mul_row(
370
+ global float4 * src0,
371
+ ulong offset0,
372
+ global float4 * src1,
373
+ ulong offset1,
374
+ global float4 * dst,
375
+ ulong offsetd,
376
+ int ne
377
+ ) {
378
+ src0 = (global float4*)((global char*)src0 + offset0);
379
+ src1 = (global float4*)((global char*)src1 + offset1);
380
+ dst = (global float4*)((global char*)dst + offsetd);
381
+
382
+ // This performs better than using %.
383
+ uint gid = get_global_id(0);
384
+ uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
385
+ dst[gid] = src0[gid] * src1[idx1];
386
+ }
387
+
388
+ //------------------------------------------------------------------------------
389
+ // scale
390
+ //------------------------------------------------------------------------------
391
+ kernel void kernel_scale(
392
+ global float4 * src0,
393
+ ulong offset0,
394
+ global float4 * dst,
395
+ ulong offsetd,
396
+ float scale
397
+ ) {
398
+ src0 = (global float4*)((global char*)src0 + offset0);
399
+ dst = (global float4*)((global char*)dst + offsetd);
400
+ dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
401
+ }
402
+
403
+ //------------------------------------------------------------------------------
404
+ // gelu
405
+ //------------------------------------------------------------------------------
406
+ #define GELU_COEF_A 0.044715f
407
+ #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
408
+
409
+ kernel void kernel_gelu(
410
+ global float * src0,
411
+ ulong offset0,
412
+ global float * dst,
413
+ ulong offsetd
414
+ ) {
415
+ src0 = (global float*)((global char*)src0 + offset0);
416
+ dst = (global float*)((global char*)dst + offsetd);
417
+
418
+ float x = src0[get_global_id(0)];
419
+
420
+ dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
421
+ }
422
+
423
+ kernel void kernel_gelu_4(
424
+ global float4 * src0,
425
+ ulong offset0,
426
+ global float4 * dst,
427
+ ulong offsetd
428
+ ) {
429
+ src0 = (global float4*)((global char*)src0 + offset0);
430
+ dst = (global float4*)((global char*)dst + offsetd);
431
+
432
+ float4 x = src0[get_global_id(0)];
433
+
434
+ dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
435
+ }
436
+
437
+ //------------------------------------------------------------------------------
438
+ // silu
439
+ //------------------------------------------------------------------------------
440
+ kernel void kernel_silu(
441
+ global float * src0,
442
+ ulong offset0,
443
+ global float * dst,
444
+ ulong offsetd
445
+ ) {
446
+ src0 = (global float*)((global char*)src0 + offset0);
447
+ dst = (global float*)((global char*)dst + offsetd);
448
+
449
+ float x = src0[get_global_id(0)];
450
+ dst[get_global_id(0)] = x / (1.0f + exp(-x));
451
+ }
452
+
453
+ kernel void kernel_silu_4(
454
+ global float4 * src0,
455
+ ulong offset0,
456
+ global float4 * dst,
457
+ ulong offsetd
458
+ ) {
459
+ src0 = (global float4*)((global char*)src0 + offset0);
460
+ dst = (global float4*)((global char*)dst + offsetd);
461
+
462
+ float4 x = src0[get_global_id(0)];
463
+ dst[get_global_id(0)] = x / (1.0f + exp(-x));
464
+ }
465
+
466
+ //------------------------------------------------------------------------------
467
+ // relu
468
+ //------------------------------------------------------------------------------
469
+ kernel void kernel_relu(
470
+ global float * src0,
471
+ ulong offset0,
472
+ global float * dst,
473
+ ulong offsetd
474
+ ) {
475
+ src0 = (global float*)((global char*)src0 + offset0);
476
+ dst = (global float*)((global char*)dst + offsetd);
477
+
478
+ dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]);
479
+ }
480
+
481
+ //------------------------------------------------------------------------------
482
+ // clamp
483
+ //------------------------------------------------------------------------------
484
+ kernel void kernel_clamp(
485
+ global float * src0,
486
+ ulong offset0,
487
+ global float * dst,
488
+ ulong offsetd,
489
+ float min,
490
+ float max
491
+ ) {
492
+ src0 = (global float*)((global char*)src0 + offset0);
493
+ dst = (global float*)((global char*)dst + offsetd);
494
+
495
+ dst[get_global_id(0)] = src0[get_global_id(0)] < min ?
496
+ min :
497
+ (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]);
498
+ }
499
+
500
+ //------------------------------------------------------------------------------
501
+ // norm
502
+ //------------------------------------------------------------------------------
503
+ kernel void kernel_norm(
504
+ global void * src0,
505
+ ulong offset0,
506
+ global float * dst,
507
+ ulong offsetd,
508
+ int ne00,
509
+ ulong nb01,
510
+ float eps,
511
+ local float * sum
512
+ ) {
513
+ src0 = (global void*)((global char*)src0 + offset0);
514
+ dst = (global void*)((global char*)dst + offsetd);
515
+
516
+ global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01);
517
+
518
+ // MEAN
519
+ // parallel sum
520
+ sum[get_local_id(0)] = 0.0f;
521
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
522
+ sum[get_local_id(0)] += x[i00];
523
+ }
524
+ // reduce
525
+ barrier(CLK_LOCAL_MEM_FENCE);
526
+ for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
527
+ if (get_local_id(0) < i) {
528
+ sum[get_local_id(0)] += sum[get_local_id(0) + i];
529
+ }
530
+ barrier(CLK_LOCAL_MEM_FENCE);
531
+ }
532
+ float mean = sum[0] / ne00;
533
+
534
+ // recenter and VARIANCE
535
+ barrier(CLK_LOCAL_MEM_FENCE);
536
+ global float * y = dst + get_group_id(0)*ne00;
537
+ sum[get_local_id(0)] = 0.0f;
538
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
539
+ y[i00] = x[i00] - mean;
540
+ sum[get_local_id(0)] += y[i00] * y[i00];
541
+ }
542
+
543
+ // reduce
544
+ barrier(CLK_LOCAL_MEM_FENCE);
545
+ for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
546
+ if (get_local_id(0) < i) {
547
+ sum[get_local_id(0)] += sum[get_local_id(0) + i];
548
+ }
549
+ barrier(CLK_LOCAL_MEM_FENCE);
550
+ }
551
+ float variance = sum[0] / ne00;
552
+
553
+ float scale = 1.0f/sqrt(variance + eps);
554
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
555
+ y[i00] = y[i00] * scale;
556
+ }
557
+ }
558
+
559
+ //------------------------------------------------------------------------------
560
+ // rms_norm
561
+ //------------------------------------------------------------------------------
562
+ // This kernel depends on subgroup size.
563
+ kernel void kernel_rms_norm(
564
+ global void * src0,
565
+ ulong offset0,
566
+ global float * dst,
567
+ ulong offsetd,
568
+ int ne00,
569
+ ulong nb01,
570
+ float eps,
571
+ local float * sum // Note, the size depends on number of subgroups
572
+ ) {
573
+ src0 = (global void*)((global char*)src0 + offset0);
574
+ dst = (global float*)((global char*)dst + offsetd);
575
+
576
+ global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01);
577
+ global float * x_scalar = (global float *) x;
578
+ float4 sumf = 0;
579
+ float all_sum = 0;
580
+
581
+ // parallel sum
582
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
583
+ sumf += x[i00] * x[i00];
584
+ }
585
+ all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;
586
+ all_sum = sub_group_reduce_add(all_sum);
587
+ if (get_sub_group_local_id() == 0) {
588
+ sum[get_sub_group_id()] = all_sum;
589
+ }
590
+
591
+ barrier(CLK_LOCAL_MEM_FENCE);
592
+ // broadcast
593
+ for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
594
+ if (get_local_id(0) < i) {
595
+ sum[get_local_id(0)] += sum[get_local_id(0) + i];
596
+ }
597
+ }
598
+ if (get_local_id(0) == 0) {
599
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
600
+ sum[0] += x_scalar[i];
601
+ }
602
+ sum[0] /= ne00;
603
+ }
604
+
605
+ barrier(CLK_LOCAL_MEM_FENCE);
606
+
607
+ const float mean = sum[0];
608
+ const float scale = 1.0f/sqrt(mean + eps);
609
+
610
+ global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00);
611
+ global float * y_scalar = (global float *) y;
612
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
613
+ y[i00] = x[i00] * scale;
614
+ }
615
+ if (get_local_id(0) == 0) {
616
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
617
+ y_scalar[i00] = x_scalar[i00] * scale;
618
+ }
619
+ }
620
+ }
621
+
622
+ //------------------------------------------------------------------------------
623
+ // diag_mask_inf kernels
624
+ //------------------------------------------------------------------------------
625
+ kernel void kernel_diag_mask_inf(
626
+ global float * src0,
627
+ ulong offset0,
628
+ global float * dst,
629
+ ulong offsetd,
630
+ int ne00,
631
+ int ne01,
632
+ int n_past
633
+ ) {
634
+ src0 = (global float*)((global char*)src0 + offset0);
635
+ dst = (global float*)((global char*)dst + offsetd);
636
+
637
+ int i02 = get_global_id(2);
638
+ int i01 = get_global_id(1);
639
+ int i00 = get_global_id(0);
640
+
641
+ if (i00 > n_past + i01) {
642
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
643
+ } else {
644
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
645
+ }
646
+ }
647
+
648
+ kernel void kernel_diag_mask_inf_8(
649
+ global float4 * src0,
650
+ ulong offset0,
651
+ global float4 * dst,
652
+ ulong offsetd,
653
+ int ne00,
654
+ int ne01,
655
+ int n_past
656
+ ) {
657
+ src0 = (global float4*)((global char*)src0 + offset0);
658
+ dst = (global float4*)((global char*)dst + offsetd);
659
+
660
+ int i = 2*get_global_id(0);
661
+
662
+ dst[i+0] = src0[i+0];
663
+ dst[i+1] = src0[i+1];
664
+ int i4 = 4*i;
665
+ int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
666
+ int i01 = i4/(ne00); i4 -= i01*ne00;
667
+ int i00 = i4;
668
+ for (int k = 3; k >= 0; --k) {
669
+ if (i00 + 4 + k <= n_past + i01) {
670
+ break;
671
+ }
672
+ (&dst[i+1])[k] = -INFINITY;
673
+ if (i00 + k > n_past + i01) {
674
+ (&dst[i])[k] = -INFINITY;
675
+ }
676
+ }
677
+ }
678
+
679
+ //------------------------------------------------------------------------------
680
+ // softmax
681
+ //------------------------------------------------------------------------------
682
+ kernel void kernel_soft_max(
683
+ global float * src0,
684
+ ulong offset0,
685
+ global float * src1,
686
+ ulong offset1,
687
+ global float * dst,
688
+ ulong offsetd,
689
+ int ne00,
690
+ int ne01,
691
+ int ne02,
692
+ float scale,
693
+ float max_bias,
694
+ float m0,
695
+ float m1,
696
+ int n_head_log2
697
+ ) {
698
+ src0 = (global float*)((global char*)src0 + offset0);
699
+ src1 = (global float*)((global char*)src1 + offset1);
700
+ dst = (global float*)((global char*)dst + offsetd);
701
+
702
+ int i03 = get_group_id(2);
703
+ int i02 = get_group_id(1);
704
+ int i01 = get_group_id(0);
705
+
706
+ global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
707
+ global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
708
+ global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
709
+
710
+ float slope = 1.0f;
711
+
712
+ // ALiBi
713
+ if (max_bias > 0.0f) {
714
+ int h = i02;
715
+
716
+ float base = h < n_head_log2 ? m0 : m1;
717
+ int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
718
+
719
+ slope = pow(base, exp);
720
+ }
721
+
722
+ // parallel max
723
+ float lmax = -INFINITY;
724
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
725
+ lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
726
+ }
727
+ float max = sub_group_reduce_max(lmax);
728
+
729
+ // parallel sum
730
+ float lsum = 0.0f;
731
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
732
+ float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
733
+ lsum += exp_psrc0;
734
+ // Remember the result of exp here. exp is expensive, so we really do not
735
+ // wish to compute it twice.
736
+ pdst[i00] = exp_psrc0;
737
+ }
738
+
739
+ const float sum = sub_group_reduce_add(lsum);
740
+
741
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
742
+ pdst[i00] /= sum;
743
+ }
744
+ }
745
+
746
+ #ifdef ADRENO_GPU
747
+ REQD_SUBGROUP_SIZE_64
748
+ #endif
749
+ kernel void kernel_soft_max_4(
750
+ global float * src0,
751
+ ulong offset0,
752
+ global float * src1,
753
+ ulong offset1,
754
+ global float * dst,
755
+ ulong offsetd,
756
+ int ne00,
757
+ int ne01,
758
+ int ne02,
759
+ float scale,
760
+ float max_bias,
761
+ float m0,
762
+ float m1,
763
+ int n_head_log2
764
+ ) {
765
+ src0 = (global float*)((global char*)src0 + offset0);
766
+ src1 = (global float*)((global char*)src1 + offset1);
767
+ dst = (global float*)((global char*)dst + offsetd);
768
+
769
+ int i03 = get_group_id(2);
770
+ int i02 = get_group_id(1);
771
+ int i01 = get_group_id(0);
772
+
773
+ global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
774
+ global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
775
+ global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
776
+
777
+ float slope = 1.0f;
778
+
779
+ // ALiBi
780
+ if (max_bias > 0.0f) {
781
+ int h = i02;
782
+
783
+ float base = h < n_head_log2 ? m0 : m1;
784
+ int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
785
+
786
+ slope = pow(base, exp);
787
+ }
788
+
789
+ // parallel max
790
+ float4 lmax4 = -INFINITY;
791
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
792
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
793
+ }
794
+ float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
795
+
796
+ const float max = sub_group_reduce_max(lmax);
797
+
798
+ // parallel sum
799
+ float4 lsum4 = 0.0f;
800
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
801
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
802
+ lsum4 += exp_psrc4;
803
+ pdst4[i00] = exp_psrc4;
804
+ }
805
+ float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
806
+
807
+ const float sum = sub_group_reduce_add(lsum);
808
+
809
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
810
+ pdst4[i00] /= sum;
811
+ }
812
+ }
813
+
814
+ //------------------------------------------------------------------------------
815
+ // kernel_rope
816
+ //------------------------------------------------------------------------------
817
+ float rope_yarn_ramp(float low, float high, int i0) {
818
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
819
+ return 1.0f - min(1.0f, max(0.0f, y));
820
+ }
821
+
822
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
823
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
824
+ float2 rope_yarn(
825
+ float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale
826
+ ) {
827
+ // Get n-d rotational scaling corrected for extrapolation
828
+ float theta_interp = freq_scale * theta_extrap;
829
+ float theta = theta_interp;
830
+ if (ext_factor != 0.0f) {
831
+ float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor;
832
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
833
+
834
+ // Get n-d magnitude scaling corrected for interpolation
835
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
836
+ }
837
+ return (float2)(cos(theta) * mscale, sin(theta) * mscale);
838
+ }
839
+
840
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
841
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
842
+ float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
843
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
844
+ }
845
+
846
+ float2 rope_yarn_corr_dims(
847
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow
848
+ ) {
849
+ // start and end correction dims
850
+ return (float2)(
851
+ max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))),
852
+ min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)))
853
+ );
854
+ }
855
+
856
+ kernel void kernel_rope_norm_f32(
857
+ global void * src0,
858
+ ulong offset0,
859
+ global int * src1,
860
+ ulong offset1,
861
+ global float * src2,
862
+ ulong offset2,
863
+ global float * dst,
864
+ ulong offsetd,
865
+ int ne00,
866
+ int ne01,
867
+ int ne02,
868
+ int ne03,
869
+ ulong nb00,
870
+ ulong nb01,
871
+ ulong nb02,
872
+ ulong nb03,
873
+ int ne0,
874
+ int ne1,
875
+ int ne2,
876
+ int ne3,
877
+ ulong nb0,
878
+ ulong nb1,
879
+ ulong nb2,
880
+ ulong nb3,
881
+ int n_past,
882
+ int n_dims,
883
+ int n_ctx_orig,
884
+ float freq_base,
885
+ float freq_scale,
886
+ float ext_factor,
887
+ float attn_factor,
888
+ float beta_fast,
889
+ float beta_slow
890
+ ) {
891
+ src0 = (global void*)((global char*)src0 + offset0);
892
+ src1 = (global int*)((global char*)src1 + offset1);
893
+ src2 = (global float*)((global char*)src2 + offset2);
894
+ dst = (global float*)((global char*)dst + offsetd);
895
+
896
+ int i3 = get_group_id(2);
897
+ int i2 = get_group_id(1);
898
+ int i1 = get_group_id(0);
899
+
900
+ float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
901
+
902
+ global int * pos = src1;
903
+
904
+ float theta_base = (float) pos[i2];
905
+ float inv_ndims = -1.f/n_dims;
906
+
907
+ for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
908
+ if (i0 < n_dims) {
909
+ int ic = i0/2;
910
+
911
+ float theta = theta_base * pow(freq_base, inv_ndims*i0);
912
+
913
+ float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
914
+
915
+ float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
916
+
917
+ global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
918
+ global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
919
+
920
+ float x0 = src[0];
921
+ float x1 = src[1];
922
+
923
+ dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
924
+ dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
925
+ } else {
926
+ global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
927
+ global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
928
+
929
+ dst_data[0] = src[0];
930
+ dst_data[1] = src[1];
931
+ }
932
+ }
933
+ }
934
+
935
+ kernel void kernel_rope_norm_f16(
936
+ global void * src0,
937
+ ulong offset0,
938
+ global int * src1,
939
+ ulong offset1,
940
+ global float * src2,
941
+ ulong offset2,
942
+ global float * dst,
943
+ ulong offsetd,
944
+ int ne00,
945
+ int ne01,
946
+ int ne02,
947
+ int ne03,
948
+ ulong nb00,
949
+ ulong nb01,
950
+ ulong nb02,
951
+ ulong nb03,
952
+ int ne0,
953
+ int ne1,
954
+ int ne2,
955
+ int ne3,
956
+ ulong nb0,
957
+ ulong nb1,
958
+ ulong nb2,
959
+ ulong nb3,
960
+ int n_past,
961
+ int n_dims,
962
+ int n_ctx_orig,
963
+ float freq_base,
964
+ float freq_scale,
965
+ float ext_factor,
966
+ float attn_factor,
967
+ float beta_fast,
968
+ float beta_slow
969
+ ) {
970
+ src0 = (global void*)((global char*)src0 + offset0);
971
+ src1 = (global int*)((global char*)src1 + offset1);
972
+ src2 = (global float*)((global char*)src2 + offset2);
973
+ dst = (global float*)((global char*)dst + offsetd);
974
+
975
+ int i3 = get_group_id(2);
976
+ int i2 = get_group_id(1);
977
+ int i1 = get_group_id(0);
978
+
979
+ float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
980
+
981
+ global int * pos = src1;
982
+
983
+ float theta_base = (float) pos[i2];
984
+ float inv_ndims = -1.f/n_dims;
985
+
986
+ for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
987
+ if (i0 < n_dims) {
988
+ int ic = i0/2;
989
+
990
+ float theta = theta_base * pow(freq_base, inv_ndims*i0);
991
+
992
+ float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
993
+
994
+ float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
995
+
996
+ global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
997
+ global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
998
+
999
+ float x0 = src[0];
1000
+ float x1 = src[1];
1001
+
1002
+ dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
1003
+ dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
1004
+ } else {
1005
+ global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1006
+ global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1007
+
1008
+ dst_data[0] = src[0];
1009
+ dst_data[1] = src[1];
1010
+ }
1011
+ }
1012
+ }
1013
+
1014
+ kernel void kernel_rope_neox_f32(
1015
+ global void * src0,
1016
+ ulong offset0,
1017
+ global int * src1,
1018
+ ulong offset1,
1019
+ global float * src2,
1020
+ ulong offset2,
1021
+ global float * dst,
1022
+ ulong offsetd,
1023
+ int ne00,
1024
+ int ne01,
1025
+ int ne02,
1026
+ int ne03,
1027
+ ulong nb00,
1028
+ ulong nb01,
1029
+ ulong nb02,
1030
+ ulong nb03,
1031
+ int ne0,
1032
+ int ne1,
1033
+ int ne2,
1034
+ int ne3,
1035
+ ulong nb0,
1036
+ ulong nb1,
1037
+ ulong nb2,
1038
+ ulong nb3,
1039
+ int n_past,
1040
+ int n_dims,
1041
+ int n_ctx_orig,
1042
+ float freq_base,
1043
+ float freq_scale,
1044
+ float ext_factor,
1045
+ float attn_factor,
1046
+ float beta_fast,
1047
+ float beta_slow
1048
+ ) {
1049
+ src0 = (global void*)((global char*)src0 + offset0);
1050
+ src1 = (global int*)((global char*)src1 + offset1);
1051
+ src2 = (global float*)((global char*)src2 + offset2);
1052
+ dst = (global float*)((global char*)dst + offsetd);
1053
+
1054
+ int i3 = get_group_id(2);
1055
+ int i2 = get_group_id(1);
1056
+ int i1 = get_group_id(0);
1057
+
1058
+ float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
1059
+
1060
+ global int * pos = src1;
1061
+
1062
+ float theta_base = (float) pos[i2];
1063
+ float inv_ndims = -1.f/n_dims;
1064
+
1065
+ for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
1066
+ if (i0 < n_dims) {
1067
+ int ic = i0/2;
1068
+
1069
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1070
+
1071
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1072
+
1073
+ float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
1074
+
1075
+ global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1076
+ global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1077
+
1078
+ const float x0 = src[0];
1079
+ const float x1 = src[n_dims/2];
1080
+
1081
+ dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
1082
+ dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
1083
+ } else {
1084
+ global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1085
+ global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1086
+
1087
+ dst_data[0] = src[0];
1088
+ dst_data[1] = src[1];
1089
+ }
1090
+ }
1091
+ }
1092
+
1093
+ kernel void kernel_rope_neox_f16(
1094
+ global void * src0,
1095
+ ulong offset0,
1096
+ global int * src1,
1097
+ ulong offset1,
1098
+ global float * src2,
1099
+ ulong offset2,
1100
+ global float * dst,
1101
+ ulong offsetd,
1102
+ int ne00,
1103
+ int ne01,
1104
+ int ne02,
1105
+ int ne03,
1106
+ ulong nb00,
1107
+ ulong nb01,
1108
+ ulong nb02,
1109
+ ulong nb03,
1110
+ int ne0,
1111
+ int ne1,
1112
+ int ne2,
1113
+ int ne3,
1114
+ ulong nb0,
1115
+ ulong nb1,
1116
+ ulong nb2,
1117
+ ulong nb3,
1118
+ int n_past,
1119
+ int n_dims,
1120
+ int n_ctx_orig,
1121
+ float freq_base,
1122
+ float freq_scale,
1123
+ float ext_factor,
1124
+ float attn_factor,
1125
+ float beta_fast,
1126
+ float beta_slow
1127
+ ) {
1128
+ src0 = (global void*)((global char*)src0 + offset0);
1129
+ src1 = (global int*)((global char*)src1 + offset1);
1130
+ src2 = (global float*)((global char*)src2 + offset2);
1131
+ dst = (global float*)((global char*)dst + offsetd);
1132
+
1133
+ int i3 = get_group_id(2);
1134
+ int i2 = get_group_id(1);
1135
+ int i1 = get_group_id(0);
1136
+
1137
+ float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
1138
+
1139
+ global int * pos = src1;
1140
+
1141
+ float theta_base = (float) pos[i2];
1142
+ float inv_ndims = -1.f/n_dims;
1143
+
1144
+ for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
1145
+ if (i0 < n_dims) {
1146
+ int ic = i0/2;
1147
+
1148
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1149
+
1150
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1151
+
1152
+ float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
1153
+
1154
+ global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1155
+ global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1156
+
1157
+ const float x0 = src[0];
1158
+ const float x1 = src[n_dims/2];
1159
+
1160
+ dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
1161
+ dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
1162
+ } else {
1163
+ global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1164
+ global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1165
+
1166
+ dst_data[0] = src[0];
1167
+ dst_data[1] = src[1];
1168
+ }
1169
+ }
1170
+ }
1171
+
1172
+ //------------------------------------------------------------------------------
1173
+ // cpy
1174
+ //------------------------------------------------------------------------------
1175
+
1176
+ kernel void kernel_cpy_f16_f16(
1177
+ global half * src0,
1178
+ ulong offset0,
1179
+ global half * dst,
1180
+ ulong offsetd,
1181
+ int ne00,
1182
+ int ne01,
1183
+ int ne02,
1184
+ int ne03,
1185
+ ulong nb00,
1186
+ ulong nb01,
1187
+ ulong nb02,
1188
+ ulong nb03,
1189
+ int ne0,
1190
+ int ne1,
1191
+ int ne2,
1192
+ int ne3,
1193
+ ulong nb0,
1194
+ ulong nb1,
1195
+ ulong nb2,
1196
+ ulong nb3
1197
+ ) {
1198
+ src0 = (global half*)((global char*)src0 + offset0);
1199
+ dst = (global half*)((global char*)dst + offsetd);
1200
+
1201
+ int i03 = get_group_id(2);
1202
+ int i02 = get_group_id(1);
1203
+ int i01 = get_group_id(0);
1204
+
1205
+ int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1206
+
1207
+ int i3 = n / (ne2*ne1*ne0);
1208
+ int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1209
+ int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1210
+ int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1211
+
1212
+ global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1213
+
1214
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
1215
+ global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1216
+ dst_data[i00] = src[0];
1217
+ }
1218
+ }
1219
+
1220
+ kernel void kernel_cpy_f16_f32(
1221
+ global half * src0,
1222
+ ulong offset0,
1223
+ global float * dst,
1224
+ ulong offsetd,
1225
+ int ne00,
1226
+ int ne01,
1227
+ int ne02,
1228
+ int ne03,
1229
+ ulong nb00,
1230
+ ulong nb01,
1231
+ ulong nb02,
1232
+ ulong nb03,
1233
+ int ne0,
1234
+ int ne1,
1235
+ int ne2,
1236
+ int ne3,
1237
+ ulong nb0,
1238
+ ulong nb1,
1239
+ ulong nb2,
1240
+ ulong nb3
1241
+ ) {
1242
+
1243
+ src0 = (global half*)((global char*)src0 + offset0);
1244
+ dst = (global float*)((global char*)dst + offsetd);
1245
+
1246
+ int i03 = get_group_id(2);
1247
+ int i02 = get_group_id(1);
1248
+ int i01 = get_group_id(0);
1249
+
1250
+ int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1251
+
1252
+ int i3 = n / (ne2*ne1*ne0);
1253
+ int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1254
+ int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1255
+ int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1256
+
1257
+ global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1258
+
1259
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
1260
+ global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1261
+ dst_data[i00] = src[0];
1262
+ }
1263
+ }
1264
+
1265
+ kernel void kernel_cpy_f32_f16(
1266
+ global float * src0,
1267
+ ulong offset0,
1268
+ global half * dst,
1269
+ ulong offsetd,
1270
+ int ne00,
1271
+ int ne01,
1272
+ int ne02,
1273
+ int ne03,
1274
+ ulong nb00,
1275
+ ulong nb01,
1276
+ ulong nb02,
1277
+ ulong nb03,
1278
+ int ne0,
1279
+ int ne1,
1280
+ int ne2,
1281
+ int ne3,
1282
+ ulong nb0,
1283
+ ulong nb1,
1284
+ ulong nb2,
1285
+ ulong nb3
1286
+ ) {
1287
+ src0 = (global float*)((global char*)src0 + offset0);
1288
+ dst = (global half*)((global char*)dst + offsetd);
1289
+
1290
+ int i03 = get_group_id(2);
1291
+ int i02 = get_group_id(1);
1292
+ int i01 = get_group_id(0);
1293
+
1294
+ int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1295
+
1296
+ int i3 = n / (ne2*ne1*ne0);
1297
+ int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1298
+ int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1299
+ int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1300
+
1301
+ global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1302
+
1303
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
1304
+ global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1305
+
1306
+ dst_data[i00] = src[0];
1307
+ }
1308
+ }
1309
+
1310
+ kernel void kernel_cpy_f32_f32(
1311
+ global float * src0,
1312
+ ulong offset0,
1313
+ global float * dst,
1314
+ ulong offsetd,
1315
+ int ne00,
1316
+ int ne01,
1317
+ int ne02,
1318
+ int ne03,
1319
+ ulong nb00,
1320
+ ulong nb01,
1321
+ ulong nb02,
1322
+ ulong nb03,
1323
+ int ne0,
1324
+ int ne1,
1325
+ int ne2,
1326
+ int ne3,
1327
+ ulong nb0,
1328
+ ulong nb1,
1329
+ ulong nb2,
1330
+ ulong nb3
1331
+ ) {
1332
+ src0 = (global float*)((global char*)src0 + offset0);
1333
+ dst = (global float*)((global char*)dst + offsetd);
1334
+
1335
+ int i03 = get_group_id(2);
1336
+ int i02 = get_group_id(1);
1337
+ int i01 = get_group_id(0);
1338
+
1339
+ int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1340
+
1341
+ int i3 = n / (ne2*ne1*ne0);
1342
+ int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1343
+ int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1344
+ int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1345
+
1346
+ global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1347
+
1348
+ for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
1349
+ global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1350
+
1351
+ dst_data[i00] = src[0];
1352
+ }
1353
+ }
1354
+
1355
+ //------------------------------------------------------------------------------
1356
+ // get_rows
1357
+ //------------------------------------------------------------------------------
1358
+ kernel void kernel_get_rows_f32(
1359
+ global void * src0,
1360
+ ulong offset0,
1361
+ global int * src1,
1362
+ ulong offset1,
1363
+ global float * dst,
1364
+ ulong offsetd,
1365
+ int ne00,
1366
+ ulong nb01,
1367
+ ulong nb02,
1368
+ int ne10,
1369
+ ulong nb10,
1370
+ ulong nb11,
1371
+ ulong nb1,
1372
+ ulong nb2
1373
+ ) {
1374
+ src0 = (global void*)((global char*)src0 + offset0);
1375
+ src1 = (global int*)((global char*)src1 + offset1);
1376
+ dst = (global float*)((global char*)dst + offsetd);
1377
+
1378
+ int i10 = get_group_id(0);
1379
+ int i11 = get_group_id(1);
1380
+
1381
+ int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
1382
+
1383
+ int i02 = i11;
1384
+
1385
+ for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
1386
+ ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
1387
+ ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
1388
+ }
1389
+ }
1390
+
1391
+ kernel void kernel_get_rows_f16(
1392
+ global void * src0,
1393
+ ulong offset0,
1394
+ global int * src1,
1395
+ ulong offset1,
1396
+ global float * dst,
1397
+ ulong offsetd,
1398
+ int ne00,
1399
+ ulong nb01,
1400
+ ulong nb02,
1401
+ int ne10,
1402
+ ulong nb10,
1403
+ ulong nb11,
1404
+ ulong nb1,
1405
+ ulong nb2
1406
+ ) {
1407
+ src0 = (global void*)((global char*)src0 + offset0);
1408
+ src1 = (global int*)((global char*)src1 + offset1);
1409
+ dst = (global float*)((global char*)dst + offsetd);
1410
+
1411
+ int i10 = get_group_id(0);
1412
+ int i11 = get_group_id(1);
1413
+
1414
+ int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
1415
+
1416
+ int i02 = i11;
1417
+
1418
+ for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
1419
+ ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
1420
+ ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
1421
+ }
1422
+ }
1423
+
1424
+ kernel void kernel_get_rows_q4_0(
1425
+ global void * src0,
1426
+ ulong offset0,
1427
+ global int * src1,
1428
+ ulong offset1,
1429
+ global float * dst,
1430
+ ulong offsetd,
1431
+ int ne00,
1432
+ ulong nb01,
1433
+ ulong nb02,
1434
+ int ne10,
1435
+ ulong nb10,
1436
+ ulong nb11,
1437
+ ulong nb1,
1438
+ ulong nb2
1439
+ ) {
1440
+ src0 = (global void*)((global char*)src0 + offset0);
1441
+ src1 = (global int*)((global char*)src1 + offset1);
1442
+ dst = (global float*)((global char*)dst + offsetd);
1443
+
1444
+ const int NL = 2;
1445
+
1446
+ int i10 = get_group_id(0);
1447
+ int i11 = get_group_id(1);
1448
+
1449
+ int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
1450
+
1451
+ int i02 = i11;
1452
+
1453
+ for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
1454
+ float16 temp;
1455
+ dequantize_q4_0_f32(
1456
+ ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
1457
+ *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
1458
+ }
1459
+ }
1460
+
1461
+ //------------------------------------------------------------------------------
1462
+ // mul_mat_f32_f32
1463
+ //------------------------------------------------------------------------------
1464
+ #define N_F32_F32 4
1465
+
1466
+ kernel void kernel_mul_mat_f32_f32(
1467
+ global char * src0,
1468
+ ulong offset0,
1469
+ global char * src1,
1470
+ ulong offset1,
1471
+ global float * dst,
1472
+ ulong offsetd,
1473
+ int ne00,
1474
+ int ne01,
1475
+ int ne02,
1476
+ ulong nb00,
1477
+ ulong nb01,
1478
+ ulong nb02,
1479
+ ulong nb03,
1480
+ int ne10,
1481
+ int ne11,
1482
+ int ne12,
1483
+ ulong nb10,
1484
+ ulong nb11,
1485
+ ulong nb12,
1486
+ ulong nb13,
1487
+ int ne0,
1488
+ int ne1,
1489
+ int r2,
1490
+ int r3
1491
+ ) {
1492
+ src0 = (global char*)((global char*)src0 + offset0);
1493
+ src1 = (global char*)((global char*)src1 + offset1);
1494
+ dst = (global float*)((global char*)dst + offsetd);
1495
+
1496
+ int r0 = get_group_id(0);
1497
+ int rb = get_group_id(1)*N_F32_F32;
1498
+ int im = get_group_id(2);
1499
+
1500
+ int i12 = im%ne12;
1501
+ int i13 = im/ne12;
1502
+
1503
+ ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1504
+
1505
+ global float * x = (global float *) (src0 + offset_src0);
1506
+
1507
+ if (ne00 < 128) {
1508
+ for (int row = 0; row < N_F32_F32; ++row) {
1509
+ int r1 = rb + row;
1510
+ if (r1 >= ne11) {
1511
+ break;
1512
+ }
1513
+
1514
+ ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1515
+
1516
+ global float * y = (global float *) (src1 + offset_src1);
1517
+
1518
+ float sumf = 0;
1519
+ for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
1520
+ sumf += (float) x[i] * (float) y[i];
1521
+ }
1522
+
1523
+ float all_sum = sub_group_reduce_add(sumf);
1524
+ if (get_sub_group_local_id() == 0) {
1525
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1526
+ }
1527
+ }
1528
+ } else {
1529
+ global float4 * x4 = (global float4 *)x;
1530
+ for (int row = 0; row < N_F32_F32; ++row) {
1531
+ int r1 = rb + row;
1532
+ if (r1 >= ne11) {
1533
+ break;
1534
+ }
1535
+
1536
+ ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1537
+
1538
+ global float * y = (global float *) (src1 + offset_src1);
1539
+ global float4 * y4 = (global float4 *) y;
1540
+
1541
+ float sumf = 0;
1542
+ for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
1543
+ sumf += (float) x4[i].s0 * y4[i].s0;
1544
+ sumf += (float) x4[i].s1 * y4[i].s1;
1545
+ sumf += (float) x4[i].s2 * y4[i].s2;
1546
+ sumf += (float) x4[i].s3 * y4[i].s3;
1547
+ }
1548
+
1549
+ float all_sum = sub_group_reduce_add(sumf);
1550
+ if (get_sub_group_local_id() == 0) {
1551
+ for (int i = 4*(ne00/4); i < ne00; ++i) {
1552
+ all_sum += (float) x[i] * y[i];
1553
+ }
1554
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1555
+ }
1556
+ }
1557
+ }
1558
+ }
1559
+
1560
+ //------------------------------------------------------------------------------
1561
+ // mul_mat_f16_f16
1562
+ //------------------------------------------------------------------------------
1563
+ #define N_F16_F16 4
1564
+
1565
+ kernel void kernel_mul_mat_f16_f16(
1566
+ global char * src0,
1567
+ ulong offset0,
1568
+ global char * src1,
1569
+ ulong offset1,
1570
+ global float * dst,
1571
+ ulong offsetd,
1572
+ int ne00,
1573
+ int ne01,
1574
+ int ne02,
1575
+ ulong nb00,
1576
+ ulong nb01,
1577
+ ulong nb02,
1578
+ ulong nb03,
1579
+ int ne10,
1580
+ int ne11,
1581
+ int ne12,
1582
+ ulong nb10,
1583
+ ulong nb11,
1584
+ ulong nb12,
1585
+ ulong nb13,
1586
+ int ne0,
1587
+ int ne1,
1588
+ int r2,
1589
+ int r3)
1590
+ {
1591
+ src0 = (global char*)((global char*)src0 + offset0);
1592
+ src1 = (global char*)((global char*)src1 + offset1);
1593
+ dst = (global float*)((global char*)dst + offsetd);
1594
+
1595
+ int r0 = get_group_id(0);
1596
+ int rb = get_group_id(1)*N_F16_F16;
1597
+ int im = get_group_id(2);
1598
+
1599
+ int i12 = im%ne12;
1600
+ int i13 = im/ne12;
1601
+
1602
+ ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1603
+
1604
+ global half * x = (global half *) (src0 + offset_src0);
1605
+
1606
+ if (ne00 < 128) {
1607
+ for (int row = 0; row < N_F16_F16; ++row) {
1608
+ int r1 = rb + row;
1609
+ if (r1 >= ne11) {
1610
+ break;
1611
+ }
1612
+
1613
+ ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1614
+
1615
+ global half * y = (global half *) (src1 + offset_src1);
1616
+
1617
+ float sumf = 0;
1618
+ for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
1619
+ sumf += (half) x[i] * (half) y[i];
1620
+ }
1621
+
1622
+ float all_sum = sub_group_reduce_add(sumf);
1623
+ if (get_sub_group_local_id() == 0) {
1624
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1625
+ }
1626
+ }
1627
+ } else {
1628
+ global half4 * x4 = (global half4 *)x;
1629
+ for (int row = 0; row < N_F16_F16; ++row) {
1630
+ int r1 = rb + row;
1631
+ if (r1 >= ne11) {
1632
+ break;
1633
+ }
1634
+
1635
+ ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1636
+
1637
+ global half * y = (global half *) (src1 + offset_src1);
1638
+ global half4 * y4 = (global half4 *) y;
1639
+
1640
+ float sumf = 0;
1641
+ for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
1642
+ sumf += (half) x4[i].s0 * y4[i].s0;
1643
+ sumf += (half) x4[i].s1 * y4[i].s1;
1644
+ sumf += (half) x4[i].s2 * y4[i].s2;
1645
+ sumf += (half) x4[i].s3 * y4[i].s3;
1646
+ }
1647
+
1648
+ float all_sum = sub_group_reduce_add(sumf);
1649
+ if (get_sub_group_local_id() == 0) {
1650
+ for (int i = 4*(ne00/4); i < ne00; ++i) {
1651
+ all_sum += (half) x[i] * y[i];
1652
+ }
1653
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1654
+ }
1655
+ }
1656
+ }
1657
+ }
1658
+
1659
+ //------------------------------------------------------------------------------
1660
+ // mul_mat_f16_f32_1row
1661
+ //------------------------------------------------------------------------------
1662
+ kernel void kernel_mul_mat_f16_f32_1row(
1663
+ global char * src0,
1664
+ ulong offset0,
1665
+ global char * src1,
1666
+ ulong offset1,
1667
+ global float * dst,
1668
+ ulong offsetd,
1669
+ int ne00,
1670
+ int ne01,
1671
+ int ne02,
1672
+ ulong nb00,
1673
+ ulong nb01,
1674
+ ulong nb02,
1675
+ ulong nb03,
1676
+ int ne10,
1677
+ int ne11,
1678
+ int ne12,
1679
+ ulong nb10,
1680
+ ulong nb11,
1681
+ ulong nb12,
1682
+ ulong nb13,
1683
+ int ne0,
1684
+ int ne1,
1685
+ int r2,
1686
+ int r3
1687
+ ) {
1688
+ src0 = (global char*)((global char*)src0 + offset0);
1689
+ src1 = (global char*)((global char*)src1 + offset1);
1690
+ dst = (global float*)((global char*)dst + offsetd);
1691
+
1692
+ int r0 = get_group_id(0);
1693
+ int r1 = get_group_id(1);
1694
+ int im = get_group_id(2);
1695
+
1696
+ int i12 = im%ne12;
1697
+ int i13 = im/ne12;
1698
+
1699
+ ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1700
+ ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1701
+
1702
+ global half * x = (global half *) (src0 + offset_src0);
1703
+ global float * y = (global float *) (src1 + offset_src1);
1704
+
1705
+ float sumf = 0;
1706
+ if (ne00 < 128) {
1707
+ for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
1708
+ sumf += (float) x[i] * (float) y[i];
1709
+ }
1710
+ float all_sum = sub_group_reduce_add(sumf);
1711
+ if (get_sub_group_local_id() == 0) {
1712
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1713
+ }
1714
+ } else {
1715
+ global half4 * x4 = (global half4 *) x;
1716
+ global float4 * y4 = (global float4 *) y;
1717
+ for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
1718
+ sumf += (float) x4[i].s0 * y4[i].s0;
1719
+ sumf += (float) x4[i].s1 * y4[i].s1;
1720
+ sumf += (float) x4[i].s2 * y4[i].s2;
1721
+ sumf += (float) x4[i].s3 * y4[i].s3;
1722
+ }
1723
+ float all_sum = sub_group_reduce_add(sumf);
1724
+ if (get_sub_group_local_id() == 0) {
1725
+ for (int i = 4*(ne00/4); i < ne00; ++i) {
1726
+ all_sum += (float) x[i] * y[i];
1727
+ }
1728
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1729
+ }
1730
+ }
1731
+
1732
+ }
1733
+
1734
+ //------------------------------------------------------------------------------
1735
+ // mul_mat_f16_f32
1736
+ //------------------------------------------------------------------------------
1737
+ #define N_F16_F32 4
1738
+
1739
+ #ifdef ADRENO_GPU
1740
+ REQD_SUBGROUP_SIZE_64
1741
+ #endif
1742
+ kernel void kernel_mul_mat_f16_f32(
1743
+ global char * src0,
1744
+ ulong offset0,
1745
+ global char * src1,
1746
+ ulong offset1,
1747
+ global float * dst,
1748
+ ulong offsetd,
1749
+ int ne00,
1750
+ int ne01,
1751
+ int ne02,
1752
+ ulong nb00,
1753
+ ulong nb01,
1754
+ ulong nb02,
1755
+ ulong nb03,
1756
+ int ne10,
1757
+ int ne11,
1758
+ int ne12,
1759
+ ulong nb10,
1760
+ ulong nb11,
1761
+ ulong nb12,
1762
+ ulong nb13,
1763
+ int ne0,
1764
+ int ne1,
1765
+ int r2,
1766
+ int r3
1767
+ ) {
1768
+ src0 = (global char*)((global char*)src0 + offset0);
1769
+ src1 = (global char*)((global char*)src1 + offset1);
1770
+ dst = (global float*)((global char*)dst + offsetd);
1771
+
1772
+ int r0 = get_group_id(0);
1773
+ int rb = get_group_id(1)*N_F16_F32;
1774
+ int im = get_group_id(2);
1775
+
1776
+ int i12 = im%ne12;
1777
+ int i13 = im/ne12;
1778
+
1779
+ ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1780
+
1781
+ global half * x = (global half *) (src0 + offset_src0);
1782
+
1783
+ if (ne00 < 128) {
1784
+ for (int row = 0; row < N_F16_F32; ++row) {
1785
+ int r1 = rb + row;
1786
+ if (r1 >= ne11) {
1787
+ break;
1788
+ }
1789
+
1790
+ ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1791
+
1792
+ global float * y = (global float *) (src1 + offset_src1);
1793
+
1794
+ float sumf = 0;
1795
+ for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
1796
+ sumf += convert_float(x[i]) * y[i];
1797
+ }
1798
+
1799
+ float all_sum = sub_group_reduce_add(sumf);
1800
+ if (get_sub_group_local_id() == 0) {
1801
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1802
+ }
1803
+ }
1804
+ } else {
1805
+ global half4 * x4 = (global half4 *)x;
1806
+ for (int row = 0; row < N_F16_F32; ++row) {
1807
+ int r1 = rb + row;
1808
+ if (r1 >= ne11) {
1809
+ break;
1810
+ }
1811
+
1812
+ ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1813
+
1814
+ global float * y = (global float *) (src1 + offset_src1);
1815
+ global float4 * y4 = (global float4 *) y;
1816
+
1817
+ float sumf = 0;
1818
+ for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
1819
+ sumf += convert_float(x4[i].s0) * y4[i].s0;
1820
+ sumf += convert_float(x4[i].s1) * y4[i].s1;
1821
+ sumf += convert_float(x4[i].s2) * y4[i].s2;
1822
+ sumf += convert_float(x4[i].s3) * y4[i].s3;
1823
+ }
1824
+
1825
+ float all_sum = sub_group_reduce_add(sumf);
1826
+ if (get_sub_group_local_id() == 0) {
1827
+ for (int i = 4*(ne00/4); i < ne00; ++i) {
1828
+ all_sum += (float) x[i] * y[i];
1829
+ }
1830
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1831
+ }
1832
+ }
1833
+ }
1834
+ }
1835
+
1836
+ //------------------------------------------------------------------------------
1837
+ // mul_mat_f16_f32_l4
1838
+ //------------------------------------------------------------------------------
1839
+ // Assumes row size (ne00) is a multiple of 4
1840
+ #ifdef ADRENO_GPU
1841
+ REQD_SUBGROUP_SIZE_64
1842
+ #endif
1843
+ kernel void kernel_mul_mat_f16_f32_l4(
1844
+ global char * src0,
1845
+ ulong offset0,
1846
+ global char * src1,
1847
+ ulong offset1,
1848
+ global float * dst,
1849
+ ulong offsetd,
1850
+ int ne00,
1851
+ int ne01,
1852
+ int ne02,
1853
+ ulong nb00,
1854
+ ulong nb01,
1855
+ ulong nb02,
1856
+ ulong nb03,
1857
+ int ne10,
1858
+ int ne11,
1859
+ int ne12,
1860
+ ulong nb10,
1861
+ ulong nb11,
1862
+ ulong nb12,
1863
+ ulong nb13,
1864
+ int ne0,
1865
+ int ne1,
1866
+ int r2,
1867
+ int r3
1868
+ ) {
1869
+ src0 = (global char*)((global char*)src0 + offset0);
1870
+ src1 = (global char*)((global char*)src1 + offset1);
1871
+ dst = (global float*)((global char*)dst + offsetd);
1872
+
1873
+ int nrows = ne11;
1874
+ int r0 = get_group_id(0);
1875
+ int im = get_group_id(2);
1876
+
1877
+ int i12 = im%ne12;
1878
+ int i13 = im/ne12;
1879
+
1880
+ ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1881
+
1882
+ global half4 * x4 = (global half4 *) (src0 + offset_src0);
1883
+
1884
+ for (int r1 = 0; r1 < nrows; ++r1) {
1885
+ ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1886
+
1887
+ global float4 * y4 = (global float4 *) (src1 + offset_src1);
1888
+
1889
+ float sumf = 0;
1890
+ for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
1891
+ sumf += convert_float(x4[i].s0) * y4[i].s0;
1892
+ sumf += convert_float(x4[i].s1) * y4[i].s1;
1893
+ sumf += convert_float(x4[i].s2) * y4[i].s2;
1894
+ sumf += convert_float(x4[i].s3) * y4[i].s3;
1895
+ }
1896
+
1897
+ float all_sum = sub_group_reduce_add(sumf);
1898
+ if (get_sub_group_local_id() == 0) {
1899
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1900
+ }
1901
+ }
1902
+ }
1903
+
1904
+ //------------------------------------------------------------------------------
1905
+ // mul_vec_q_n_f32
1906
+ //------------------------------------------------------------------------------
1907
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
1908
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
1909
+ // we assume that the yl's have been multiplied with the appropriate scale factor
1910
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
1911
+ inline float block_q_4_0_dot_y(
1912
+ global struct block_q4_0 * qb_curr,
1913
+ float sumy,
1914
+ private float * yl,
1915
+ int il
1916
+ ) {
1917
+ float d = qb_curr->d;
1918
+ float2 acc = 0.f;
1919
+ global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
1920
+ for (int i = 0; i < 8; i+=2) {
1921
+ acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F)
1922
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
1923
+ acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0)
1924
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
1925
+ }
1926
+ return d * (sumy * -8.f + acc.s0 + acc.s1);
1927
+ }
1928
+
1929
+ #ifdef INTEL_GPU
1930
+ #define N_DST 4 // each SIMD group works on 4 rows
1931
+ #define N_SIMDGROUP 1 // number of SIMD groups in a thread group
1932
+ #define N_SIMDWIDTH 16 // assuming SIMD group size is 16
1933
+ #elif defined (ADRENO_GPU)
1934
+ #define N_DST 4
1935
+ #define N_SIMDGROUP 1
1936
+ #define N_SIMDWIDTH 64
1937
+ #endif
1938
+
1939
+ inline void mul_vec_q_n_f32(
1940
+ global void * src0,
1941
+ global float * src1,
1942
+ global float * dst,
1943
+ int ne00,
1944
+ int ne01,
1945
+ int ne02,
1946
+ int ne10,
1947
+ int ne12,
1948
+ int ne0,
1949
+ int ne1,
1950
+ int r2,
1951
+ int r3
1952
+ ) {
1953
+
1954
+ const ulong nb = ne00/QK4_0;
1955
+
1956
+ int r0 = get_group_id(0);
1957
+ int r1 = get_group_id(1);
1958
+ int im = get_group_id(2);
1959
+
1960
+ // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
1961
+ // id of a SIMD group in the grid.
1962
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
1963
+
1964
+ int i12 = im%ne12;
1965
+ int i13 = im/ne12;
1966
+
1967
+ ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
1968
+
1969
+ global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
1970
+ global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
1971
+
1972
+ float yl[16]; // src1 vector cache
1973
+ float sumf[N_DST]={0.f};
1974
+
1975
+ int ix = get_sub_group_local_id()/2;
1976
+ int il = 8*(get_sub_group_local_id()%2);
1977
+
1978
+ global float * yb = y + ix * QK4_0 + il;
1979
+
1980
+ // each thread in a SIMD group deals with half a block.
1981
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
1982
+ float sumy = 0;
1983
+ for (int i = 0; i < 8; i += 2) {
1984
+ sumy += yb[i] + yb[i+1];
1985
+ yl[i+0] = yb[i+ 0];
1986
+ yl[i+1] = yb[i+ 1]/256.f;
1987
+ sumy += yb[i+16] + yb[i+17];
1988
+ yl[i+8] = yb[i+16]/16.f;
1989
+ yl[i+9] = yb[i+17]/4096.f;
1990
+ }
1991
+
1992
+ for (int row = 0; row < N_DST; row++) {
1993
+ sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il);
1994
+ }
1995
+
1996
+ // One thread in a SIMD group (i.e., subgroup) handles a half block,
1997
+ // hence then entire SIMD group handles SIMDWIDTH/2 blocks.
1998
+ // y points to the activation matrix (of type float). Therefore for
1999
+ // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
2000
+ // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
2001
+ // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
2002
+ yb += QK4_0 * (N_SIMDWIDTH/2);
2003
+ }
2004
+
2005
+ // The above does not work for Adreno - it produces incorrect results for
2006
+ // row = 1, 2, 3 and only row = 0 gives the correct result.
2007
+ // If N_DST is changed, the below array must be initialized accordingly.
2008
+ // This also seems to perform better on Intel.
2009
+ float tot[N_DST] = {
2010
+ sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]),
2011
+ sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])};
2012
+ for (int row = 0; row < N_DST; ++row) {
2013
+ if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
2014
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row];
2015
+ }
2016
+ }
2017
+ }
2018
+
2019
+ #ifdef INTEL_GPU
2020
+ REQD_SUBGROUP_SIZE_16
2021
+ #elif defined (ADRENO_GPU)
2022
+ REQD_SUBGROUP_SIZE_64
2023
+ #endif
2024
+ kernel void kernel_mul_mat_q4_0_f32(
2025
+ global void * src0,
2026
+ ulong offset0,
2027
+ global float * src1,
2028
+ ulong offset1,
2029
+ global float * dst,
2030
+ ulong offsetd,
2031
+ int ne00,
2032
+ int ne01,
2033
+ int ne02,
2034
+ int ne10,
2035
+ int ne12,
2036
+ int ne0,
2037
+ int ne1,
2038
+ int r2,
2039
+ int r3
2040
+ ) {
2041
+ src0 = (global void*)((global char*)src0 + offset0);
2042
+ src1 = (global float*)((global char*)src1 + offset1);
2043
+ dst = (global float*)((global char*)dst + offsetd);
2044
+
2045
+ mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
2046
+ }
2047
+
2048
+ //
2049
+ // This variant unrolls the loops and uses vector types instead of pointers.
2050
+ // It improves performance on Adreno but not so much on Intel.
2051
+ //
2052
+ inline float block_q_4_0_dot_y_v(
2053
+ global struct block_q4_0 * qb_curr,
2054
+ float sumy,
2055
+ float16 yl,
2056
+ int il
2057
+ ) {
2058
+ float d = qb_curr->d;
2059
+ float acc = 0.f;
2060
+ global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
2061
+
2062
+ acc += yl.s0 * (qs[0] & 0x000F);
2063
+ acc += yl.s1 * (qs[0] & 0x0F00);
2064
+ acc += yl.s8 * (qs[0] & 0x00F0);
2065
+ acc += yl.s9 * (qs[0] & 0xF000);
2066
+
2067
+ acc += yl.s2 * (qs[1] & 0x000F);
2068
+ acc += yl.s3 * (qs[1] & 0x0F00);
2069
+ acc += yl.sa * (qs[1] & 0x00F0);
2070
+ acc += yl.sb * (qs[1] & 0xF000);
2071
+
2072
+ acc += yl.s4 * (qs[2] & 0x000F);
2073
+ acc += yl.s5 * (qs[2] & 0x0F00);
2074
+ acc += yl.sc * (qs[2] & 0x00F0);
2075
+ acc += yl.sd * (qs[2] & 0xF000);
2076
+
2077
+ acc += yl.s6 * (qs[3] & 0x000F);
2078
+ acc += yl.s7 * (qs[3] & 0x0F00);
2079
+ acc += yl.se * (qs[3] & 0x00F0);
2080
+ acc += yl.sf * (qs[3] & 0xF000);
2081
+
2082
+ return d * (sumy * -8.f + acc);
2083
+ }
2084
+
2085
+ #undef N_DST
2086
+ #undef N_SIMDGROUP
2087
+ #undef N_SIMDWIDTH
2088
+
2089
+ #ifdef INTEL_GPU
2090
+ #define N_DST 4 // each SIMD group works on 4 rows
2091
+ #define N_SIMDGROUP 1 // number of SIMD groups in a thread group
2092
+ #define N_SIMDWIDTH 16 // assuming SIMD group size is 16
2093
+ #elif defined (ADRENO_GPU)
2094
+ #define N_DST 4
2095
+ #define N_SIMDGROUP 1
2096
+ #define N_SIMDWIDTH 64
2097
+ #endif
2098
+
2099
+ inline void mul_vec_q_n_f32_v(
2100
+ global void * src0,
2101
+ global float * src1,
2102
+ global float * dst,
2103
+ int ne00,
2104
+ int ne01,
2105
+ int ne02,
2106
+ int ne10,
2107
+ int ne12,
2108
+ int ne0,
2109
+ int ne1,
2110
+ int r2,
2111
+ int r3
2112
+ ) {
2113
+ const ulong nb = ne00/QK4_0;
2114
+
2115
+ int r0 = get_group_id(0);
2116
+ int r1 = get_group_id(1);
2117
+ int im = get_group_id(2);
2118
+
2119
+ // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
2120
+ // id of a SIMD group in the grid.
2121
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
2122
+
2123
+ int i12 = im%ne12;
2124
+ int i13 = im/ne12;
2125
+
2126
+ ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2127
+
2128
+ global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
2129
+ global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
2130
+
2131
+ float16 yl; // src1 vector cache
2132
+ float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
2133
+
2134
+ int ix = get_sub_group_local_id()/2;
2135
+ int il = 8*(get_sub_group_local_id()%2);
2136
+
2137
+ global float * yb = y + ix * QK4_0 + il;
2138
+
2139
+ // each thread in a SIMD group deals with half a block.
2140
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
2141
+ float sumy = 0;
2142
+
2143
+ sumy += yb[0];
2144
+ sumy += yb[1];
2145
+ sumy += yb[2];
2146
+ sumy += yb[3];
2147
+ sumy += yb[4];
2148
+ sumy += yb[5];
2149
+ sumy += yb[6];
2150
+ sumy += yb[7];
2151
+
2152
+ sumy += yb[16];
2153
+ sumy += yb[17];
2154
+ sumy += yb[18];
2155
+ sumy += yb[19];
2156
+ sumy += yb[20];
2157
+ sumy += yb[21];
2158
+ sumy += yb[22];
2159
+ sumy += yb[23];
2160
+
2161
+
2162
+ yl.s0 = yb[0];
2163
+ yl.s1 = yb[1]/256.f;
2164
+
2165
+ yl.s2 = yb[2];
2166
+ yl.s3 = yb[3]/256.f;
2167
+
2168
+ yl.s4 = yb[4];
2169
+ yl.s5 = yb[5]/256.f;
2170
+
2171
+ yl.s6 = yb[6];
2172
+ yl.s7 = yb[7]/256.f;
2173
+
2174
+ yl.s8 = yb[16]/16.f;
2175
+ yl.s9 = yb[17]/4096.f;
2176
+
2177
+ yl.sa = yb[18]/16.f;
2178
+ yl.sb = yb[19]/4096.f;
2179
+
2180
+ yl.sc = yb[20]/16.f;
2181
+ yl.sd = yb[21]/4096.f;
2182
+
2183
+ yl.se = yb[22]/16.f;
2184
+ yl.sf = yb[23]/4096.f;
2185
+
2186
+ sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il);
2187
+ sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il);
2188
+ sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il);
2189
+ sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il);
2190
+
2191
+ // One thread in a SIMD group (i.e., subgroup) handles a half block,
2192
+ // hence then entire SIMD group handles SIMDWIDTH/2 blocks.
2193
+ // y points to the activation matrix (of type float). Therefore for
2194
+ // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
2195
+ // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
2196
+ // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
2197
+ yb += QK4_0 * (N_SIMDWIDTH/2);
2198
+ }
2199
+
2200
+ // The above does not work for Adreno - it produces incorrect results for
2201
+ // row = 1, 2, 3 and only row = 0 gives the correct result.
2202
+ // If N_DST is changed, the below array must be initialized accordingly.
2203
+ // This also seems to perform better on Intel.
2204
+ float4 tot = (float4)(
2205
+ sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
2206
+ sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
2207
+ );
2208
+
2209
+ if (get_sub_group_local_id() == 0) {
2210
+ if (first_row + 0 < ne01) {
2211
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
2212
+ }
2213
+ if (first_row + 1 < ne01) {
2214
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
2215
+ }
2216
+ if (first_row + 2 < ne01) {
2217
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
2218
+ }
2219
+ if (first_row + 3 < ne01) {
2220
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
2221
+ }
2222
+ }
2223
+ }
2224
+
2225
+ #ifdef INTEL_GPU
2226
+ REQD_SUBGROUP_SIZE_16
2227
+ #elif defined (ADRENO_GPU)
2228
+ REQD_SUBGROUP_SIZE_64
2229
+ #endif
2230
+ kernel void kernel_mul_mat_q4_0_f32_v(
2231
+ global void * src0,
2232
+ ulong offset0,
2233
+ global float * src1,
2234
+ ulong offset1,
2235
+ global float * dst,
2236
+ ulong offsetd,
2237
+ int ne00,
2238
+ int ne01,
2239
+ int ne02,
2240
+ int ne10,
2241
+ int ne12,
2242
+ int ne0,
2243
+ int ne1,
2244
+ int r2,
2245
+ int r3
2246
+ ) {
2247
+ src0 = (global void*)((global char*)src0 + offset0);
2248
+ src1 = (global float*)((global char*)src1 + offset1);
2249
+ dst = (global float*)((global char*)dst + offsetd);
2250
+
2251
+ mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
2252
+ }
2253
+
2254
+ //------------------------------------------------------------------------------
2255
+ // kernel_convert_block_q4_0
2256
+ // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
2257
+ // This kernel does not deshuffle the bits.
2258
+ //------------------------------------------------------------------------------
2259
+ kernel void kernel_convert_block_q4_0(
2260
+ global struct block_q4_0 * src0,
2261
+ global uchar * dst_q,
2262
+ global half * dst_d
2263
+ ) {
2264
+ global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
2265
+ global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
2266
+ global half * d = (global half *) dst_d + get_global_id(0);
2267
+
2268
+ *d = b->d;
2269
+
2270
+ for (int i = 0; i < QK4_0/2; ++i) {
2271
+ q[i] = b->qs[i];
2272
+ }
2273
+ }
2274
+
2275
+ kernel void kernel_restore_block_q4_0(
2276
+ global uchar * src_q,
2277
+ global half * src_d,
2278
+ global struct block_q4_0 * dst
2279
+ ) {
2280
+ global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
2281
+ global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
2282
+ global half * d = (global half *) src_d + get_global_id(0);
2283
+
2284
+ b->d = *d;
2285
+ for (int i = 0; i < QK4_0/2; ++i) {
2286
+ b->qs[i] = q[i];
2287
+ }
2288
+ }
2289
+
2290
+ //------------------------------------------------------------------------------
2291
+ // mul_vec_q_n_f32_flat
2292
+ //
2293
+ // This variation uses flat arrays (struct of arrays, SOA) representation for
2294
+ // quant tensors.
2295
+ //------------------------------------------------------------------------------
2296
+
2297
+ // This function requires the original shuffled weights.
2298
+ // As a reminder, the original weights are shuffled so that (q[0], q[16]) are
2299
+ // packed together in a byte, so are (q[1], q[17]) and so on.
2300
+ inline float block_q_4_0_dot_y_flat(
2301
+ global uchar * x,
2302
+ global half * dh,
2303
+ float sumy,
2304
+ float16 yl,
2305
+ int il
2306
+ ) {
2307
+ float d = *dh;
2308
+ global ushort * qs = ((global ushort *)x + il/2);
2309
+ float acc = 0.f;
2310
+
2311
+ acc += yl.s0 * (qs[0] & 0x000F);
2312
+ acc += yl.s1 * (qs[0] & 0x0F00);
2313
+ acc += yl.s8 * (qs[0] & 0x00F0);
2314
+ acc += yl.s9 * (qs[0] & 0xF000);
2315
+
2316
+ acc += yl.s2 * (qs[1] & 0x000F);
2317
+ acc += yl.s3 * (qs[1] & 0x0F00);
2318
+ acc += yl.sa * (qs[1] & 0x00F0);
2319
+ acc += yl.sb * (qs[1] & 0xF000);
2320
+
2321
+ acc += yl.s4 * (qs[2] & 0x000F);
2322
+ acc += yl.s5 * (qs[2] & 0x0F00);
2323
+ acc += yl.sc * (qs[2] & 0x00F0);
2324
+ acc += yl.sd * (qs[2] & 0xF000);
2325
+
2326
+ acc += yl.s6 * (qs[3] & 0x000F);
2327
+ acc += yl.s7 * (qs[3] & 0x0F00);
2328
+ acc += yl.se * (qs[3] & 0x00F0);
2329
+ acc += yl.sf * (qs[3] & 0xF000);
2330
+
2331
+ return d * (sumy * -8.f + acc);
2332
+ }
2333
+
2334
+ #undef N_DST
2335
+ #undef N_SIMDGROUP
2336
+ #undef N_SIMDWIDTH
2337
+
2338
+ #ifdef INTEL_GPU
2339
+ #define N_DST 4 // each SIMD group works on 4 rows
2340
+ #define N_SIMDGROUP 1 // number of SIMD groups in a thread group
2341
+ #define N_SIMDWIDTH 16 // assuming SIMD group size is 32
2342
+ #elif defined (ADRENO_GPU)
2343
+ #define N_DST 4
2344
+ #define N_SIMDGROUP 1
2345
+ #define N_SIMDWIDTH 64
2346
+ #endif
2347
+
2348
+ inline void mul_vec_q_n_f32_flat(
2349
+ global uchar * src0_q,
2350
+ global half * src0_d,
2351
+ global float * src1,
2352
+ global float * dst,
2353
+ int ne00,
2354
+ int ne01,
2355
+ int ne02,
2356
+ int ne10,
2357
+ int ne12,
2358
+ int ne0,
2359
+ int ne1,
2360
+ int r2,
2361
+ int r3
2362
+ ) {
2363
+ const ulong nb = ne00/QK4_0;
2364
+
2365
+ int r0 = get_group_id(0);
2366
+ int r1 = get_group_id(1);
2367
+ int im = get_group_id(2);
2368
+
2369
+ // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
2370
+ // a SIMD group in the grid. Each SIMD group produces N_DST values in the
2371
+ // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
2372
+ // Currently with llama2 7B, im is always 0.
2373
+ // TODO: how to handle im/gqa*(nb*ne0)?
2374
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
2375
+
2376
+ int i12 = im%ne12;
2377
+ int i13 = im/ne12;
2378
+
2379
+ // The number of scales is the same as the number of blocks.
2380
+ ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2381
+ // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
2382
+ ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
2383
+
2384
+ global uchar * x = (global uchar *) src0_q + offset0_q;
2385
+ global half * d = (global half *) src0_d + offset0_d;
2386
+ global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
2387
+
2388
+ float16 yl;
2389
+ float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
2390
+
2391
+ int ix = get_sub_group_local_id()/2;
2392
+ int il = 8*(get_sub_group_local_id()%2);
2393
+
2394
+ global float * yb = y + ix*QK4_0 + il;
2395
+
2396
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
2397
+ float sumy = 0.f;
2398
+
2399
+ sumy += yb[0];
2400
+ sumy += yb[1];
2401
+ sumy += yb[2];
2402
+ sumy += yb[3];
2403
+ sumy += yb[4];
2404
+ sumy += yb[5];
2405
+ sumy += yb[6];
2406
+ sumy += yb[7];
2407
+
2408
+ sumy += yb[16];
2409
+ sumy += yb[17];
2410
+ sumy += yb[18];
2411
+ sumy += yb[19];
2412
+ sumy += yb[20];
2413
+ sumy += yb[21];
2414
+ sumy += yb[22];
2415
+ sumy += yb[23];
2416
+
2417
+ yl.s0 = yb[0];
2418
+ yl.s1 = yb[1]/256.f;
2419
+
2420
+ yl.s2 = yb[2];
2421
+ yl.s3 = yb[3]/256.f;
2422
+
2423
+ yl.s4 = yb[4];
2424
+ yl.s5 = yb[5]/256.f;
2425
+
2426
+ yl.s6 = yb[6];
2427
+ yl.s7 = yb[7]/256.f;
2428
+
2429
+ yl.s8 = yb[16]/16.f;
2430
+ yl.s9 = yb[17]/4096.f;
2431
+
2432
+ yl.sa = yb[18]/16.f;
2433
+ yl.sb = yb[19]/4096.f;
2434
+
2435
+ yl.sc = yb[20]/16.f;
2436
+ yl.sd = yb[21]/4096.f;
2437
+
2438
+ yl.se = yb[22]/16.f;
2439
+ yl.sf = yb[23]/4096.f;
2440
+
2441
+ sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
2442
+ sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
2443
+ sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
2444
+ sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
2445
+
2446
+ yb += QK4_0 * (N_SIMDWIDTH/2);
2447
+ }
2448
+
2449
+ float4 tot = (float4)(
2450
+ sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
2451
+ sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
2452
+ );
2453
+
2454
+ if (get_sub_group_local_id() == 0) {
2455
+ if (first_row + 0 < ne01) {
2456
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
2457
+ }
2458
+ if (first_row + 1 < ne01) {
2459
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
2460
+ }
2461
+ if (first_row + 2 < ne01) {
2462
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
2463
+ }
2464
+ if (first_row + 3 < ne01) {
2465
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
2466
+ }
2467
+ }
2468
+ }
2469
+
2470
+ #ifdef INTEL_GPU
2471
+ REQD_SUBGROUP_SIZE_16
2472
+ #elif defined (ADRENO_GPU)
2473
+ REQD_SUBGROUP_SIZE_64
2474
+ #endif
2475
+ kernel void kernel_mul_mat_q4_0_f32_flat(
2476
+ global uchar * src0_q,
2477
+ global half * src0_d,
2478
+ global float * src1,
2479
+ ulong offset1,
2480
+ global float * dst,
2481
+ ulong offsetd,
2482
+ int ne00,
2483
+ int ne01,
2484
+ int ne02,
2485
+ int ne10,
2486
+ int ne12,
2487
+ int ne0,
2488
+ int ne1,
2489
+ int r2,
2490
+ int r3
2491
+ ) {
2492
+ src1 = (global float*)((global char*)src1 + offset1);
2493
+ dst = (global float*)((global char*)dst + offsetd);
2494
+
2495
+ mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
2496
+ }
2497
+
2498
+ //
2499
+ // This variant outputs 8 values.
2500
+ //
2501
+ #undef N_DST
2502
+ #undef N_SIMDGROUP
2503
+ #undef N_SIMDWIDTH
2504
+
2505
+ #ifdef INTEL_GPU
2506
+ #define N_DST 8 // each SIMD group works on 8 rows
2507
+ #define N_SIMDGROUP 1 // number of SIMD groups in a thread group
2508
+ #define N_SIMDWIDTH 16 // assuming SIMD group size is 32
2509
+ #elif defined (ADRENO_GPU)
2510
+ #define N_DST 8
2511
+ #define N_SIMDGROUP 1
2512
+ #define N_SIMDWIDTH 64
2513
+ #endif
2514
+
2515
+ inline void mul_vec_q_n_f32_8x_flat(
2516
+ global uchar * src0_q,
2517
+ global half * src0_d,
2518
+ global float * src1,
2519
+ global float * dst,
2520
+ int ne00,
2521
+ int ne01,
2522
+ int ne02,
2523
+ int ne10,
2524
+ int ne12,
2525
+ int ne0,
2526
+ int ne1,
2527
+ int r2,
2528
+ int r3
2529
+ ) {
2530
+ const ulong nb = ne00/QK4_0;
2531
+
2532
+ int r0 = get_group_id(0);
2533
+ int r1 = get_group_id(1);
2534
+ int im = get_group_id(2);
2535
+
2536
+ // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
2537
+ // a SIMD group in the grid. Each SIMD group produces N_DST values in the
2538
+ // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
2539
+ // Currently with llama2 7B, im is always 0.
2540
+ // TODO: how to handle im/gqa*(nb*ne0)?
2541
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
2542
+
2543
+ int i12 = im%ne12;
2544
+ int i13 = im/ne12;
2545
+
2546
+ // The number of scales is the same as the number of blocks.
2547
+ ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2548
+ // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
2549
+ ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
2550
+
2551
+ global uchar * x = (global uchar *) src0_q + offset0_q;
2552
+ global half * d = (global half *) src0_d + offset0_d;
2553
+ global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
2554
+
2555
+ float16 yl;
2556
+ float8 sumf = 0.f;
2557
+
2558
+ int ix = get_sub_group_local_id()/2;
2559
+ int il = 8*(get_sub_group_local_id()%2);
2560
+
2561
+ global float * yb = y + ix*QK4_0 + il;
2562
+
2563
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
2564
+ float sumy = 0.f;
2565
+
2566
+ sumy += yb[0];
2567
+ sumy += yb[1];
2568
+ sumy += yb[2];
2569
+ sumy += yb[3];
2570
+ sumy += yb[4];
2571
+ sumy += yb[5];
2572
+ sumy += yb[6];
2573
+ sumy += yb[7];
2574
+
2575
+ sumy += yb[16];
2576
+ sumy += yb[17];
2577
+ sumy += yb[18];
2578
+ sumy += yb[19];
2579
+ sumy += yb[20];
2580
+ sumy += yb[21];
2581
+ sumy += yb[22];
2582
+ sumy += yb[23];
2583
+
2584
+ yl.s0 = yb[0];
2585
+ yl.s1 = yb[1]/256.f;
2586
+
2587
+ yl.s2 = yb[2];
2588
+ yl.s3 = yb[3]/256.f;
2589
+
2590
+ yl.s4 = yb[4];
2591
+ yl.s5 = yb[5]/256.f;
2592
+
2593
+ yl.s6 = yb[6];
2594
+ yl.s7 = yb[7]/256.f;
2595
+
2596
+ yl.s8 = yb[16]/16.f;
2597
+ yl.s9 = yb[17]/4096.f;
2598
+
2599
+ yl.sa = yb[18]/16.f;
2600
+ yl.sb = yb[19]/4096.f;
2601
+
2602
+ yl.sc = yb[20]/16.f;
2603
+ yl.sd = yb[21]/4096.f;
2604
+
2605
+ yl.se = yb[22]/16.f;
2606
+ yl.sf = yb[23]/4096.f;
2607
+
2608
+ sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
2609
+ sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
2610
+ sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
2611
+ sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
2612
+
2613
+ sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
2614
+ sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
2615
+ sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
2616
+ sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
2617
+
2618
+ yb += QK4_0 * (N_SIMDWIDTH/2);
2619
+ }
2620
+
2621
+ float8 tot = (float8)(
2622
+ sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
2623
+ sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
2624
+ sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
2625
+ sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
2626
+ );
2627
+
2628
+ if (get_sub_group_local_id() == 0) {
2629
+ if (first_row + 0 < ne01) {
2630
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
2631
+ }
2632
+ if (first_row + 1 < ne01) {
2633
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
2634
+ }
2635
+ if (first_row + 2 < ne01) {
2636
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
2637
+ }
2638
+ if (first_row + 3 < ne01) {
2639
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
2640
+ }
2641
+
2642
+ if (first_row + 4 < ne01) {
2643
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
2644
+ }
2645
+ if (first_row + 5 < ne01) {
2646
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
2647
+ }
2648
+ if (first_row + 6 < ne01) {
2649
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
2650
+ }
2651
+ if (first_row + 7 < ne01) {
2652
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
2653
+ }
2654
+ }
2655
+ }
2656
+
2657
+ #ifdef INTEL_GPU
2658
+ REQD_SUBGROUP_SIZE_16
2659
+ #elif defined (ADRENO_GPU)
2660
+ REQD_SUBGROUP_SIZE_64
2661
+ #endif
2662
+ kernel void kernel_mul_mat_q4_0_f32_8x_flat(
2663
+ global uchar * src0_q,
2664
+ global half * src0_d,
2665
+ global float * src1,
2666
+ ulong offset1,
2667
+ global float * dst,
2668
+ ulong offsetd,
2669
+ int ne00,
2670
+ int ne01,
2671
+ int ne02,
2672
+ int ne10,
2673
+ int ne12,
2674
+ int ne0,
2675
+ int ne1,
2676
+ int r2,
2677
+ int r3
2678
+ ) {
2679
+ src1 = (global float*)((global char*)src1 + offset1);
2680
+ dst = (global float*)((global char*)dst + offsetd);
2681
+
2682
+ mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
2683
+ }
ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //------------------------------------------------------------------------------
2
+ // This file is contains additional kernels for data conversion.
3
+ // These kernels are used when loading the model, so its performance is less
4
+ // important.
5
+ //------------------------------------------------------------------------------
6
+ #ifdef cl_khr_fp16
7
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
8
+ #elif defined(cl_amd_fp16)
9
+ #pragma OPENCL EXTENSION cl_amd_fp16 : enable
10
+ #else
11
+ #error "Half precision floating point not supportedby OpenCL implementation on your device."
12
+ #endif
13
+
14
+ #ifdef cl_khr_subgroups
15
+ #pragma OPENCL EXTENSION cl_khr_subgroups : enable
16
+ #elif defined(cl_intel_subgroups)
17
+ #pragma OPENCL EXTENSION cl_intel_subgroups : enable
18
+ #else
19
+ #error "Subgroup not supported on your device."
20
+ #endif
21
+
22
+ #ifdef cl_intel_required_subgroup_size
23
+ // Always use subgroup size of 32 on Intel.
24
+ #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
25
+ #define INTEL_GPU 1
26
+ #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
27
+ #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
28
+ #elif defined(cl_qcom_reqd_sub_group_size)
29
+ // Always use subgroups size of 64 on Adreno.
30
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
31
+ #define ADRENO_GPU 1
32
+ #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
33
+ #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
34
+ #else
35
+ // TODO: do not know how to choose subgroup size on other GPUs.
36
+ #error "Selecting subgroup size is not supported on your device."
37
+ #endif
38
+
39
+ #define QK4_0 32
40
+ #define QR4_0 2
41
+ #define QK4_1 32
42
+ #define QR4_1 2
43
+ #define QK5_0 32
44
+ #define QR5_0 2
45
+ #define QK5_1 32
46
+ #define QR5_1 2
47
+ #define QK8_0 32
48
+ #define QR8_0 1
49
+ #define QK_K 256
50
+ #define K_QUANTS_PER_ITERATION 2
51
+
52
+ typedef char int8_t;
53
+ typedef uchar uint8_t;
54
+ typedef short int16_t;
55
+ typedef ushort uint16_t;
56
+ typedef int int32_t;
57
+ typedef uint uint32_t;
58
+
59
+ //------------------------------------------------------------------------------
60
+ // block_q4_0
61
+ //------------------------------------------------------------------------------
62
+ struct block_q4_0
63
+ {
64
+ half d;
65
+ uint8_t qs[QK4_0 / 2];
66
+ };
67
+
68
+ //------------------------------------------------------------------------------
69
+ // mul_vec_q_n_f32_flat_noshuffle
70
+ //
71
+ // This variation uses flat arrays (struct of arrays, SOA) representation for
72
+ // quant tensors. It also uses non shuffled bit order for weights.
73
+ //
74
+ // The shuffled version is kept in the original file because moving it here
75
+ // seems to result in worse performance for adreno.
76
+ //------------------------------------------------------------------------------
77
+
78
+ kernel void kernel_convert_block_q4_0_noshuffle(
79
+ global struct block_q4_0 * src0,
80
+ global uchar * dst_q,
81
+ global half * dst_d
82
+ ) {
83
+ global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
84
+ global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
85
+ global half * d = (global half *) dst_d + get_global_id(0);
86
+
87
+ *d = b->d;
88
+ for (int i = 0; i < QK4_0/4; ++i) {
89
+ uchar x0 = b->qs[2*i + 0];
90
+ uchar x1 = b->qs[2*i + 1];
91
+
92
+ q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
93
+ q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
94
+
95
+ #ifdef ADRENO_GPU
96
+ // Workaround for adreno - must have the following printf statement for
97
+ // the kernel to work properly. Otherwise it produces incorrect result.
98
+ // convert_uchar above also seems necessary.
99
+ // Compare against a large number so that it does not print anything.
100
+ // get_sub_group_local_id() also works.
101
+ if (get_global_id(0) == 65536*4096) {
102
+ printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
103
+ }
104
+ #endif
105
+ }
106
+ }
ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
+ #pragma OPENCL EXTENSION cl_khr_subgroups : enable
3
+ #pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
4
+ #pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
5
+ #pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
6
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
7
+
8
+ // assume
9
+ #define QK4_0 32
10
+ #define N_SIMDGROUP 4
11
+
12
+ #define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
13
+ float shared_y; \
14
+ shared_y = sub_group_broadcast(y.s0, 0); \
15
+ total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
16
+ total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
17
+ shared_y = sub_group_broadcast(y.s1, 0); \
18
+ total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
19
+ total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
20
+ shared_y = sub_group_broadcast(y.s2, 0); \
21
+ total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
22
+ total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
23
+ shared_y = sub_group_broadcast(y.s3, 0); \
24
+ total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
25
+ total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
26
+ shared_y = sub_group_broadcast(y.s4, 0); \
27
+ total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
28
+ total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
29
+ shared_y = sub_group_broadcast(y.s5, 0); \
30
+ total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
31
+ total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
32
+ shared_y = sub_group_broadcast(y.s6, 0); \
33
+ total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
34
+ total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
35
+ shared_y = sub_group_broadcast(y.s7, 0); \
36
+ total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
37
+ total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
38
+ shared_y = sub_group_broadcast(y.s0, 1); \
39
+ total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
40
+ total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
41
+ shared_y = sub_group_broadcast(y.s1, 1); \
42
+ total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
43
+ total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
44
+ shared_y = sub_group_broadcast(y.s2, 1); \
45
+ total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
46
+ total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
47
+ shared_y = sub_group_broadcast(y.s3, 1); \
48
+ total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
49
+ total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
50
+ shared_y = sub_group_broadcast(y.s4, 1); \
51
+ total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
52
+ total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
53
+ shared_y = sub_group_broadcast(y.s5, 1); \
54
+ total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
55
+ total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
56
+ shared_y = sub_group_broadcast(y.s6, 1); \
57
+ total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
58
+ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
59
+ shared_y = sub_group_broadcast(y.s7, 1); \
60
+ total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
61
+ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
62
+
63
+
64
+ #define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
65
+ shared_y = sub_group_broadcast(y.s0, 2); \
66
+ total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
67
+ total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
68
+ shared_y = sub_group_broadcast(y.s1, 2); \
69
+ total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
70
+ total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
71
+ shared_y = sub_group_broadcast(y.s2, 2); \
72
+ total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
73
+ total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
74
+ shared_y = sub_group_broadcast(y.s3, 2); \
75
+ total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
76
+ total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
77
+ shared_y = sub_group_broadcast(y.s4, 2); \
78
+ total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
79
+ total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
80
+ shared_y = sub_group_broadcast(y.s5, 2); \
81
+ total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
82
+ total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
83
+ shared_y = sub_group_broadcast(y.s6, 2); \
84
+ total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
85
+ total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
86
+ shared_y = sub_group_broadcast(y.s7, 2); \
87
+ total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
88
+ total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
89
+ shared_y = sub_group_broadcast(y.s0, 3); \
90
+ total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
91
+ total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
92
+ shared_y = sub_group_broadcast(y.s1, 3); \
93
+ total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
94
+ total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
95
+ shared_y = sub_group_broadcast(y.s2, 3); \
96
+ total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
97
+ total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
98
+ shared_y = sub_group_broadcast(y.s3, 3); \
99
+ total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
100
+ total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
101
+ shared_y = sub_group_broadcast(y.s4, 3); \
102
+ total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
103
+ total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
104
+ shared_y = sub_group_broadcast(y.s5, 3); \
105
+ total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
106
+ total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
107
+ shared_y = sub_group_broadcast(y.s6, 3); \
108
+ total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
109
+ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
110
+ shared_y = sub_group_broadcast(y.s7, 3); \
111
+ total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
112
+ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
113
+
114
+
115
+ #define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
116
+ float8 shared_y; \
117
+ shared_y = sub_group_broadcast(y, 0); \
118
+ total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
119
+ total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
120
+ total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
121
+ total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
122
+ total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
123
+ total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
124
+ total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
125
+ total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
126
+ total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
127
+ total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
128
+ total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
129
+ total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
130
+ total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
131
+ total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
132
+ total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
133
+ total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
134
+ shared_y = sub_group_broadcast(y, 1); \
135
+ total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
136
+ total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
137
+ total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
138
+ total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
139
+ total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
140
+ total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
141
+ total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
142
+ total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
143
+ total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
144
+ total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
145
+ total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
146
+ total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
147
+ total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
148
+ total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
149
+ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
150
+ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
151
+
152
+
153
+ #define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
154
+ shared_y = sub_group_broadcast(y, 2); \
155
+ total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
156
+ total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
157
+ total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
158
+ total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
159
+ total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
160
+ total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
161
+ total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
162
+ total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
163
+ total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
164
+ total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
165
+ total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
166
+ total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
167
+ total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
168
+ total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
169
+ total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
170
+ total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
171
+ shared_y = sub_group_broadcast(y, 3); \
172
+ total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
173
+ total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
174
+ total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
175
+ total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
176
+ total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
177
+ total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
178
+ total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
179
+ total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
180
+ total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
181
+ total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
182
+ total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
183
+ total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
184
+ total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
185
+ total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
186
+ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
187
+ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
188
+
189
+
190
+ __attribute__((qcom_reqd_sub_group_size("full")))
191
+ __kernel void kernel_gemv_noshuffle(
192
+ __read_only image1d_buffer_t src0_q, // quantized A
193
+ global half2 * src0_d, // A scales
194
+ __read_only image1d_buffer_t src1, // B
195
+ ulong offset1, // offset to B (0)
196
+ global float * dst, // C
197
+ ulong offsetd, // offset to C (0)
198
+ uint K, // K
199
+ int ne01, // M
200
+ int ne02, // 1
201
+ int ne10, // K
202
+ int ne12, // 1
203
+ int ne0, // M
204
+ int ne1, // N
205
+ int r2, // 1
206
+ int r3)
207
+ {
208
+ uint groupId = get_local_id(1);
209
+ uint gid = get_global_id(0);
210
+ ushort slid = get_sub_group_local_id();
211
+
212
+ __private uint4 regA;
213
+ __private half2 regS;
214
+ __private float8 regB;
215
+
216
+ __private float2 totalSum = (float2)(0.0f);
217
+
218
+ // loop along K in block granularity, skip 4 blocks every iter
219
+ for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
220
+ regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
221
+ // first 4 fibers in each wave load 8 B values to its private scope
222
+ if (slid < 4) {
223
+ regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
224
+ regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
225
+ }
226
+
227
+ // load half weights for two blocks in consecutive rows
228
+ regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
229
+ regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
230
+ regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
231
+ regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
232
+ #ifdef VECTOR_SUB_GROUP_BROADCAT
233
+ dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
234
+ #else
235
+ dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
236
+ #endif // VECTOR_SUB_GROUP_BROADCAT
237
+
238
+ regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
239
+ regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
240
+ regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
241
+ regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
242
+ #ifdef VECTOR_SUB_GROUP_BROADCAT
243
+ dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
244
+ #else
245
+ dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
246
+ #endif // VECTOR_SUB_GROUP_BROADCAT
247
+ }
248
+
249
+ // reduction in local memory, assumes #wave=4
250
+ __local float2 reduceLM[SIMDGROUP_WIDTH * 3];
251
+ if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
252
+ if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
253
+ if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
254
+ barrier(CLK_LOCAL_MEM_FENCE);
255
+ if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
256
+ if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
257
+ if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
258
+
259
+ // 2 outputs per fiber in wave 0
260
+ if (groupId == 0) {
261
+ dst = (global float*)((global char*)dst + offsetd);
262
+ vstore2(totalSum, 0, &(dst[gid * 2]));
263
+ }
264
+
265
+ }
ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
+ #pragma OPENCL EXTENSION cl_khr_subgroups : enable
3
+ #pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
4
+ #pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
5
+ #pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
6
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
7
+
8
+ // assume
9
+ #define QK4_0 32
10
+ #define N_SIMDGROUP 4
11
+
12
+ #define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
13
+ float shared_y; \
14
+ shared_y = sub_group_broadcast(y.s0, 0); \
15
+ total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
16
+ total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
17
+ shared_y = sub_group_broadcast(y.s1, 0); \
18
+ total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
19
+ total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
20
+ shared_y = sub_group_broadcast(y.s2, 0); \
21
+ total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
22
+ total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
23
+ shared_y = sub_group_broadcast(y.s3, 0); \
24
+ total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
25
+ total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
26
+ shared_y = sub_group_broadcast(y.s4, 0); \
27
+ total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
28
+ total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
29
+ shared_y = sub_group_broadcast(y.s5, 0); \
30
+ total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
31
+ total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
32
+ shared_y = sub_group_broadcast(y.s6, 0); \
33
+ total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
34
+ total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
35
+ shared_y = sub_group_broadcast(y.s7, 0); \
36
+ total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
37
+ total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
38
+ shared_y = sub_group_broadcast(y.s0, 1); \
39
+ total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
40
+ total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
41
+ shared_y = sub_group_broadcast(y.s1, 1); \
42
+ total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
43
+ total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
44
+ shared_y = sub_group_broadcast(y.s2, 1); \
45
+ total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
46
+ total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
47
+ shared_y = sub_group_broadcast(y.s3, 1); \
48
+ total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
49
+ total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
50
+ shared_y = sub_group_broadcast(y.s4, 1); \
51
+ total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
52
+ total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
53
+ shared_y = sub_group_broadcast(y.s5, 1); \
54
+ total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
55
+ total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
56
+ shared_y = sub_group_broadcast(y.s6, 1); \
57
+ total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
58
+ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
59
+ shared_y = sub_group_broadcast(y.s7, 1); \
60
+ total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
61
+ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
62
+
63
+
64
+ #define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
65
+ shared_y = sub_group_broadcast(y.s0, 2); \
66
+ total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
67
+ total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
68
+ shared_y = sub_group_broadcast(y.s1, 2); \
69
+ total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
70
+ total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
71
+ shared_y = sub_group_broadcast(y.s2, 2); \
72
+ total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
73
+ total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
74
+ shared_y = sub_group_broadcast(y.s3, 2); \
75
+ total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
76
+ total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
77
+ shared_y = sub_group_broadcast(y.s4, 2); \
78
+ total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
79
+ total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
80
+ shared_y = sub_group_broadcast(y.s5, 2); \
81
+ total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
82
+ total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
83
+ shared_y = sub_group_broadcast(y.s6, 2); \
84
+ total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
85
+ total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
86
+ shared_y = sub_group_broadcast(y.s7, 2); \
87
+ total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
88
+ total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
89
+ shared_y = sub_group_broadcast(y.s0, 3); \
90
+ total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
91
+ total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
92
+ shared_y = sub_group_broadcast(y.s1, 3); \
93
+ total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
94
+ total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
95
+ shared_y = sub_group_broadcast(y.s2, 3); \
96
+ total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
97
+ total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
98
+ shared_y = sub_group_broadcast(y.s3, 3); \
99
+ total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
100
+ total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
101
+ shared_y = sub_group_broadcast(y.s4, 3); \
102
+ total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
103
+ total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
104
+ shared_y = sub_group_broadcast(y.s5, 3); \
105
+ total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
106
+ total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
107
+ shared_y = sub_group_broadcast(y.s6, 3); \
108
+ total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
109
+ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
110
+ shared_y = sub_group_broadcast(y.s7, 3); \
111
+ total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
112
+ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
113
+
114
+
115
+ #define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
116
+ float8 shared_y; \
117
+ shared_y = sub_group_broadcast(y, 0); \
118
+ total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
119
+ total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
120
+ total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
121
+ total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
122
+ total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
123
+ total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
124
+ total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
125
+ total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
126
+ total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
127
+ total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
128
+ total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
129
+ total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
130
+ total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
131
+ total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
132
+ total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
133
+ total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
134
+ shared_y = sub_group_broadcast(y, 1); \
135
+ total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
136
+ total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
137
+ total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
138
+ total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
139
+ total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
140
+ total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
141
+ total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
142
+ total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
143
+ total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
144
+ total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
145
+ total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
146
+ total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
147
+ total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
148
+ total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
149
+ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
150
+ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
151
+
152
+
153
+ #define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
154
+ shared_y = sub_group_broadcast(y, 2); \
155
+ total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
156
+ total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
157
+ total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
158
+ total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
159
+ total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
160
+ total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
161
+ total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
162
+ total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
163
+ total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
164
+ total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
165
+ total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
166
+ total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
167
+ total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
168
+ total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
169
+ total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
170
+ total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
171
+ shared_y = sub_group_broadcast(y, 3); \
172
+ total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
173
+ total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
174
+ total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
175
+ total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
176
+ total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
177
+ total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
178
+ total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
179
+ total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
180
+ total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
181
+ total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
182
+ total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
183
+ total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
184
+ total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
185
+ total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
186
+ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
187
+ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
188
+
189
+
190
+ __attribute__((qcom_reqd_sub_group_size("full")))
191
+ __kernel void kernel_gemv_noshuffle(
192
+ __read_only image1d_buffer_t src0_q, // quantized A
193
+ global half2 * src0_d, // A scales
194
+ __read_only image1d_buffer_t src1, // B
195
+ ulong offset1, // offset to B (0)
196
+ global float * dst, // C
197
+ ulong offsetd, // offset to C (0)
198
+ int ne00, // K
199
+ int ne01, // M
200
+ int ne02, // 1
201
+ int ne10, // K
202
+ int ne12, // 1
203
+ int ne0, // M
204
+ int ne1, // N
205
+ int r2, // 1
206
+ int r3)
207
+ {
208
+ uint groupId = get_local_id(1);
209
+ uint gid = get_global_id(0);
210
+ ushort slid = get_sub_group_local_id();
211
+
212
+ uint K = ne00;
213
+ uint M = ne01;
214
+
215
+ uint LINE_STRIDE_A = M / 2;
216
+ uint BLOCK_STRIDE_A = N_SIMDGROUP * M;
217
+
218
+ __private uint4 regA;
219
+ __private half2 regS;
220
+ __private float8 regB;
221
+
222
+ __private float2 totalSum = (float2)(0.0f);
223
+
224
+ // loop along K in block granularity, skip 4 blocks every iter
225
+ for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
226
+ regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
227
+ // first 4 fibers in each wave load 8 B values to its private scope
228
+ if (slid < 4) {
229
+ regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
230
+ regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
231
+ }
232
+
233
+ // load half weights for two blocks in consecutive rows
234
+ regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
235
+ regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
236
+ regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
237
+ regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
238
+ #ifdef VECTOR_SUB_GROUP_BROADCAT
239
+ dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
240
+ #else
241
+ dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
242
+ #endif // VECTOR_SUB_GROUP_BROADCAT
243
+
244
+ regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
245
+ regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
246
+ regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
247
+ regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
248
+ #ifdef VECTOR_SUB_GROUP_BROADCAT
249
+ dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
250
+ #else
251
+ dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
252
+ #endif // VECTOR_SUB_GROUP_BROADCAT
253
+ }
254
+
255
+ // reduction in local memory, assumes #wave=4
256
+ __local float2 reduceLM[SIMDGROUP_WIDTH * 3];
257
+ if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
258
+ if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
259
+ if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
260
+ barrier(CLK_LOCAL_MEM_FENCE);
261
+ if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
262
+ if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
263
+ if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
264
+
265
+ // 2 outputs per fiber in wave 0
266
+ if (groupId == 0) {
267
+ dst = (global float*)((global char*)dst + offsetd);
268
+ vstore2(totalSum, 0, &(dst[gid * 2]));
269
+ }
270
+
271
+ }
ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //------------------------------------------------------------------------------
2
+ // This file is contains additional mulmat kernels
3
+ // (and potentially other kernels).
4
+ //------------------------------------------------------------------------------
5
+ #ifdef cl_khr_fp16
6
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
7
+ #elif defined(cl_amd_fp16)
8
+ #pragma OPENCL EXTENSION cl_amd_fp16 : enable
9
+ #else
10
+ #error "Half precision floating point not supportedby OpenCL implementation on your device."
11
+ #endif
12
+
13
+ #ifdef cl_khr_subgroups
14
+ #pragma OPENCL EXTENSION cl_khr_subgroups : enable
15
+ #elif defined(cl_intel_subgroups)
16
+ #pragma OPENCL EXTENSION cl_intel_subgroups : enable
17
+ #else
18
+ #error "Subgroup not supported on your device."
19
+ #endif
20
+
21
+ #ifdef cl_intel_required_subgroup_size
22
+ // Always use subgroup size of 32 on Intel.
23
+ #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
24
+ #define INTEL_GPU 1
25
+ #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
26
+ #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
27
+ #elif defined(cl_qcom_reqd_sub_group_size)
28
+ // Always use subgroups size of 64 on Adreno.
29
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
30
+ #define ADRENO_GPU 1
31
+ #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
32
+ #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
33
+ #else
34
+ // TODO: do not know how to choose subgroup size on other GPUs.
35
+ #error "Selecting subgroup size is not supported on your device."
36
+ #endif
37
+
38
+ #define QK4_0 32
39
+ #define QR4_0 2
40
+ #define QK4_1 32
41
+ #define QR4_1 2
42
+ #define QK5_0 32
43
+ #define QR5_0 2
44
+ #define QK5_1 32
45
+ #define QR5_1 2
46
+ #define QK8_0 32
47
+ #define QR8_0 1
48
+ #define QK_K 256
49
+ #define K_QUANTS_PER_ITERATION 2
50
+
51
+ typedef char int8_t;
52
+ typedef uchar uint8_t;
53
+ typedef short int16_t;
54
+ typedef ushort uint16_t;
55
+ typedef int int32_t;
56
+ typedef uint uint32_t;
57
+
58
+ //------------------------------------------------------------------------------
59
+ // block_q4_0
60
+ //------------------------------------------------------------------------------
61
+ struct block_q4_0
62
+ {
63
+ half d;
64
+ uint8_t qs[QK4_0 / 2];
65
+ };
66
+
67
+ //------------------------------------------------------------------------------
68
+ // block_q6_K
69
+ //------------------------------------------------------------------------------
70
+ // 6-bit quantization
71
+ // weight is represented as x = a * q
72
+ // 16 blocks of 16 elements each
73
+ // Effectively 6.5625 bits per weight
74
+ typedef struct {
75
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
76
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
77
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
78
+ half d; // super-block scale
79
+ } block_q6_K;
80
+
81
+ //------------------------------------------------------------------------------
82
+ // These are the variant for matmatmul, based on the matvecmul kernel with
83
+ // flattened block_q4_0.
84
+ //------------------------------------------------------------------------------
85
+
86
+ // Common dot prod.
87
+ inline float mm_block_q_4_0_dot_y_flat(
88
+ global uchar * x,
89
+ global half * dh,
90
+ float sumy,
91
+ float16 yl,
92
+ int il
93
+ ) {
94
+ float d = *dh;
95
+ global ushort * qs = ((global ushort *)x + il/2);
96
+ float acc = 0.f;
97
+
98
+ acc += yl.s0 * (qs[0] & 0x000F);
99
+ acc += yl.s1 * (qs[0] & 0x0F00);
100
+ acc += yl.s8 * (qs[0] & 0x00F0);
101
+ acc += yl.s9 * (qs[0] & 0xF000);
102
+
103
+ acc += yl.s2 * (qs[1] & 0x000F);
104
+ acc += yl.s3 * (qs[1] & 0x0F00);
105
+ acc += yl.sa * (qs[1] & 0x00F0);
106
+ acc += yl.sb * (qs[1] & 0xF000);
107
+
108
+ acc += yl.s4 * (qs[2] & 0x000F);
109
+ acc += yl.s5 * (qs[2] & 0x0F00);
110
+ acc += yl.sc * (qs[2] & 0x00F0);
111
+ acc += yl.sd * (qs[2] & 0xF000);
112
+
113
+ acc += yl.s6 * (qs[3] & 0x000F);
114
+ acc += yl.s7 * (qs[3] & 0x0F00);
115
+ acc += yl.se * (qs[3] & 0x00F0);
116
+ acc += yl.sf * (qs[3] & 0xF000);
117
+
118
+ return d * (sumy * -8.f + acc);
119
+ }
120
+
121
+ #undef N_DST
122
+ #undef N_SIMDGROUP
123
+ #undef N_SIMDWIDTH
124
+
125
+ #ifdef INTEL_GPU
126
+ #define N_DST 8 // each SIMD group works on 8 rows (in weights matrix)
127
+ #define N_SIMDGROUP 1 // number of SIMD groups in a thread group
128
+ #define N_SIMDWIDTH 16 // assuming SIMD group size is 16
129
+ #elif defined (ADRENO_GPU)
130
+ #define N_DST 8
131
+ #define N_SIMDGROUP 1
132
+ #define N_SIMDWIDTH 64
133
+ #endif
134
+ //
135
+ // This variant performs 1d blocking with 8x output.
136
+ // Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix).
137
+ //
138
+ inline void mul_mat_q_n_f32_1d_8x_flat(
139
+ global uchar * src0_q,
140
+ global half * src0_d,
141
+ global float * src1,
142
+ global float * dst,
143
+ int ne00,
144
+ int ne01,
145
+ int ne02,
146
+ int ne10,
147
+ int ne12,
148
+ int ne0,
149
+ int ne1,
150
+ int r2,
151
+ int r3
152
+ ) {
153
+ const int nb = ne00/QK4_0;
154
+
155
+ int r0 = get_group_id(0);
156
+ int r1 = get_group_id(1);
157
+ int im = get_group_id(2);
158
+
159
+ // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
160
+ // a SIMD group in the grid. Each SIMD group produces N_DST values in the
161
+ // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
162
+ // Currently with llama2 7B, im is always 0.
163
+ // TODO: how to handle im/gqa*(nb*ne0)?
164
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
165
+
166
+ int i12 = im%ne12;
167
+ int i13 = im/ne12;
168
+
169
+ // The number of scales is the same as the number of blocks.
170
+ ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
171
+ // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
172
+ ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
173
+
174
+ global uchar * x = (global uchar *) src0_q + offset0_q;
175
+ global half * d = (global half *) src0_d + offset0_d;
176
+ global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
177
+
178
+ float16 yl;
179
+ float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
180
+
181
+ int ix = get_sub_group_local_id()/2;
182
+ int il = 8*(get_sub_group_local_id()%2);
183
+
184
+ global float * yb = y + ix*QK4_0 + il;
185
+
186
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
187
+ float sumy = 0.f;
188
+
189
+ sumy += yb[0];
190
+ sumy += yb[1];
191
+ sumy += yb[2];
192
+ sumy += yb[3];
193
+ sumy += yb[4];
194
+ sumy += yb[5];
195
+ sumy += yb[6];
196
+ sumy += yb[7];
197
+
198
+ sumy += yb[16];
199
+ sumy += yb[17];
200
+ sumy += yb[18];
201
+ sumy += yb[19];
202
+ sumy += yb[20];
203
+ sumy += yb[21];
204
+ sumy += yb[22];
205
+ sumy += yb[23];
206
+
207
+ yl.s0 = yb[0];
208
+ yl.s1 = yb[1]/256.f;
209
+
210
+ yl.s2 = yb[2];
211
+ yl.s3 = yb[3]/256.f;
212
+
213
+ yl.s4 = yb[4];
214
+ yl.s5 = yb[5]/256.f;
215
+
216
+ yl.s6 = yb[6];
217
+ yl.s7 = yb[7]/256.f;
218
+
219
+ yl.s8 = yb[16]/16.f;
220
+ yl.s9 = yb[17]/4096.f;
221
+
222
+ yl.sa = yb[18]/16.f;
223
+ yl.sb = yb[19]/4096.f;
224
+
225
+ yl.sc = yb[20]/16.f;
226
+ yl.sd = yb[21]/4096.f;
227
+
228
+ yl.se = yb[22]/16.f;
229
+ yl.sf = yb[23]/4096.f;
230
+
231
+ sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
232
+ sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
233
+ sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
234
+ sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
235
+
236
+ sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
237
+ sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
238
+ sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
239
+ sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
240
+
241
+ yb += QK4_0 * (N_SIMDWIDTH/2);
242
+ }
243
+
244
+ float8 tot = (float8)(
245
+ sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
246
+ sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
247
+ sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
248
+ sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
249
+ );
250
+
251
+ if (get_sub_group_local_id() == 0) {
252
+ if (first_row + 0 < ne01) {
253
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
254
+ }
255
+ if (first_row + 1 < ne01) {
256
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
257
+ }
258
+ if (first_row + 2 < ne01) {
259
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
260
+ }
261
+ if (first_row + 3 < ne01) {
262
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
263
+ }
264
+
265
+ if (first_row + 4 < ne01) {
266
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
267
+ }
268
+ if (first_row + 5 < ne01) {
269
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
270
+ }
271
+ if (first_row + 6 < ne01) {
272
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
273
+ }
274
+ if (first_row + 7 < ne01) {
275
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
276
+ }
277
+ }
278
+ }
279
+
280
+ #ifdef INTEL_GPU
281
+ REQD_SUBGROUP_SIZE_16
282
+ #elif defined (ADRENO_GPU)
283
+ REQD_SUBGROUP_SIZE_64
284
+ #endif
285
+ kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat(
286
+ global uchar * src0_q,
287
+ global half * src0_d,
288
+ global float * src1,
289
+ ulong offset1,
290
+ global float * dst,
291
+ ulong offsetd,
292
+ int ne00,
293
+ int ne01,
294
+ int ne02,
295
+ int ne10,
296
+ int ne12,
297
+ int ne0,
298
+ int ne1,
299
+ int r2,
300
+ int r3
301
+ ) {
302
+ src1 = (global float*)((global char*)src1 + offset1);
303
+ dst = (global float*)((global char*)dst + offsetd);
304
+
305
+ mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
306
+ }
307
+
308
+ #undef N_DST
309
+ #undef N_SIMDGROUP
310
+ #undef N_SIMDWIDTH
311
+
312
+ #ifdef INTEL_GPU
313
+ #define N_DST 16 // each SIMD group works on 8 rows (in weights matrix)
314
+ #define N_SIMDGROUP 1 // number of SIMD groups in a thread group
315
+ #define N_SIMDWIDTH 16 // assuming SIMD group size is 16
316
+ #elif defined (ADRENO_GPU)
317
+ #define N_DST 16
318
+ #define N_SIMDGROUP 1
319
+ #define N_SIMDWIDTH 64
320
+ #endif
321
+ //
322
+ // This variant performs 1d blocking with 16x output.
323
+ // Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix).
324
+ //
325
+ inline void mul_mat_q_n_f32_1d_16x_flat(
326
+ global uchar * src0_q,
327
+ global half * src0_d,
328
+ global float * src1,
329
+ global float * dst,
330
+ int ne00,
331
+ int ne01,
332
+ int ne02,
333
+ int ne10,
334
+ int ne12,
335
+ int ne0,
336
+ int ne1,
337
+ int r2,
338
+ int r3
339
+ ) {
340
+ const int nb = ne00/QK4_0;
341
+
342
+ int r0 = get_group_id(0);
343
+ int r1 = get_group_id(1);
344
+ int im = get_group_id(2);
345
+
346
+ // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
347
+ // a SIMD group in the grid. Each SIMD group produces N_DST values in the
348
+ // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
349
+ // Currently with llama2 7B, im is always 0.
350
+ // TODO: how to handle im/gqa*(nb*ne0)?
351
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
352
+
353
+ int i12 = im%ne12;
354
+ int i13 = im/ne12;
355
+
356
+ // The number of scales is the same as the number of blocks.
357
+ ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
358
+ // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
359
+ ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
360
+
361
+ global uchar * x = (global uchar *) src0_q + offset0_q;
362
+ global half * d = (global half *) src0_d + offset0_d;
363
+ global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
364
+
365
+ float16 yl;
366
+ float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
367
+ 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
368
+
369
+ int ix = get_sub_group_local_id()/2;
370
+ int il = 8*(get_sub_group_local_id()%2);
371
+
372
+ global float * yb = y + ix*QK4_0 + il;
373
+
374
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
375
+ float sumy = 0.f;
376
+
377
+ sumy += yb[0];
378
+ sumy += yb[1];
379
+ sumy += yb[2];
380
+ sumy += yb[3];
381
+ sumy += yb[4];
382
+ sumy += yb[5];
383
+ sumy += yb[6];
384
+ sumy += yb[7];
385
+
386
+ sumy += yb[16];
387
+ sumy += yb[17];
388
+ sumy += yb[18];
389
+ sumy += yb[19];
390
+ sumy += yb[20];
391
+ sumy += yb[21];
392
+ sumy += yb[22];
393
+ sumy += yb[23];
394
+
395
+ yl.s0 = yb[0];
396
+ yl.s1 = yb[1]/256.f;
397
+
398
+ yl.s2 = yb[2];
399
+ yl.s3 = yb[3]/256.f;
400
+
401
+ yl.s4 = yb[4];
402
+ yl.s5 = yb[5]/256.f;
403
+
404
+ yl.s6 = yb[6];
405
+ yl.s7 = yb[7]/256.f;
406
+
407
+ yl.s8 = yb[16]/16.f;
408
+ yl.s9 = yb[17]/4096.f;
409
+
410
+ yl.sa = yb[18]/16.f;
411
+ yl.sb = yb[19]/4096.f;
412
+
413
+ yl.sc = yb[20]/16.f;
414
+ yl.sd = yb[21]/4096.f;
415
+
416
+ yl.se = yb[22]/16.f;
417
+ yl.sf = yb[23]/4096.f;
418
+
419
+ sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
420
+ sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
421
+ sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
422
+ sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
423
+
424
+ sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
425
+ sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
426
+ sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
427
+ sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
428
+
429
+ sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il);
430
+ sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il);
431
+ sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il);
432
+ sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il);
433
+
434
+ sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il);
435
+ sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il);
436
+ sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il);
437
+ sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il);
438
+
439
+ yb += QK4_0 * (N_SIMDWIDTH/2);
440
+ }
441
+
442
+ float16 tot = (float16)(
443
+ sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
444
+ sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
445
+ sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
446
+ sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7),
447
+
448
+ sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9),
449
+ sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb),
450
+ sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd),
451
+ sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf)
452
+ );
453
+
454
+ if (get_sub_group_local_id() == 0) {
455
+ if (first_row + 0 < ne01) {
456
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
457
+ }
458
+ if (first_row + 1 < ne01) {
459
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
460
+ }
461
+ if (first_row + 2 < ne01) {
462
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
463
+ }
464
+ if (first_row + 3 < ne01) {
465
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
466
+ }
467
+
468
+ if (first_row + 4 < ne01) {
469
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
470
+ }
471
+ if (first_row + 5 < ne01) {
472
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
473
+ }
474
+ if (first_row + 6 < ne01) {
475
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
476
+ }
477
+ if (first_row + 7 < ne01) {
478
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
479
+ }
480
+
481
+ if (first_row + 8 < ne01) {
482
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8;
483
+ }
484
+ if (first_row + 9 < ne01) {
485
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9;
486
+ }
487
+ if (first_row + 10 < ne01) {
488
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa;
489
+ }
490
+ if (first_row + 11 < ne01) {
491
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb;
492
+ }
493
+
494
+ if (first_row + 12 < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc;
496
+ }
497
+ if (first_row + 13 < ne01) {
498
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd;
499
+ }
500
+ if (first_row + 14 < ne01) {
501
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se;
502
+ }
503
+ if (first_row + 15 < ne01) {
504
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf;
505
+ }
506
+ }
507
+ }
508
+
509
+ #ifdef INTEL_GPU
510
+ REQD_SUBGROUP_SIZE_16
511
+ #elif defined (ADRENO_GPU)
512
+ REQD_SUBGROUP_SIZE_64
513
+ #endif
514
+ kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat(
515
+ global uchar * src0_q,
516
+ global half * src0_d,
517
+ global float * src1,
518
+ ulong offset1,
519
+ global float * dst,
520
+ ulong offsetd,
521
+ int ne00,
522
+ int ne01,
523
+ int ne02,
524
+ int ne10,
525
+ int ne12,
526
+ int ne0,
527
+ int ne1,
528
+ int r2,
529
+ int r3
530
+ ) {
531
+ src1 = (global float*)((global char*)src1 + offset1);
532
+ dst = (global float*)((global char*)dst + offsetd);
533
+
534
+ mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
535
+ }
536
+
537
+ //------------------------------------------------------------------------------
538
+ // kernel_mul_mat_q4_0_f32_flat_v0
539
+ //------------------------------------------------------------------------------
540
+ inline float block_q_4_0_dot_y_flat_v2(
541
+ half x,
542
+ half d,
543
+ float sumy,
544
+ float4 yl
545
+ ) {
546
+ uchar2 q = as_uchar2(x);
547
+ float acc = 0.0f;
548
+
549
+ acc += (q.s0 & 0x0F) * yl.s0;
550
+ acc += (q.s1 & 0x0F) * yl.s1;
551
+
552
+ acc += (q.s0 & 0xF0) * yl.s2;
553
+ acc += (q.s1 & 0xF0) * yl.s3;
554
+
555
+ return d * (sumy * -8.f + acc);;
556
+ }
557
+
558
+ inline float block_q_4_0_dot_y_flat_v4(
559
+ float x,
560
+ half d,
561
+ float sumy,
562
+ float8 yl
563
+ ) {
564
+ uchar4 q = as_uchar4(x);
565
+ float acc = 0.0f;
566
+
567
+ acc += (q.s0 & 0x0F) * yl.s0;
568
+ acc += (q.s1 & 0x0F) * yl.s1;
569
+ acc += (q.s2 & 0x0F) * yl.s2;
570
+ acc += (q.s3 & 0x0F) * yl.s3;
571
+
572
+ acc += (q.s0 & 0xF0) * yl.s4;
573
+ acc += (q.s1 & 0xF0) * yl.s5;
574
+ acc += (q.s2 & 0xF0) * yl.s6;
575
+ acc += (q.s3 & 0xF0) * yl.s7;
576
+
577
+ return d * (sumy * -8.f + acc);;
578
+ }
579
+
580
+ inline float block_q_4_0_dot_y_flat_v8(
581
+ float2 x,
582
+ half d,
583
+ float sumy,
584
+ float16 yl
585
+ ) {
586
+ uchar8 q = as_uchar8(x);
587
+ float acc = 0.0f;
588
+
589
+ acc += (q.s0 & 0x0F) * yl.s0;
590
+ acc += (q.s1 & 0x0F) * yl.s1;
591
+ acc += (q.s2 & 0x0F) * yl.s2;
592
+ acc += (q.s3 & 0x0F) * yl.s3;
593
+ acc += (q.s4 & 0x0F) * yl.s4;
594
+ acc += (q.s5 & 0x0F) * yl.s5;
595
+ acc += (q.s6 & 0x0F) * yl.s6;
596
+ acc += (q.s7 & 0x0F) * yl.s7;
597
+
598
+ acc += (q.s0 & 0xF0) * yl.s8;
599
+ acc += (q.s1 & 0xF0) * yl.s9;
600
+ acc += (q.s2 & 0xF0) * yl.sa;
601
+ acc += (q.s3 & 0xF0) * yl.sb;
602
+ acc += (q.s4 & 0xF0) * yl.sc;
603
+ acc += (q.s5 & 0xF0) * yl.sd;
604
+ acc += (q.s6 & 0xF0) * yl.se;
605
+ acc += (q.s7 & 0xF0) * yl.sf;
606
+
607
+ return d * (sumy * -8.f + acc);;
608
+ }
609
+
610
+ #undef N_DST
611
+ #undef N_SIMDGROUP
612
+ #undef N_SIMDWIDTH
613
+
614
+ #ifdef INTEL_GPU
615
+ #define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block
616
+ #define N_DST 4
617
+ #define N_SIMDGROUP 1
618
+ #define N_SIMDWIDTH 16
619
+ #elif defined (ADRENO_GPU)
620
+ #define THREADS_PER_BLK 4
621
+ #define N_DST 4
622
+ #define N_SIMDGROUP 1
623
+ #define N_SIMDWIDTH 64
624
+ #endif
625
+
626
+ #if THREADS_PER_BLK == 2 // Each thread processes 1/2 block
627
+ # define ACT_TY float16
628
+ # define Q_BLK_LD_TY float2
629
+ # define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8
630
+ #elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block
631
+ # define ACT_TY float8
632
+ # define Q_BLK_LD_TY float
633
+ # define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4
634
+ #elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block
635
+ # define ACT_TY float4
636
+ # define Q_BLK_LD_TY half
637
+ # define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2
638
+ #endif
639
+
640
+ #define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK)
641
+
642
+ #if N_DST == 2
643
+ # define SUM_TY float2
644
+ #elif N_DST == 4
645
+ # define SUM_TY float4
646
+ #elif N_DST == 8
647
+ # define SUM_TY float8
648
+ #elif N_DST == 16
649
+ # define SUM_TY float16
650
+ #endif
651
+
652
+ #ifdef INTEL_GPU
653
+ REQD_SUBGROUP_SIZE_16
654
+ #elif defined (ADRENO_GPU)
655
+ REQD_SUBGROUP_SIZE_64
656
+ #endif
657
+ kernel void kernel_mul_mat_q4_0_f32_flat_v0(
658
+ global uchar * src0_q,
659
+ global half * src0_d,
660
+ global float * src1,
661
+ ulong offset1,
662
+ global float * dst,
663
+ ulong offsetd,
664
+ int ne00,
665
+ int ne01,
666
+ int ne02,
667
+ int ne10,
668
+ int ne12,
669
+ int ne0,
670
+ int ne1,
671
+ int r2,
672
+ int r3
673
+ ) {
674
+ src1 = (global float*)((global char*)src1 + offset1);
675
+ dst = (global float*)((global char*)dst + offsetd);
676
+
677
+ const int nb = ne00/QK4_0;
678
+
679
+ int r0 = get_group_id(0);
680
+ int r1 = get_group_id(1);
681
+ int im = get_group_id(2);
682
+
683
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
684
+
685
+ int i12 = im%ne12;
686
+ int i13 = im/ne12;
687
+
688
+ // The number of scales is the same as the number of blocks.
689
+ ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
690
+ // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
691
+ ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
692
+
693
+ global uchar * x = (global uchar *) src0_q + offset0_q;
694
+ global half * d = (global half *) src0_d + offset0_d;
695
+ global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
696
+
697
+ int ix = get_sub_group_local_id()/THREADS_PER_BLK;
698
+ int il = get_sub_group_local_id()%THREADS_PER_BLK;
699
+
700
+ global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il;
701
+
702
+ // Registers for caching activation
703
+ ACT_TY yl = 0.f;
704
+
705
+ // Registers for caching quants
706
+ Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0;
707
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
708
+ Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0;
709
+ #endif
710
+ #if N_DST == 8 || N_DST == 16
711
+ Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0;
712
+ #endif
713
+
714
+ // Partial sum
715
+ SUM_TY sumf = 0.f;
716
+
717
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) {
718
+ float sumy = 0.f;
719
+
720
+ q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2);
721
+ q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2);
722
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
723
+ q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2);
724
+ q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2);
725
+ #endif
726
+ #if N_DST == 8 || N_DST == 16
727
+ q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2));
728
+ q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2));
729
+ q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2));
730
+ q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2));
731
+ #endif
732
+
733
+ // Load activation
734
+ #if THREADS_PER_BLK == 2 // Each thread processes 1/2 block
735
+ yl.s01234567 = *(global float8 *)(yb);
736
+ yl.s89abcdef = *(global float8 *)(yb + 16);
737
+
738
+ sumy += yl.s0;
739
+ sumy += yl.s1;
740
+ sumy += yl.s2;
741
+ sumy += yl.s3;
742
+ sumy += yl.s4;
743
+ sumy += yl.s5;
744
+ sumy += yl.s6;
745
+ sumy += yl.s7;
746
+ sumy += yl.s8; yl.s8 /= 16.f;
747
+ sumy += yl.s9; yl.s9 /= 16.f;
748
+ sumy += yl.sa; yl.sa /= 16.f;
749
+ sumy += yl.sb; yl.sb /= 16.f;
750
+ sumy += yl.sc; yl.sc /= 16.f;
751
+ sumy += yl.sd; yl.sd /= 16.f;
752
+ sumy += yl.se; yl.se /= 16.f;
753
+ sumy += yl.sf; yl.sf /= 16.f;
754
+ #elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block
755
+ yl.s0123 = *(global float4 *)(yb);
756
+ yl.s4567 = *(global float4 *)(yb + 16);
757
+
758
+ sumy += yl.s0;
759
+ sumy += yl.s1;
760
+ sumy += yl.s2;
761
+ sumy += yl.s3;
762
+ sumy += yl.s4; yl.s4 /= 16.f;
763
+ sumy += yl.s5; yl.s5 /= 16.f;
764
+ sumy += yl.s6; yl.s6 /= 16.f;
765
+ sumy += yl.s7; yl.s7 /= 16.f;
766
+ #elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block
767
+ yl.s01 = *(global float2 *)(yb);
768
+ yl.s23 = *(global float2 *)(yb + 16);
769
+
770
+ sumy += yl.s0;
771
+ sumy += yl.s1;
772
+ sumy += yl.s2; yl.s2 /= 16.f;
773
+ sumy += yl.s3; yl.s3 /= 16.f;
774
+ #endif
775
+
776
+ sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl);
777
+ sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl);
778
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
779
+ sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl);
780
+ sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl);
781
+ #endif
782
+ #if N_DST == 8 || N_DST == 16
783
+ sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl);
784
+ sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl);
785
+ sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl);
786
+ sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl);
787
+ #endif
788
+
789
+ yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK);
790
+ }
791
+
792
+ SUM_TY tot = (SUM_TY)(
793
+ sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1)
794
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
795
+ , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
796
+ #endif
797
+ #if N_DST == 8 || N_DST == 16
798
+ , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5)
799
+ , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
800
+ #endif
801
+ );
802
+
803
+ if (get_sub_group_local_id() == 0) {
804
+ if (first_row + 0 < ne01) {
805
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
806
+ }
807
+ if (first_row + 1 < ne01) {
808
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
809
+ }
810
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
811
+ if (first_row + 2 < ne01) {
812
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
813
+ }
814
+ if (first_row + 3 < ne01) {
815
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
816
+ }
817
+ #endif
818
+ #if N_DST == 8 || N_DST == 16
819
+ if (first_row + 4 < ne01) {
820
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
821
+ }
822
+ if (first_row + 5 < ne01) {
823
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
824
+ }
825
+ if (first_row + 6 < ne01) {
826
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
827
+ }
828
+ if (first_row + 7 < ne01) {
829
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
830
+ }
831
+ #endif
832
+ }
833
+ }
834
+
835
+ //------------------------------------------------------------------------------
836
+ // Using image1d_buffer_t
837
+
838
+ #if defined(cl_qcom_subgroup_shuffle)
839
+ #pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
840
+ float qcom_sub_group_reduce_add(float sum) {
841
+ sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
842
+ sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
843
+ sum += qcom_sub_group_shuffle_down(sum, 8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
844
+ sum += qcom_sub_group_shuffle_down(sum, 4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
845
+ sum += qcom_sub_group_shuffle_down(sum, 2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
846
+ sum += qcom_sub_group_shuffle_down(sum, 1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
847
+ return sum;
848
+ }
849
+ #define sub_group_reduce_add qcom_sub_group_reduce_add
850
+ #else
851
+ #define sub_group_reduce_add sub_group_reduce_add
852
+ #endif
853
+
854
+ #undef THREADS_PER_BLK
855
+ #undef N_DST
856
+ #undef N_SIMDGROUP
857
+ #undef N_SIMDWIDTH
858
+
859
+ #ifdef INTEL_GPU
860
+ #define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block
861
+ #define N_DST 4
862
+ #define N_SIMDGROUP 1
863
+ #define N_SIMDWIDTH 16
864
+ #elif defined (ADRENO_GPU)
865
+ #define THREADS_PER_BLK 4
866
+ #define N_DST 4
867
+ #define N_SIMDGROUP 1
868
+ #define N_SIMDWIDTH 64
869
+ #endif
870
+
871
+ #if THREADS_PER_BLK == 2 // Each thread processes 1/2 block
872
+ # define ACT_TY float16
873
+ # define Q_BLK_LD_TY float2
874
+ # define EXTRACT_BLK_DATA(tmp, part) *((float2*)&tmp + part)
875
+ # define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8
876
+ #elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block
877
+ # define ACT_TY float8
878
+ # define Q_BLK_LD_TY float
879
+ # define EXTRACT_BLK_DATA(tmp, part) *((float*)&tmp + part)
880
+ # define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4
881
+ #elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block
882
+ # define ACT_TY float4
883
+ # define Q_BLK_LD_TY half
884
+ # define EXTRACT_BLK_DATA(tmp, part) *((half*)&tmp + part)
885
+ # define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2
886
+ #endif
887
+
888
+ #define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK)
889
+
890
+ #if N_DST == 2
891
+ # define SUM_TY float2
892
+ #elif N_DST == 4
893
+ # define SUM_TY float4
894
+ #elif N_DST == 8
895
+ # define SUM_TY float8
896
+ #elif N_DST == 16
897
+ # define SUM_TY float16
898
+ #endif
899
+
900
+ #ifdef INTEL_GPU
901
+ REQD_SUBGROUP_SIZE_16
902
+ #elif defined (ADRENO_GPU)
903
+ REQD_SUBGROUP_SIZE_64
904
+ #endif
905
+ kernel void kernel_mul_mat_q4_0_f32_flat_img_v0(
906
+ read_only image1d_buffer_t src0_q,
907
+ read_only image1d_buffer_t src0_d,
908
+ global float * src1,
909
+ ulong offset1,
910
+ global float * dst,
911
+ ulong offsetd,
912
+ int ne00,
913
+ int ne01,
914
+ int ne02,
915
+ int ne10,
916
+ int ne12,
917
+ int ne0,
918
+ int ne1,
919
+ int r2,
920
+ int r3
921
+ ) {
922
+ src1 = (global float*)((global char*)src1 + offset1);
923
+ dst = (global float*)((global char*)dst + offsetd);
924
+
925
+ const int nb = ne00/QK4_0;
926
+
927
+ int r0 = get_group_id(0);
928
+ int r1 = get_group_id(1);
929
+ int im = get_group_id(2);
930
+
931
+ int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
932
+
933
+ int i12 = im%ne12;
934
+ int i13 = im/ne12;
935
+
936
+ // The number of scales is the same as the number of blocks.
937
+ ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
938
+ // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
939
+ ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
940
+
941
+ global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
942
+
943
+ int ix = get_sub_group_local_id()/THREADS_PER_BLK;
944
+ int il = get_sub_group_local_id()%THREADS_PER_BLK;
945
+
946
+ global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il;
947
+
948
+ // Registers for caching activation
949
+ ACT_TY yl = 0.f;
950
+
951
+ // Registers for caching quants
952
+ Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0;
953
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
954
+ Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0;
955
+ #endif
956
+ #if N_DST == 8 || N_DST == 16
957
+ Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0;
958
+ #endif
959
+
960
+ // Partial sum
961
+ SUM_TY sumf = 0.f;
962
+
963
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) {
964
+ float sumy = 0.f;;
965
+
966
+ float4 tmp;
967
+ tmp = read_imagef(src0_q, offset0_q + ib + 0*nb);
968
+ q_blk_0 = EXTRACT_BLK_DATA(tmp, il);
969
+ tmp = read_imagef(src0_q, offset0_q + ib + 1*nb);
970
+ q_blk_1 = EXTRACT_BLK_DATA(tmp, il);
971
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
972
+ tmp = read_imagef(src0_q, offset0_q + ib + 2*nb);
973
+ q_blk_2 = EXTRACT_BLK_DATA(tmp, il);
974
+ tmp = read_imagef(src0_q, offset0_q + ib + 3*nb);
975
+ q_blk_3 = EXTRACT_BLK_DATA(tmp, il);
976
+ #endif
977
+ #if N_DST == 8 || N_DST == 16
978
+ tmp = read_imagef(src0_q, offset0_q + ib + 4*nb);
979
+ q_blk_4 = EXTRACT_BLK_DATA(tmp, il);
980
+ tmp = read_imagef(src0_q, offset0_q + ib + 5*nb);
981
+ q_blk_5 = EXTRACT_BLK_DATA(tmp, il);
982
+ tmp = read_imagef(src0_q, offset0_q + ib + 6*nb);
983
+ q_blk_6 = EXTRACT_BLK_DATA(tmp, il);
984
+ tmp = read_imagef(src0_q, offset0_q + ib + 7*nb);
985
+ q_blk_7 = EXTRACT_BLK_DATA(tmp, il);
986
+ #endif
987
+
988
+ // Load activation
989
+ #if THREADS_PER_BLK == 2 // Each thread processes 1/2 block
990
+ yl.s01234567 = *(global float8 *)(yb);
991
+ yl.s89abcdef = *(global float8 *)(yb + 16);
992
+
993
+ sumy += yl.s0;
994
+ sumy += yl.s1;
995
+ sumy += yl.s2;
996
+ sumy += yl.s3;
997
+ sumy += yl.s4;
998
+ sumy += yl.s5;
999
+ sumy += yl.s6;
1000
+ sumy += yl.s7;
1001
+ sumy += yl.s8; yl.s8 /= 16.f;
1002
+ sumy += yl.s9; yl.s9 /= 16.f;
1003
+ sumy += yl.sa; yl.sa /= 16.f;
1004
+ sumy += yl.sb; yl.sb /= 16.f;
1005
+ sumy += yl.sc; yl.sc /= 16.f;
1006
+ sumy += yl.sd; yl.sd /= 16.f;
1007
+ sumy += yl.se; yl.se /= 16.f;
1008
+ sumy += yl.sf; yl.sf /= 16.f;
1009
+ #elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block
1010
+ yl.s0123 = *(global float4 *)(yb);
1011
+ yl.s4567 = *(global float4 *)(yb + 16);
1012
+
1013
+ sumy += yl.s0;
1014
+ sumy += yl.s1;
1015
+ sumy += yl.s2;
1016
+ sumy += yl.s3;
1017
+ sumy += yl.s4; yl.s4 /= 16.f;
1018
+ sumy += yl.s5; yl.s5 /= 16.f;
1019
+ sumy += yl.s6; yl.s6 /= 16.f;
1020
+ sumy += yl.s7; yl.s7 /= 16.f;
1021
+ #elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block
1022
+ yl.s01 = *(global float2 *)(yb);
1023
+ yl.s23 = *(global float2 *)(yb + 16);
1024
+
1025
+ sumy += yl.s0;
1026
+ sumy += yl.s1;
1027
+ sumy += yl.s2; yl.s2 /= 16.f;
1028
+ sumy += yl.s3; yl.s3 /= 16.f;
1029
+ #endif
1030
+
1031
+ sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl);
1032
+ sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl);
1033
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
1034
+ sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl);
1035
+ sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl);
1036
+ #endif
1037
+ #if N_DST == 8 || N_DST == 16
1038
+ sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl);
1039
+ sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl);
1040
+ sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl);
1041
+ sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl);
1042
+ #endif
1043
+
1044
+ yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK);
1045
+ }
1046
+
1047
+ SUM_TY tot = (SUM_TY)(
1048
+ sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1)
1049
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
1050
+ , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
1051
+ #endif
1052
+ #if N_DST == 8 || N_DST == 16
1053
+ , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5)
1054
+ , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
1055
+ #endif
1056
+ );
1057
+
1058
+ if (get_sub_group_local_id() == 0) {
1059
+ if (first_row + 0 < ne01) {
1060
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
1061
+ }
1062
+ if (first_row + 1 < ne01) {
1063
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
1064
+ }
1065
+ #if N_DST == 4 || N_DST == 8 || N_DST == 16
1066
+ if (first_row + 2 < ne01) {
1067
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
1068
+ }
1069
+ if (first_row + 3 < ne01) {
1070
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
1071
+ }
1072
+ #endif
1073
+ #if N_DST == 8 || N_DST == 16
1074
+ if (first_row + 4 < ne01) {
1075
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
1076
+ }
1077
+ if (first_row + 5 < ne01) {
1078
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
1079
+ }
1080
+ if (first_row + 6 < ne01) {
1081
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
1082
+ }
1083
+ if (first_row + 7 < ne01) {
1084
+ dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
1085
+ }
1086
+ #endif
1087
+ }
1088
+ }
1089
+
1090
+ //------------------------------------------------------------------------------
1091
+ // kernel_mul_mv_q6_K_f32
1092
+ //------------------------------------------------------------------------------
1093
+
1094
+ #undef N_DST
1095
+ #undef N_SIMDGROUP
1096
+ #undef N_SIMDWIDTH
1097
+
1098
+ #ifdef INTEL_GPU
1099
+ #define N_DST 1 // number of rows each SIMD group works on
1100
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
1101
+ #define N_SIMDWIDTH 16 // SIMD group size
1102
+ #elif defined (ADRENO_GPU)
1103
+ #define N_DST 1
1104
+ #define N_SIMDGROUP 2
1105
+ #define N_SIMDWIDTH 64
1106
+ #endif
1107
+
1108
+ #define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
1109
+
1110
+ #ifdef INTEL_GPU
1111
+ REQD_SUBGROUP_SIZE_16
1112
+ #elif defined (ADRENO_GPU)
1113
+ REQD_SUBGROUP_SIZE_64
1114
+ #endif
1115
+ kernel void kernel_mul_mv_q6_K_f32(
1116
+ global void * src0,
1117
+ ulong offset0,
1118
+ global float * src1,
1119
+ ulong offset1,
1120
+ global float * dst,
1121
+ ulong offsetd,
1122
+ int ne00,
1123
+ int ne01,
1124
+ int ne02,
1125
+ int ne10,
1126
+ int ne12,
1127
+ int ne0,
1128
+ int ne1,
1129
+ int r2,
1130
+ int r3
1131
+ ) {
1132
+ src0 = (global void*)((global char*)src0 + offset0);
1133
+ src1 = (global float*)((global char*)src1 + offset1);
1134
+ dst = (global float*)((global char*)dst + offsetd);
1135
+
1136
+ uchar kmask1 = 0x03;
1137
+ uchar kmask2 = 0x0C;
1138
+ uchar kmask3 = 0x30;
1139
+ uchar kmask4 = 0xC0;
1140
+
1141
+ int nb = ne00/QK_K;
1142
+
1143
+ int r0 = get_group_id(0);
1144
+ int r1 = get_group_id(1);
1145
+ int im = get_group_id(2);
1146
+
1147
+ int row = N_SIMDGROUP * r0 + get_sub_group_id();
1148
+
1149
+ int i12 = im%ne12;
1150
+ int i13 = im/ne12;
1151
+
1152
+ ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
1153
+
1154
+ global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0;
1155
+ global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1;
1156
+
1157
+ float sumf = 0;
1158
+
1159
+ // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a
1160
+ // block. Values in a subblock shares a scale that is quantized with 8 bits;
1161
+ // the entire block shares a single floating point scale.
1162
+ // For work distribution, each thread processes a subblock (16 weights), hence
1163
+ // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16
1164
+ // (super) blocks -- this is the block stride.
1165
+ // The 16 threads that process a (super) block are split into 2 portions, each has
1166
+ // 8 threads; each portion works on 8 subblocks.
1167
+ // For subgroup of 16 threads, the entire subgroup works on a single (super) block
1168
+ // before moving to the next (super) block. Thread0 - thread7 work on the
1169
+ // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks.
1170
+ // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on
1171
+ // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but
1172
+ // works on a total of 16 weight values.
1173
+ int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
1174
+ int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
1175
+ int ip = tid/8; // first or second half of (super) block (0 or 1)
1176
+ int il = tid%8; // each half has 8 parts, one per scale
1177
+ int n = 4; // 4 scales at a time (and 4 sums)
1178
+ int l0 = n*il; // offset into half-block, 0..28
1179
+ int is = 8*ip + l0/16; // 0, 1, 8, 9
1180
+
1181
+ int y_offset = 128*ip + l0;
1182
+ int q_offset_l = 64*ip + l0;
1183
+ int q_offset_h = 32*ip + l0;
1184
+
1185
+ for (int i = ix; i < nb; i += BLOCK_STRIDE) {
1186
+
1187
+ global uint8_t * q1 = x[i].ql + q_offset_l;
1188
+ global uint8_t * q2 = q1 + QK_K/8;
1189
+ global uint8_t * qh = x[i].qh + q_offset_h;
1190
+ global int8_t * sc = x[i].scales + is;
1191
+
1192
+ global float * y = yy + i * QK_K + y_offset;
1193
+
1194
+ float dall = x[i].d;
1195
+
1196
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1197
+
1198
+ sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f);
1199
+ sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f);
1200
+ sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f);
1201
+ sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f);
1202
+
1203
+ sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f);
1204
+ sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f);
1205
+ sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f);
1206
+ sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f);
1207
+
1208
+ sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f);
1209
+ sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f);
1210
+ sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f);
1211
+ sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f);
1212
+
1213
+ sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f);
1214
+ sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f);
1215
+ sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f);
1216
+ sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f);
1217
+
1218
+ sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
1219
+ }
1220
+
1221
+ float tot = sub_group_reduce_add(sumf);
1222
+ if (get_sub_group_local_id() == 0) {
1223
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1224
+ }
1225
+ }
ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src0_q, src0_d, src1 are transposed as a preprocessing step
2
+ // 4-bit weights are transposed in groups of 4 (unsigned short int)
3
+ // consider weights originally "next to each other", now "on top of each other"
4
+ // each fiber computes a 8x4 tile of output elements
5
+ // using unshuffled weights
6
+
7
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
8
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
9
+
10
+ __attribute__((qcom_reqd_sub_group_size("full")))
11
+ kernel void kernel_mul_mat_Ab_Bi_8x4(
12
+ global const ushort * src0_q, // quantized A
13
+ global const half * src0_d, // A scales
14
+ __read_only image1d_buffer_t src1, // B (1d image)
15
+ global float * dst, // C
16
+ int m, // M
17
+ int n, // N with padding
18
+ int k, // K
19
+ int n_no_padding // N without padding
20
+ ) {
21
+
22
+ int m_4 = m >> 2;
23
+ int n_4 = n >> 2;
24
+
25
+ int gy = get_global_id(0);
26
+ int gx = get_global_id(1);
27
+ int gx_2 = gx << 2;
28
+
29
+ half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements
30
+ half8 B; // registers for activations
31
+ half4 dequantized_weights; // registers for dequantized weights
32
+ __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights
33
+ __global const half* scale_ptr = src0_d + gx_2; // pointer for scales
34
+
35
+ for(int i=0; i<k; i+=4){ //loop through K dimension
36
+
37
+ B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
38
+ B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
39
+
40
+ // keep (i/4) and (i/32) in parenthesis, rounds down
41
+ // load 4 consecutive groups of 4 weights
42
+ ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); // (i/4) because weights grouped in 4s
43
+
44
+ // load 4 consecutive scales
45
+ half4 scale = vload4(0, scale_ptr + (i/32)*(m));// (i/32) because 1 scale per 32 elements
46
+
47
+ // j=0
48
+ dequantized_weights.s0 = ((bits4.s0 & (0x000F)) - 8) * scale.s0; // dequantize a row of the 16 weights
49
+ dequantized_weights.s1 = ((bits4.s1 & (0x000F)) - 8) * scale.s1;
50
+ dequantized_weights.s2 = ((bits4.s2 & (0x000F)) - 8) * scale.s2;
51
+ dequantized_weights.s3 = ((bits4.s3 & (0x000F)) - 8) * scale.s3;
52
+ c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
53
+ c1 += B * dequantized_weights.s1;
54
+ c2 += B * dequantized_weights.s2;
55
+ c3 += B * dequantized_weights.s3;
56
+
57
+ // j=1
58
+ B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
59
+ B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
60
+ dequantized_weights.s0 = (((bits4.s0 & (0x00F0)) >> 4) - 8) * scale.s0; // dequantize a row of the 16 weights
61
+ dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1;
62
+ dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2;
63
+ dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3;
64
+ c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate
65
+ c1 += B * dequantized_weights.s1;
66
+ c2 += B * dequantized_weights.s2;
67
+ c3 += B * dequantized_weights.s3;
68
+
69
+ // j=2
70
+ B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
71
+ B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
72
+ dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights
73
+ dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1;
74
+ dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2;
75
+ dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3;
76
+ c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
77
+ c1 += B * dequantized_weights.s1;
78
+ c2 += B * dequantized_weights.s2;
79
+ c3 += B * dequantized_weights.s3;
80
+
81
+ // j=3
82
+ B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
83
+ B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
84
+ dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights
85
+ dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1;
86
+ dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2;
87
+ dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3;
88
+ c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
89
+ c1 += B * dequantized_weights.s1;
90
+ c2 += B * dequantized_weights.s2;
91
+ c3 += B * dequantized_weights.s3;
92
+ }
93
+
94
+ int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements
95
+
96
+ // conditional check if store is to a valid location. Required when N is not a multiple of 8
97
+ // if statements allow registers to be reused for each store
98
+ // provides a performance boost due to reduced register footprint, which increases number of concurrent waves
99
+ if(idx+3 < m*n_no_padding){
100
+ vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
101
+ idx += m;
102
+ }
103
+ if(idx+3 < m*n_no_padding){
104
+ vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
105
+ idx += m;
106
+ }
107
+ if(idx+3 < m*n_no_padding){
108
+ vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
109
+ idx += m;
110
+ }
111
+ if(idx+3 < m*n_no_padding){
112
+ vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
113
+ idx += m;
114
+ }
115
+ if(idx+3 < m*n_no_padding){
116
+ vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
117
+ idx += m;
118
+ }
119
+ if(idx+3 < m*n_no_padding){
120
+ vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
121
+ idx += m;
122
+ }
123
+ if(idx+3 < m*n_no_padding){
124
+ vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
125
+ idx += m;
126
+ }
127
+ if(idx+3 < m*n_no_padding){
128
+ vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
129
+ }
130
+ }
ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // 16-bit transpose, loading/storing an 8x8 tile of elements
2
+
3
+ kernel void kernel_transpose_16(
4
+ __read_only image1d_buffer_t input,
5
+ __write_only image1d_buffer_t output,
6
+ const uint rows,
7
+ const uint cols
8
+ ) {
9
+
10
+ const int i = get_global_id(0);
11
+ const int j = get_global_id(1);
12
+ const int i_3 = i<<3;
13
+ const int j_3 = j<<3;
14
+
15
+ ushort8 temp0 = as_ushort8(read_imagef(input, (j_3+0)*cols+i));
16
+ ushort8 temp1 = as_ushort8(read_imagef(input, (j_3+1)*cols+i));
17
+ ushort8 temp2 = as_ushort8(read_imagef(input, (j_3+2)*cols+i));
18
+ ushort8 temp3 = as_ushort8(read_imagef(input, (j_3+3)*cols+i));
19
+ ushort8 temp4 = as_ushort8(read_imagef(input, (j_3+4)*cols+i));
20
+ ushort8 temp5 = as_ushort8(read_imagef(input, (j_3+5)*cols+i));
21
+ ushort8 temp6 = as_ushort8(read_imagef(input, (j_3+6)*cols+i));
22
+ ushort8 temp7 = as_ushort8(read_imagef(input, (j_3+7)*cols+i));
23
+
24
+ write_imagef(output, (i_3+0)*rows+j, as_float4((ushort8)(temp0.s0, temp1.s0, temp2.s0, temp3.s0, temp4.s0, temp5.s0, temp6.s0, temp7.s0)));
25
+ write_imagef(output, (i_3+1)*rows+j, as_float4((ushort8)(temp0.s1, temp1.s1, temp2.s1, temp3.s1, temp4.s1, temp5.s1, temp6.s1, temp7.s1)));
26
+ write_imagef(output, (i_3+2)*rows+j, as_float4((ushort8)(temp0.s2, temp1.s2, temp2.s2, temp3.s2, temp4.s2, temp5.s2, temp6.s2, temp7.s2)));
27
+ write_imagef(output, (i_3+3)*rows+j, as_float4((ushort8)(temp0.s3, temp1.s3, temp2.s3, temp3.s3, temp4.s3, temp5.s3, temp6.s3, temp7.s3)));
28
+ write_imagef(output, (i_3+4)*rows+j, as_float4((ushort8)(temp0.s4, temp1.s4, temp2.s4, temp3.s4, temp4.s4, temp5.s4, temp6.s4, temp7.s4)));
29
+ write_imagef(output, (i_3+5)*rows+j, as_float4((ushort8)(temp0.s5, temp1.s5, temp2.s5, temp3.s5, temp4.s5, temp5.s5, temp6.s5, temp7.s5)));
30
+ write_imagef(output, (i_3+6)*rows+j, as_float4((ushort8)(temp0.s6, temp1.s6, temp2.s6, temp3.s6, temp4.s6, temp5.s6, temp6.s6, temp7.s6)));
31
+ write_imagef(output, (i_3+7)*rows+j, as_float4((ushort8)(temp0.s7, temp1.s7, temp2.s7, temp3.s7, temp4.s7, temp5.s7, temp6.s7, temp7.s7)));
32
+ }
ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // 32-bit transpose, loading/storing a 4x4 tile of elements
2
+
3
+ kernel void kernel_transpose_32(
4
+ __read_only image1d_buffer_t input,
5
+ __write_only image1d_buffer_t output,
6
+ const uint rows,
7
+ const uint cols
8
+ ) {
9
+
10
+ const int i = get_global_id(0);
11
+ const int j = get_global_id(1);
12
+ const int i_2 = i<<2;
13
+ const int j_2 = j<<2;
14
+
15
+ float4 temp0 = read_imagef(input, (j_2+0)*cols+i);
16
+ float4 temp1 = read_imagef(input, (j_2+1)*cols+i);
17
+ float4 temp2 = read_imagef(input, (j_2+2)*cols+i);
18
+ float4 temp3 = read_imagef(input, (j_2+3)*cols+i);
19
+
20
+ write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
21
+ write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
22
+ write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
23
+ write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
24
+
25
+ }
ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // 32-bit transpose, loading/storing a 4x4 tile of elements
2
+ // Only used for activations
3
+ // converts to FP16
4
+ // also adds zero padding for non multiple of 8 prompt lengths
5
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
6
+
7
+ kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {
8
+
9
+ const int i = get_global_id(0);
10
+ const int j = get_global_id(1);
11
+ const int i_2 = i<<2;
12
+ const int j_2 = j<<2;
13
+ half4 temp0 = {0,0,0,0}; // initialize outputs to 0
14
+ half4 temp1 = {0,0,0,0};
15
+ half4 temp2 = {0,0,0,0};
16
+ half4 temp3 = {0,0,0,0};
17
+
18
+ if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0
19
+ temp0 = read_imageh(input, (j_2+0)*cols+i);
20
+ }
21
+ if((j_2+1)*cols+i*4+3 < rows*cols*16){
22
+ temp1 = read_imageh(input, (j_2+1)*cols+i);
23
+ }
24
+ if((j_2+2)*cols+i*4+3 < rows*cols*16){
25
+ temp2 = read_imageh(input, (j_2+2)*cols+i);
26
+ }
27
+ if((j_2+3)*cols+i*4+3 < rows*cols*16){
28
+ temp3 = read_imageh(input, (j_2+3)*cols+i);
29
+ }
30
+
31
+ write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding
32
+ write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
33
+ write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
34
+ write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
35
+ }