Reese Levine commited on
Commit
4b3da1d
·
1 Parent(s): 6dd510c

ggml: Add initial WebGPU backend (llama/14521)

Browse files
ggml/src/ggml-webgpu/CMakeLists.txt ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.13)
2
+
3
+ find_package(Python3 REQUIRED)
4
+
5
+ # Shader locations
6
+ set(SHADER_DIR "${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders")
7
+ set(SHADER_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
8
+ set(SHADER_HEADER "${SHADER_OUTPUT_DIR}/ggml-wgsl-shaders.hpp")
9
+ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
10
+
11
+ message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
12
+
13
+ # Find all WGSL files
14
+ file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
15
+
16
+ # Generate the header using a Python script
17
+ add_custom_command(
18
+ OUTPUT ${SHADER_HEADER}
19
+ COMMAND ${CMAKE_COMMAND} -E echo "Embedding WGSL shaders to ggml-wgsl-shaders.hpp"
20
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
21
+ COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
22
+ ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
23
+ --input "${SHADER_DIR}"
24
+ --output "${SHADER_HEADER}"
25
+ DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
26
+ VERBATIM
27
+ )
28
+
29
+ add_custom_target(generate_shaders DEPENDS ${SHADER_HEADER})
30
+
31
+ ggml_add_backend_library(ggml-webgpu
32
+ ggml-webgpu.cpp
33
+ ${SHADER_HEADER}
34
+ ../../include/ggml-webgpu.h
35
+ )
36
+
37
+ add_dependencies(ggml-webgpu generate_shaders)
38
+
39
+ if(EMSCRIPTEN)
40
+ set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
41
+
42
+ target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
43
+ target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
44
+ else()
45
+ find_package(Dawn REQUIRED)
46
+ set(DawnWebGPU_TARGET dawn::webgpu_dawn)
47
+ endif()
48
+
49
+ if (GGML_WEBGPU_DEBUG)
50
+ target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
51
+ endif()
52
+
53
+ target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR})
54
+ target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET})
ggml/src/ggml-webgpu/ggml-webgpu.cpp ADDED
@@ -0,0 +1,1190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ WebGPU backend implementation.
3
+ Note: Use ClangFormat to format this file.
4
+ */
5
+
6
+ #include "ggml-webgpu.h"
7
+
8
+ #include "ggml-backend-impl.h"
9
+ #include "ggml-impl.h"
10
+ #include "ggml-wgsl-shaders.hpp"
11
+
12
+ #include <webgpu/webgpu_cpp.h>
13
+
14
+ #include <condition_variable>
15
+ #include <cstring>
16
+ #include <iostream>
17
+ #include <mutex>
18
+ #include <string>
19
+ #include <vector>
20
+
21
+ #ifdef GGML_WEBGPU_DEBUG
22
+ # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
23
+ # define WEBGPU_DEBUG_BUF_ELEMS 32
24
+ #else
25
+ # define WEBGPU_LOG_DEBUG(msg) ((void) 0)
26
+ #endif // GGML_WEBGPU_DEBUG
27
+
28
+ /* Constants */
29
+
30
+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31
+ #define WEBGPU_MUL_MAT_WG_SIZE 64
32
+ #define WEBGPU_NUM_PARAM_BUFS 100
33
+ #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
34
+ #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
35
+ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
36
+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
37
+
38
+ /* End Constants */
39
+
40
+ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
41
+ static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
42
+
43
+ // Always returns the base offset of a tensor, regardless of views.
44
+ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
45
+ if (tensor->view_src) {
46
+ return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
47
+ }
48
+ return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
49
+ }
50
+
51
+ /* Struct definitions */
52
+
53
+ // Forward reference
54
+ static void ggml_webgpu_create_buffer(wgpu::Device & device,
55
+ wgpu::Buffer & buffer,
56
+ size_t size,
57
+ wgpu::BufferUsage usage,
58
+ const char * label);
59
+
60
+ struct webgpu_pool_bufs {
61
+ wgpu::Buffer host_buf;
62
+ wgpu::Buffer dev_buf;
63
+ };
64
+
65
+ // Holds a pool of parameter buffers for WebGPU operations
66
+ struct webgpu_buf_pool {
67
+ std::vector<webgpu_pool_bufs> free;
68
+
69
+ std::mutex mutex;
70
+
71
+ std::condition_variable cv;
72
+
73
+ void init(wgpu::Device device,
74
+ int num_bufs,
75
+ size_t buf_size,
76
+ wgpu::BufferUsage dev_buf_usage,
77
+ wgpu::BufferUsage host_buf_usage) {
78
+ for (int i = 0; i < num_bufs; i++) {
79
+ wgpu::Buffer host_buf;
80
+ wgpu::Buffer dev_buf;
81
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
82
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
83
+ free.push_back({ host_buf, dev_buf });
84
+ }
85
+ }
86
+
87
+ webgpu_pool_bufs alloc_bufs() {
88
+ std::unique_lock<std::mutex> lock(mutex);
89
+ cv.wait(lock, [this] { return !free.empty(); });
90
+ webgpu_pool_bufs bufs = free.back();
91
+ free.pop_back();
92
+ return bufs;
93
+ }
94
+
95
+ void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
96
+ std::lock_guard<std::mutex> lock(mutex);
97
+ free.insert(free.end(), bufs.begin(), bufs.end());
98
+ cv.notify_all();
99
+ }
100
+
101
+ void cleanup() {
102
+ std::lock_guard<std::mutex> lock(mutex);
103
+ for (auto & bufs : free) {
104
+ bufs.host_buf.Destroy();
105
+ bufs.dev_buf.Destroy();
106
+ }
107
+ free.clear();
108
+ }
109
+ };
110
+
111
+ // All the base objects needed to run operations on a WebGPU device
112
+ struct webgpu_context_struct {
113
+ wgpu::Instance instance;
114
+ wgpu::Adapter adapter;
115
+ wgpu::Device device;
116
+ wgpu::Queue queue;
117
+ wgpu::Limits limits;
118
+
119
+ std::recursive_mutex mutex;
120
+
121
+ bool device_init = false;
122
+
123
+ webgpu_buf_pool param_buf_pool;
124
+ webgpu_buf_pool set_rows_error_buf_pool;
125
+
126
+ wgpu::ComputePipeline memset_pipeline;
127
+ wgpu::ComputePipeline mul_mat_pipeline;
128
+ wgpu::ComputePipeline set_rows_pipeline;
129
+ wgpu::ComputePipeline cpy_pipeline;
130
+
131
+ size_t memset_bytes_per_thread;
132
+
133
+ // Staging buffer for reading data from the GPU
134
+ wgpu::Buffer get_tensor_staging_buf;
135
+
136
+ // Command buffers which need to be submitted
137
+ std::vector<wgpu::CommandBuffer> staged_command_bufs;
138
+
139
+ // Parameter buffers associated with the staged command buffers
140
+ std::vector<webgpu_pool_bufs> staged_param_bufs;
141
+ // Buffers associated with set_rows operations, used to store potential errors
142
+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
143
+
144
+ std::vector<wgpu::FutureWaitInfo> callback_futures;
145
+
146
+ #ifdef GGML_WEBGPU_DEBUG
147
+ wgpu::Buffer debug_host_buf;
148
+ wgpu::Buffer debug_dev_buf;
149
+ #endif
150
+ };
151
+
152
+ typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
153
+
154
+ struct ggml_backend_webgpu_reg_context {
155
+ webgpu_context webgpu_ctx;
156
+ size_t device_count;
157
+ const char * name;
158
+ };
159
+
160
+ struct ggml_backend_webgpu_device_context {
161
+ webgpu_context webgpu_ctx;
162
+ std::string device_name;
163
+ std::string device_desc;
164
+ };
165
+
166
+ struct ggml_backend_webgpu_context {
167
+ webgpu_context webgpu_ctx;
168
+ std::string name;
169
+ };
170
+
171
+ struct ggml_backend_webgpu_buffer_context {
172
+ webgpu_context webgpu_ctx;
173
+ wgpu::Buffer buffer;
174
+
175
+ ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
176
+ webgpu_ctx(std::move(ctx)),
177
+ buffer(std::move(buf)) {}
178
+ };
179
+
180
+ /* End struct definitions */
181
+
182
+ /* WebGPU object initializations */
183
+
184
+ static void ggml_webgpu_create_pipeline(wgpu::Device & device,
185
+ wgpu::ComputePipeline & pipeline,
186
+ const char * shader_code,
187
+ const char * label,
188
+ const std::vector<wgpu::ConstantEntry> & constants = {}) {
189
+ WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
190
+
191
+ wgpu::ShaderSourceWGSL shader_source;
192
+ shader_source.code = shader_code;
193
+
194
+ wgpu::ShaderModuleDescriptor shader_desc;
195
+ shader_desc.nextInChain = &shader_source;
196
+
197
+ wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
198
+
199
+ wgpu::ComputePipelineDescriptor pipeline_desc;
200
+ pipeline_desc.label = label;
201
+ pipeline_desc.compute.module = shader_module;
202
+ pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
203
+ pipeline_desc.layout = nullptr; // nullptr means auto layout
204
+ if (constants.size() > 0) {
205
+ pipeline_desc.compute.constants = constants.data();
206
+ pipeline_desc.compute.constantCount = constants.size();
207
+ }
208
+ pipeline = device.CreateComputePipeline(&pipeline_desc);
209
+ }
210
+
211
+ static void ggml_webgpu_create_buffer(wgpu::Device & device,
212
+ wgpu::Buffer & buffer,
213
+ size_t size,
214
+ wgpu::BufferUsage usage,
215
+ const char * label) {
216
+ WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
217
+
218
+ wgpu::BufferDescriptor buffer_desc;
219
+ buffer_desc.size = size;
220
+ buffer_desc.usage = usage;
221
+ buffer_desc.label = label;
222
+ buffer_desc.mappedAtCreation = false;
223
+
224
+ // TODO: error handling
225
+ buffer = device.CreateBuffer(&buffer_desc);
226
+ }
227
+
228
+ /** End WebGPU object initializations */
229
+
230
+ /** WebGPU Actions */
231
+
232
+ // Wait for the queue to finish processing all submitted work
233
+ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
234
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
235
+ if (ctx->callback_futures.empty()) {
236
+ // no existing callbacks, wait on queue submission
237
+ ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
238
+ wgpu::CallbackMode::AllowSpontaneous,
239
+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
240
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
241
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
242
+ }
243
+ }),
244
+ UINT64_MAX);
245
+ } else {
246
+ // existing callbacks, wait on them
247
+ ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
248
+ ctx->callback_futures.clear();
249
+ }
250
+ }
251
+
252
+ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
253
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
254
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
255
+ if (ctx->staged_command_bufs.empty()) {
256
+ // Nothing to submit
257
+ return;
258
+ }
259
+ ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
260
+
261
+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
262
+ if (ctx->staged_set_row_error_bufs.size() > 0) {
263
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
264
+ for (auto & error_bufs : ctx->staged_set_row_error_bufs) {
265
+ // Copy the error buffer to the host buffer
266
+ encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize());
267
+ }
268
+ wgpu::CommandBuffer commands = encoder.Finish();
269
+ ctx->queue.Submit(1, &commands);
270
+ }
271
+
272
+ ctx->staged_command_bufs.clear();
273
+ std::vector<webgpu_pool_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
274
+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs);
275
+
276
+ // Free the staged parameter buffers once the submission completes
277
+ wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
278
+ wgpu::CallbackMode::AllowSpontaneous,
279
+ [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
280
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
281
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
282
+ }
283
+ // Free the staged buffers
284
+ ctx->param_buf_pool.free_bufs(staged_param_bufs);
285
+ });
286
+ ctx->callback_futures.push_back({ p_f });
287
+
288
+ // Check for errrors in SET_ROWS operations
289
+ for (auto & error_bufs : staged_set_row_error_bufs) {
290
+ wgpu::Future f = error_bufs.host_buf.MapAsync(
291
+ wgpu::MapMode::Read,
292
+ 0,
293
+ error_bufs.host_buf.GetSize(),
294
+ wgpu::CallbackMode::AllowSpontaneous,
295
+ [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
296
+ if (status != wgpu::MapAsyncStatus::Success) {
297
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data);
298
+ } else {
299
+ const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
300
+ if (*error_data) {
301
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
302
+ }
303
+ // We can't unmap in here due to WebGPU reentrancy limitations.
304
+ ctx->set_rows_error_buf_pool.free_bufs({ error_bufs });
305
+ }
306
+ });
307
+ ctx->callback_futures.push_back({ f });
308
+ }
309
+ }
310
+
311
+ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
312
+ wgpu::Buffer & buffer,
313
+ wgpu::MapMode mode,
314
+ size_t offset,
315
+ size_t size) {
316
+ ctx->instance.WaitAny(buffer.MapAsync(mode,
317
+ offset,
318
+ size,
319
+ wgpu::CallbackMode::AllowSpontaneous,
320
+ [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
321
+ if (status != wgpu::MapAsyncStatus::Success) {
322
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
323
+ message.data);
324
+ }
325
+ }),
326
+ UINT64_MAX);
327
+ }
328
+
329
+ #ifdef GGML_WEBGPU_DEBUG
330
+ // This function adds debugging information to shaders, as WebGPU does not support printing directly.
331
+ // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
332
+ // debug statements in the shader, and then call this function after encoding the commands and submitting them.
333
+ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
334
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
335
+ encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
336
+ wgpu::CommandBuffer commands = encoder.Finish();
337
+ ctx->queue.Submit(1, &commands);
338
+
339
+ ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
340
+ const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
341
+ std::cout << "debug data:";
342
+ for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
343
+ std::cout << " " << i << ": " << debug_data[i];
344
+ }
345
+ std::cout << "\n";
346
+ ctx->debug_host_buf.Unmap();
347
+ }
348
+ #endif
349
+
350
+ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
351
+ wgpu::ComputePipeline & pipeline,
352
+ std::vector<uint32_t> params,
353
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
354
+ uint32_t wg_x,
355
+ bool submit_and_wait = false) {
356
+ webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
357
+
358
+ ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
359
+ uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
360
+ for (size_t i = 0; i < params.size(); i++) {
361
+ _params[i] = params[i];
362
+ };
363
+
364
+ params_bufs.host_buf.Unmap();
365
+
366
+ uint32_t params_bufs_binding_num = bind_group_entries.size();
367
+ bind_group_entries.push_back({ .binding = params_bufs_binding_num,
368
+ .buffer = params_bufs.dev_buf,
369
+ .offset = 0,
370
+ .size = params_bufs.dev_buf.GetSize() });
371
+
372
+ wgpu::BindGroupDescriptor bind_group_desc;
373
+ bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
374
+ bind_group_desc.entryCount = bind_group_entries.size();
375
+ bind_group_desc.entries = bind_group_entries.data();
376
+ wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
377
+
378
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
379
+ encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
380
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
381
+ pass.SetPipeline(pipeline);
382
+ pass.SetBindGroup(0, bind_group);
383
+ pass.DispatchWorkgroups(wg_x, 1, 1);
384
+ pass.End();
385
+ wgpu::CommandBuffer commands = encoder.Finish();
386
+ if (submit_and_wait) {
387
+ // Submit and wait immediately
388
+ ctx->queue.Submit(1, &commands);
389
+ ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
390
+ wgpu::CallbackMode::AllowSpontaneous,
391
+ [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
392
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
393
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
394
+ }
395
+ ctx->param_buf_pool.free_bufs({ params_bufs });
396
+ }),
397
+ UINT64_MAX);
398
+ } else {
399
+ // Lock the context mutex when pushing to the staging vectors.
400
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
401
+ // Enqueue commands and only submit if we have enough staged commands
402
+ ctx->staged_command_bufs.push_back(commands);
403
+ ctx->staged_param_bufs.push_back(params_bufs);
404
+ if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
405
+ ggml_backend_webgpu_submit_queue(ctx);
406
+ }
407
+ }
408
+ }
409
+
410
+ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
411
+ wgpu::Buffer & buf,
412
+ uint32_t value,
413
+ size_t offset,
414
+ size_t size) {
415
+ std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
416
+ std::vector<wgpu::BindGroupEntry> entries = {
417
+ { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
418
+ };
419
+ size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
420
+ uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
421
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
422
+ }
423
+
424
+ static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) {
425
+ return webgpu_tensor_offset(tensor) + tensor->view_offs;
426
+ }
427
+
428
+ static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) {
429
+ ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
430
+ return ctx->buffer;
431
+ }
432
+
433
+ /** End WebGPU Actions */
434
+
435
+ /** GGML Backend Interface */
436
+
437
+ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
438
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
439
+ return ctx->name.c_str();
440
+ }
441
+
442
+ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
443
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
444
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
445
+
446
+ // TODO: cleanup
447
+ GGML_UNUSED(ctx);
448
+ }
449
+
450
+ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
451
+ size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
452
+ // assumes power of 2 offset alignment
453
+ size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
454
+ // align to minimum offset alignment
455
+ src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
456
+ size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
457
+ size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
458
+ dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
459
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
460
+ std::vector<uint32_t> params = { ne,
461
+ (uint32_t) (src_misalignment / ggml_type_size(src->type)),
462
+ (uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
463
+ // Convert byte-strides to element-strides
464
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
465
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
466
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
467
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
468
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
469
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
470
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
471
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
472
+ // Logical shape — same for both tensors even if permuted
473
+ (uint32_t) src->ne[0],
474
+ (uint32_t) src->ne[1],
475
+ (uint32_t) src->ne[2],
476
+ (uint32_t) src->ne[3] };
477
+
478
+ std::vector<wgpu::BindGroupEntry> entries = {
479
+ { .binding = 0,
480
+ .buffer = ggml_backend_webgpu_tensor_buf(src),
481
+ .offset = src_offset,
482
+ .size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
483
+ ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
484
+ { .binding = 1,
485
+ .buffer = ggml_backend_webgpu_tensor_buf(dst),
486
+ .offset = dst_offset,
487
+ .size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
488
+ ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
489
+ };
490
+
491
+ size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
492
+ uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
493
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
494
+ }
495
+
496
+ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
497
+ // For set rows specifically, we need to check if src and idx are empty tensors.
498
+ if (ggml_is_empty(src) || ggml_is_empty(idx)) {
499
+ return;
500
+ }
501
+
502
+ webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
503
+ if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
504
+ error_bufs.host_buf.Unmap();
505
+ }
506
+
507
+ size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
508
+ // assumes power of 2 offset alignment
509
+ size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
510
+ // align to minimum offset alignment
511
+ src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
512
+ size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx);
513
+ size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
514
+ idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
515
+ size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
516
+ size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
517
+ dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
518
+
519
+ std::vector<uint32_t> params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)),
520
+ (uint32_t) (idx_misalignment / ggml_type_size(idx->type)),
521
+ (uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
522
+ // Convert byte-strides to element-strides
523
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
524
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
525
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
526
+ (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
527
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
528
+ (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
529
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
530
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
531
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
532
+ // Shape of src
533
+ (uint32_t) src->ne[0],
534
+ (uint32_t) src->ne[1],
535
+ (uint32_t) src->ne[2],
536
+ (uint32_t) src->ne[3],
537
+ // Shape of idx
538
+ (uint32_t) (idx->ne[1]),
539
+ (uint32_t) (idx->ne[2]) };
540
+
541
+ std::vector<wgpu::BindGroupEntry> entries = {
542
+ { .binding = 0,
543
+ .buffer = ggml_backend_webgpu_tensor_buf(src),
544
+ .offset = ggml_backend_webgpu_tensor_offset(src),
545
+ .size = ggml_nbytes(src) },
546
+ { .binding = 1,
547
+ .buffer = ggml_backend_webgpu_tensor_buf(idx),
548
+ .offset = ggml_backend_webgpu_tensor_offset(idx),
549
+ .size = ggml_nbytes(idx) },
550
+ { .binding = 2,
551
+ .buffer = ggml_backend_webgpu_tensor_buf(dst),
552
+ .offset = ggml_backend_webgpu_tensor_offset(dst),
553
+ .size = ggml_nbytes(dst) },
554
+ { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
555
+ };
556
+
557
+ size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
558
+ uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
559
+
560
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
561
+ ctx->staged_set_row_error_bufs.push_back(error_bufs);
562
+
563
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
564
+ }
565
+
566
+ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
567
+ std::vector<uint32_t> params = {
568
+ (uint32_t) dst->ne[1], // number of rows in result (M)
569
+ (uint32_t) dst->ne[0], // number of columns in result (N)
570
+ (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
571
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1
572
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1
573
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2
574
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2
575
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3
576
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3
577
+ (uint32_t) src0->ne[2], // batch size in dimension 2
578
+ (uint32_t) src0->ne[3], // batch size in dimension 3
579
+ (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
580
+ (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
581
+ };
582
+
583
+ std::vector<wgpu::BindGroupEntry> entries = {
584
+ { .binding = 0,
585
+ .buffer = ggml_backend_webgpu_tensor_buf(src0),
586
+ .offset = ggml_backend_webgpu_tensor_offset(src0),
587
+ .size = ggml_nbytes(src0) },
588
+ { .binding = 1,
589
+ .buffer = ggml_backend_webgpu_tensor_buf(src1),
590
+ .offset = ggml_backend_webgpu_tensor_offset(src1),
591
+ .size = ggml_nbytes(src1) },
592
+ { .binding = 2,
593
+ .buffer = ggml_backend_webgpu_tensor_buf(dst),
594
+ .offset = ggml_backend_webgpu_tensor_offset(dst),
595
+ .size = ggml_nbytes(dst) }
596
+ };
597
+
598
+ uint32_t wg_x =
599
+ (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
600
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
601
+ }
602
+
603
+ // Returns true if node has enqueued work into the queue, false otherwise
604
+ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
605
+ if (ggml_is_empty(node)) {
606
+ return false;
607
+ }
608
+ WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
609
+
610
+ ggml_tensor * src0 = node->src[0];
611
+ ggml_tensor * src1 = node->src[1];
612
+
613
+ switch (node->op) {
614
+ // no-ops
615
+ case GGML_OP_NONE:
616
+ case GGML_OP_VIEW:
617
+ case GGML_OP_PERMUTE:
618
+ return false;
619
+ case GGML_OP_CPY:
620
+ {
621
+ ggml_webgpu_cpy(ctx, src0, node);
622
+ break;
623
+ }
624
+ case GGML_OP_SET_ROWS:
625
+ {
626
+ ggml_webgpu_set_rows(ctx, src0, src1, node);
627
+ break;
628
+ }
629
+ case GGML_OP_MUL_MAT:
630
+ {
631
+ ggml_webgpu_mul_mat(ctx, src0, src1, node);
632
+ break;
633
+ }
634
+ default:
635
+ return false;
636
+ }
637
+ return true;
638
+ }
639
+
640
+ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
641
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
642
+
643
+ ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
644
+ webgpu_context ctx = backend_ctx->webgpu_ctx;
645
+
646
+ for (int i = 0; i < cgraph->n_nodes; i++) {
647
+ ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
648
+ }
649
+
650
+ ggml_backend_webgpu_submit_queue(ctx);
651
+ ggml_backend_webgpu_wait_on_submission(ctx);
652
+
653
+ return GGML_STATUS_SUCCESS;
654
+ }
655
+
656
+ static ggml_backend_i ggml_backend_webgpu_i = {
657
+ /* .get_name = */ ggml_backend_webgpu_name,
658
+ /* .free = */ ggml_backend_webgpu_free,
659
+ /* .set_tensor_async = */ NULL,
660
+ /* .get_tensor_async = */ NULL,
661
+ /* .cpy_tensor_async = */ NULL,
662
+ /* .synchronize = */ NULL,
663
+ /* .graph_plan_create = */ NULL,
664
+ /* .graph_plan_free = */ NULL,
665
+ /* .graph_plan_update = */ NULL,
666
+ /* .graph_plan_compute = */ NULL,
667
+ /* .graph_compute = */ ggml_backend_webgpu_graph_compute,
668
+ /* .event_record = */ NULL,
669
+ /* .event_wait = */ NULL,
670
+ };
671
+
672
+ /* End GGML Backend Interface */
673
+
674
+ /* GGML Backend Buffer Interface */
675
+
676
+ static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
677
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()");
678
+ ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
679
+ ctx->buffer.Destroy();
680
+ }
681
+
682
+ // Returns the "fake" base pointer.
683
+ static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
684
+ GGML_UNUSED(buffer);
685
+ return webgpu_ptr_base;
686
+ }
687
+
688
+ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
689
+ ggml_tensor * tensor,
690
+ uint8_t value,
691
+ size_t offset,
692
+ size_t size) {
693
+ if (size == 0) {
694
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
695
+ return;
696
+ }
697
+
698
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
699
+ << offset << ", " << size << ")");
700
+
701
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
702
+
703
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
704
+
705
+ // This is a trick to set all bytes of a u32 to the same 1 byte value.
706
+ uint32_t val32 = (uint32_t) value * 0x01010101;
707
+ ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
708
+ }
709
+
710
+ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
711
+ ggml_tensor * tensor,
712
+ const void * data,
713
+ size_t offset,
714
+ size_t size) {
715
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
716
+ << offset << ", " << size << ")");
717
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
718
+ webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
719
+
720
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
721
+
722
+ webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
723
+
724
+ if (size % 4 != 0) {
725
+ // If size is not a multiple of 4, we need to memset the remaining bytes
726
+ size_t remaining_size = size % 4;
727
+
728
+ // pack the remaining bytes into a uint32_t
729
+ uint32_t val32 = 0;
730
+
731
+ for (size_t i = 0; i < remaining_size; i++) {
732
+ ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
733
+ }
734
+ // memset the remaining bytes
735
+ ggml_backend_webgpu_buffer_memset(
736
+ webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
737
+ } else {
738
+ // wait for WriteBuffer to complete
739
+ ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
740
+ }
741
+ }
742
+
743
+ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
744
+ const ggml_tensor * tensor,
745
+ void * data,
746
+ size_t offset,
747
+ size_t size) {
748
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
749
+ << offset << ", " << size << ")");
750
+
751
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
752
+ webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
753
+ wgpu::Device device = webgpu_ctx->device;
754
+
755
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
756
+
757
+ size_t final_size = size;
758
+ if (size % 4 != 0) {
759
+ // If size is not a multiple of 4, we need to round it up to the next multiple of 4
760
+ final_size = size + (4 - (size % 4));
761
+ }
762
+
763
+ std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
764
+
765
+ if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
766
+ // Create a new staging buffer if it doesn't exist or is too small
767
+ if (webgpu_ctx->get_tensor_staging_buf) {
768
+ webgpu_ctx->get_tensor_staging_buf.Destroy();
769
+ }
770
+ ggml_webgpu_create_buffer(device,
771
+ webgpu_ctx->get_tensor_staging_buf,
772
+ final_size,
773
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
774
+ "get_tensor_staging_buf");
775
+ }
776
+
777
+ // Copy the data from the buffer to the staging buffer
778
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
779
+ encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
780
+ wgpu::CommandBuffer commands = encoder.Finish();
781
+
782
+ // Submit the command buffer to the queue
783
+ webgpu_ctx->queue.Submit(1, &commands);
784
+
785
+ // Map the staging buffer to read the data
786
+ ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
787
+ // Must specify size here since the staging buffer might be larger than the tensor size
788
+ const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
789
+
790
+ // Copy the data from the mapped range to the output buffer
791
+ std::memcpy(data, mapped_range, size);
792
+ webgpu_ctx->get_tensor_staging_buf.Unmap();
793
+ }
794
+
795
+ static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
796
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
797
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
798
+ ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
799
+ }
800
+
801
+ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
802
+ /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
803
+ /* .get_base = */ ggml_backend_webgpu_buffer_get_base,
804
+ /* .init_tensor = */ NULL, // TODO: optional, needed?
805
+ /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
806
+ /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
807
+ /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
808
+ /* .cpy_tensor = */ NULL, // TODO: optional, implement this
809
+ /* .clear = */ ggml_backend_webgpu_buffer_clear,
810
+ /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
811
+ };
812
+
813
+ /* End GGML Backend Buffer Interface */
814
+
815
+ /* GGML Backend Buffer Type Interface */
816
+
817
+ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
818
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
819
+ return ctx->device_name.c_str();
820
+ }
821
+
822
+ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
823
+ size_t size) {
824
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
825
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
826
+
827
+ wgpu::Buffer buf;
828
+ ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
829
+ buf,
830
+ size,
831
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
832
+ "allocated_buffer");
833
+
834
+ ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
835
+
836
+ return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
837
+ }
838
+
839
+ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
840
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
841
+ return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
842
+ }
843
+
844
+ // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
845
+ static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
846
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
847
+ return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
848
+ }
849
+
850
+ /* End GGML Backend Buffer Type Interface */
851
+
852
+ /* GGML Backend Device Interface */
853
+
854
+ static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
855
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
856
+ return ctx->device_name.c_str();
857
+ }
858
+
859
+ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
860
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
861
+ return ctx->device_desc.c_str();
862
+ }
863
+
864
+ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
865
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
866
+ // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
867
+ *free = ctx->webgpu_ctx->limits.maxBufferSize;
868
+ *total = ctx->webgpu_ctx->limits.maxBufferSize;
869
+ }
870
+
871
+ static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
872
+ GGML_UNUSED(dev);
873
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
874
+ }
875
+
876
+ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
877
+ props->name = ggml_backend_webgpu_device_get_name(dev);
878
+ props->description = ggml_backend_webgpu_device_get_description(dev);
879
+ props->type = ggml_backend_webgpu_device_get_type(dev);
880
+ ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
881
+ props->caps = {
882
+ /* .async = */ false,
883
+ /* .host_buffer = */ false,
884
+ /* .buffer_from_host_ptr = */ false,
885
+ /* .events = */ false,
886
+ };
887
+ }
888
+
889
+ static ggml_guid_t ggml_backend_webgpu_guid(void) {
890
+ static const char * guid_str = "__ggml_webgpu :)";
891
+ return reinterpret_cast<ggml_guid_t>((void *) guid_str);
892
+ }
893
+
894
+ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
895
+ // we use the maximum workgroup size for the memset pipeline
896
+ size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
897
+ size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
898
+ // Size the bytes_per_thread so that the largest buffer size can be handled
899
+ webgpu_ctx->memset_bytes_per_thread =
900
+ (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
901
+ std::vector<wgpu::ConstantEntry> constants(2);
902
+ constants[0].key = "wg_size";
903
+ constants[0].value = max_wg_size;
904
+ constants[1].key = "bytes_per_thread";
905
+ constants[1].value = webgpu_ctx->memset_bytes_per_thread;
906
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
907
+ }
908
+
909
+ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
910
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
911
+ }
912
+
913
+ static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
914
+ std::vector<wgpu::ConstantEntry> constants(1);
915
+ constants[0].key = "wg_size";
916
+ constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
917
+ ggml_webgpu_create_pipeline(
918
+ webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
919
+ }
920
+
921
+ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
922
+ std::vector<wgpu::ConstantEntry> constants(1);
923
+ constants[0].key = "wg_size";
924
+ constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
925
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
926
+ }
927
+
928
+ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
929
+ GGML_UNUSED(params);
930
+
931
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
932
+
933
+ ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
934
+ webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
935
+
936
+ // Multiple threads may try to initialize the device
937
+ std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
938
+ if (!webgpu_ctx->device_init) {
939
+ // Initialize device
940
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
941
+ wgpu::FeatureName::ImplicitDeviceSynchronization };
942
+ wgpu::DeviceDescriptor dev_desc;
943
+ dev_desc.requiredLimits = &webgpu_ctx->limits;
944
+ dev_desc.requiredFeatures = required_features.data();
945
+ dev_desc.requiredFeatureCount = required_features.size();
946
+ dev_desc.SetDeviceLostCallback(
947
+ wgpu::CallbackMode::AllowSpontaneous,
948
+ [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
949
+ GGML_UNUSED(device);
950
+ GGML_LOG_ERROR(
951
+ "ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
952
+ });
953
+ dev_desc.SetUncapturedErrorCallback(
954
+ [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
955
+ GGML_UNUSED(device);
956
+ GGML_LOG_ERROR(
957
+ "ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
958
+ });
959
+ webgpu_ctx->instance.WaitAny(
960
+ webgpu_ctx->adapter.RequestDevice(
961
+ &dev_desc,
962
+ wgpu::CallbackMode::AllowSpontaneous,
963
+ [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
964
+ if (status != wgpu::RequestDeviceStatus::Success) {
965
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
966
+ return;
967
+ }
968
+ webgpu_ctx->device = std::move(device);
969
+ }),
970
+ UINT64_MAX);
971
+ GGML_ASSERT(webgpu_ctx->device != nullptr);
972
+
973
+ // Initialize (compute) queue
974
+ webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
975
+
976
+ // Create buffer pool for shader parameters
977
+ webgpu_ctx->param_buf_pool.init(webgpu_ctx->device,
978
+ WEBGPU_NUM_PARAM_BUFS,
979
+ WEBGPU_PARAMS_BUF_SIZE_BYTES,
980
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
981
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
982
+ webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device,
983
+ WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
984
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
985
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
986
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
987
+
988
+ ggml_webgpu_init_memset_pipeline(webgpu_ctx);
989
+ ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
990
+ ggml_webgpu_init_set_rows_pipeline(webgpu_ctx);
991
+ ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
992
+
993
+ #ifdef GGML_WEBGPU_DEBUG
994
+ // Initialize debug buffers
995
+ ggml_webgpu_create_buffer(webgpu_ctx->device,
996
+ webgpu_ctx->debug_host_buf,
997
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
998
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
999
+ "debug_host_buf");
1000
+ ggml_webgpu_create_buffer(webgpu_ctx->device,
1001
+ webgpu_ctx->debug_dev_buf,
1002
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1003
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
1004
+ "debug_dev_buf");
1005
+ #endif
1006
+ webgpu_ctx->device_init = true;
1007
+ }
1008
+
1009
+ static ggml_backend_webgpu_context backend_ctx;
1010
+ backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
1011
+ backend_ctx.webgpu_ctx = webgpu_ctx;
1012
+
1013
+ // See GGML Backend Interface section
1014
+ static ggml_backend backend = {
1015
+ /* .guid = */ ggml_backend_webgpu_guid(),
1016
+ /* .interface = */ ggml_backend_webgpu_i,
1017
+ /* .device = */ dev,
1018
+ /* .context = */ &backend_ctx,
1019
+ };
1020
+
1021
+ return &backend;
1022
+ }
1023
+
1024
+ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
1025
+ // See GGML Backend Buffer Type Interface section
1026
+ static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
1027
+ /* .iface = */ {
1028
+ /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
1029
+ /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
1030
+ /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
1031
+ /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
1032
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1033
+ /* .is_host = */ NULL, // defaults to false
1034
+ },
1035
+ /* .device = */
1036
+ dev,
1037
+ /* .context = */ NULL,
1038
+ };
1039
+
1040
+ return &ggml_backend_webgpu_buffer_type;
1041
+ }
1042
+
1043
+ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1044
+ GGML_UNUSED(dev);
1045
+ return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
1046
+ }
1047
+
1048
+ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
1049
+ GGML_UNUSED(dev);
1050
+
1051
+ switch (op->op) {
1052
+ case GGML_OP_NONE:
1053
+ case GGML_OP_VIEW:
1054
+ case GGML_OP_PERMUTE:
1055
+ return true;
1056
+ case GGML_OP_CPY | GGML_OP_SET_ROWS:
1057
+ return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
1058
+ case GGML_OP_MUL_MAT:
1059
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1060
+ default:
1061
+ return false;
1062
+ }
1063
+ }
1064
+
1065
+ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
1066
+ /* .get_name = */ ggml_backend_webgpu_device_get_name,
1067
+ /* .get_description = */ ggml_backend_webgpu_device_get_description,
1068
+ /* .get_memory = */ ggml_backend_webgpu_device_get_memory,
1069
+ /* .get_type = */ ggml_backend_webgpu_device_get_type,
1070
+ /* .get_props = */ ggml_backend_webgpu_device_get_props,
1071
+ /* .init_backend = */ ggml_backend_webgpu_device_init,
1072
+ /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
1073
+ /* .get_host_buffer_type = */ NULL,
1074
+ /* .buffer_from_host_ptr = */ NULL,
1075
+ /* .supports_op = */ ggml_backend_webgpu_device_supports_op,
1076
+ /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
1077
+ /* .offload_op = */ NULL,
1078
+ /* .event_new = */ NULL,
1079
+ /* .event_free = */ NULL,
1080
+ /* .event_synchronize = */ NULL,
1081
+ };
1082
+
1083
+ /* End GGML Backend Device Interface */
1084
+
1085
+ /* GGML Backend Registration Interface */
1086
+
1087
+ static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
1088
+ ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
1089
+ return ctx->name;
1090
+ }
1091
+
1092
+ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
1093
+ ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
1094
+ return ctx->device_count;
1095
+ }
1096
+
1097
+ // TODO: Does this need to be thread safe? Is it only called once?
1098
+ // Only one device is supported for now
1099
+ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1100
+ GGML_ASSERT(index == 0);
1101
+ WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
1102
+
1103
+ ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
1104
+
1105
+ webgpu_context ctx = reg_ctx->webgpu_ctx;
1106
+
1107
+ wgpu::RequestAdapterOptions options = {};
1108
+ auto callback =
1109
+ [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) {
1110
+ if (status != wgpu::RequestAdapterStatus::Success) {
1111
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
1112
+ return;
1113
+ }
1114
+ *static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
1115
+ };
1116
+ void * userdata = &ctx->adapter;
1117
+ ctx->instance.WaitAny(
1118
+ ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX);
1119
+ GGML_ASSERT(ctx->adapter != nullptr);
1120
+
1121
+ ctx->adapter.GetLimits(&ctx->limits);
1122
+
1123
+ wgpu::AdapterInfo info{};
1124
+ ctx->adapter.GetInfo(&info);
1125
+
1126
+ static ggml_backend_webgpu_device_context device_ctx;
1127
+ device_ctx.webgpu_ctx = ctx;
1128
+ device_ctx.device_name = GGML_WEBGPU_NAME;
1129
+ device_ctx.device_desc = std::string(info.description.data);
1130
+
1131
+ GGML_LOG_INFO(
1132
+ "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
1133
+ "device_desc: %s\n",
1134
+ info.vendorID,
1135
+ info.vendor.data,
1136
+ info.architecture.data,
1137
+ info.deviceID,
1138
+ info.device.data,
1139
+ info.description.data);
1140
+
1141
+ // See GGML Backend Device Interface section
1142
+ static ggml_backend_device device = {
1143
+ /* .iface = */ ggml_backend_webgpu_device_i,
1144
+ /* .reg = */ reg,
1145
+ /* .context = */ &device_ctx,
1146
+ };
1147
+ return &device;
1148
+ }
1149
+
1150
+ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
1151
+ /* .get_name = */ ggml_backend_webgpu_reg_get_name,
1152
+ /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
1153
+ /* .get_device = */ ggml_backend_webgpu_reg_get_device,
1154
+ /* .get_proc_address = */ NULL,
1155
+ };
1156
+
1157
+ /* End GGML Backend Registration Interface */
1158
+
1159
+ ggml_backend_reg_t ggml_backend_webgpu_reg() {
1160
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
1161
+
1162
+ webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
1163
+
1164
+ static ggml_backend_webgpu_reg_context ctx;
1165
+ ctx.webgpu_ctx = webgpu_ctx;
1166
+ ctx.name = GGML_WEBGPU_NAME;
1167
+ ctx.device_count = 1;
1168
+
1169
+ wgpu::InstanceDescriptor instance_descriptor{};
1170
+ std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
1171
+ instance_descriptor.requiredFeatures = instance_features.data();
1172
+ instance_descriptor.requiredFeatureCount = instance_features.size();
1173
+ webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
1174
+ GGML_ASSERT(webgpu_ctx->instance != nullptr);
1175
+
1176
+ static ggml_backend_reg reg = {
1177
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
1178
+ /* .iface = */ ggml_backend_webgpu_reg_i,
1179
+ /* .context = */ &ctx,
1180
+ };
1181
+ return &reg;
1182
+ }
1183
+
1184
+ ggml_backend_t ggml_backend_webgpu_init(void) {
1185
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
1186
+
1187
+ return ggml_backend_webgpu_device_init(dev, nullptr);
1188
+ }
1189
+
1190
+ GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)
ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ enable f16;
2
+
3
+ @group(0) @binding(0)
4
+ var<storage, read_write> src: array<f32>;
5
+
6
+ @group(0) @binding(1)
7
+ var<storage, read_write> dst: array<f16>;
8
+
9
+ struct Params {
10
+ ne: u32, // total number of elements
11
+ offset_src: u32, // in elements
12
+ offset_dst: u32, // in elements
13
+
14
+ // Strides (in elements) — may be permuted
15
+ stride_src0: u32,
16
+ stride_src1: u32,
17
+ stride_src2: u32,
18
+ stride_src3: u32,
19
+
20
+ stride_dst0: u32,
21
+ stride_dst1: u32,
22
+ stride_dst2: u32,
23
+ stride_dst3: u32,
24
+
25
+ // Logical shape (same for both tensors)
26
+ ne0: u32,
27
+ ne1: u32,
28
+ ne2: u32,
29
+ ne3: u32,
30
+ };
31
+
32
+ @group(0) @binding(2)
33
+ var<uniform> params: Params;
34
+
35
+ override wg_size: u32;
36
+ @compute @workgroup_size(wg_size)
37
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
38
+ if (gid.x >= params.ne) {
39
+ return;
40
+ }
41
+
42
+ var i = gid.x;
43
+
44
+ let i3 = i / (params.ne2 * params.ne1 * params.ne0);
45
+ i = i % (params.ne2 * params.ne1 * params.ne0);
46
+
47
+ let i2 = i / (params.ne1 * params.ne0);
48
+ i = i % (params.ne1 * params.ne0);
49
+
50
+ let i1 = i / params.ne0;
51
+ let i0 = i % params.ne0;
52
+
53
+ let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
54
+ i2 * params.stride_src2 + i3 * params.stride_src3;
55
+
56
+ let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
57
+ i2 * params.stride_dst2 + i3 * params.stride_dst3;
58
+
59
+ dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
60
+ }
ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+
5
+ def escape_triple_quotes(wgsl):
6
+ # Simple defense in case of embedded """
7
+ return wgsl.replace('"""', '\\"""')
8
+
9
+
10
+ def to_cpp_string_literal(varname, content):
11
+ return f'const char* wgsl_{varname} = R"({content})";\n'
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--input', required=True)
17
+ parser.add_argument('--output', required=True)
18
+ args = parser.parse_args()
19
+
20
+ with open(args.output, 'w', encoding='utf-8') as out:
21
+ out.write("// Auto-generated shader embedding \n\n")
22
+ for fname in sorted(os.listdir(args.input)):
23
+ if not fname.endswith('.wgsl'):
24
+ continue
25
+ shader_path = os.path.join(args.input, fname)
26
+ varname = os.path.splitext(fname)[0]
27
+ with open(shader_path, 'r', encoding='utf-8') as f:
28
+ content = f.read()
29
+ content = escape_triple_quotes(content)
30
+ out.write(to_cpp_string_literal(varname, content))
31
+ out.write('\n')
32
+
33
+
34
+ if __name__ == '__main__':
35
+ main()
ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @group(0) @binding(0)
2
+ var<storage, read_write> output_buffer: array<u32>;
3
+
4
+ struct Params {
5
+ offset: u32, // in bytes
6
+ size: u32, // in bytes
7
+ value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
8
+ };
9
+
10
+ @group(0) @binding(1)
11
+ var<uniform> params: Params;
12
+
13
+ override wg_size: u32;
14
+ override bytes_per_thread: u32;
15
+
16
+ @compute @workgroup_size(wg_size)
17
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
18
+ let i = gid.x * bytes_per_thread;
19
+ let start = params.offset;
20
+ let end = params.offset + params.size;
21
+
22
+ for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
23
+ let byte_index = start + i + j;
24
+ if (byte_index + 4u <= end) {
25
+ output_buffer[(byte_index >> 2u)] = params.value;
26
+ } else {
27
+ // Handle tail (unaligned)
28
+ for (var k: u32 = 0u; k < 4u; k = k + 1u) {
29
+ let idx = byte_index + k;
30
+ if (idx < end) {
31
+ let word_idx = idx >> 2u;
32
+ let byte_offset = (idx & 3u) * 8u;
33
+ let mask = ~(0xffu << byte_offset);
34
+ let existing = output_buffer[word_idx];
35
+ output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
36
+ }
37
+ }
38
+ }
39
+ }
40
+ }
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ struct MulMatParams {
2
+ m: u32,
3
+ n: u32,
4
+ k: u32,
5
+ // all strides are in elements
6
+ stride_01: u32,
7
+ stride_11: u32,
8
+ stride_02: u32,
9
+ stride_12: u32,
10
+ stride_03: u32,
11
+ stride_13: u32,
12
+
13
+ bs02: u32,
14
+ bs03: u32,
15
+ broadcast2: u32,
16
+ broadcast3: u32
17
+ };
18
+
19
+ @group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
20
+ @group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
21
+ @group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
22
+
23
+ @group(0) @binding(3) var<uniform> params: MulMatParams;
24
+
25
+ @compute @workgroup_size(64)
26
+ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
27
+ let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
28
+ if (global_id.x >= total) {
29
+ return;
30
+ }
31
+
32
+ let dst2_stride = params.m * params.n;
33
+ let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
34
+
35
+ let dst3_idx = global_id.x / dst3_stride;
36
+ let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
37
+ let src13_idx = dst3_idx; // src1 is not broadcast
38
+ let dst3_rem = global_id.x % dst3_stride;
39
+
40
+ let dst2_idx = dst3_rem / dst2_stride;
41
+ let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
42
+ let src12_idx = dst2_idx; // src1 is not broadcast
43
+
44
+ let dst2_rem = dst3_rem % dst2_stride;
45
+
46
+ let row = dst2_rem / params.n; // output row
47
+ let col = dst2_rem % params.n; // output column
48
+
49
+ var sum = 0.0;
50
+ for (var i: u32 = 0u; i < params.k; i = i + 1u) {
51
+ let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
52
+ let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
53
+ sum = sum + src0[src0_idx] * src1[src1_idx];
54
+ }
55
+ dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
56
+ }
ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ enable f16;
2
+
3
+ @group(0) @binding(0)
4
+ var<storage, read_write> src: array<f32>;
5
+
6
+ @group(0) @binding(1)
7
+ var<storage, read_write> idx: array<u32>;
8
+
9
+ @group(0) @binding(2)
10
+ var<storage, read_write> dst: array<f16>;
11
+
12
+ @group(0) @binding(3)
13
+ var<storage, read_write> error: atomic<u32>;
14
+
15
+ struct Params {
16
+ offset_src: u32, // in elements
17
+ offset_idx: u32, // in elements
18
+ offset_dst: u32, // in elements
19
+
20
+ // Strides (in elements)
21
+ stride_src1: u32,
22
+ stride_src2: u32,
23
+ stride_src3: u32,
24
+
25
+ stride_idx0: u32,
26
+ stride_idx1: u32,
27
+ stride_idx2: u32,
28
+
29
+ stride_dst1: u32,
30
+ stride_dst2: u32,
31
+ stride_dst3: u32,
32
+
33
+ // Shape of src
34
+ ne0: u32,
35
+ n_rows: u32,
36
+ ne2: u32,
37
+ ne3: u32,
38
+
39
+ // Shape of idx
40
+ idx1: u32,
41
+ idx2: u32,
42
+ };
43
+
44
+ @group(0) @binding(4)
45
+ var<uniform> params: Params;
46
+
47
+ override wg_size: u32;
48
+ @compute @workgroup_size(wg_size)
49
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
50
+ if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
51
+ return;
52
+ }
53
+ var i = gid.x;
54
+ let i_src3 = i / (params.ne2 * params.n_rows);
55
+ let i_dst3 = i / (params.ne2 * 3);
56
+
57
+ i = i % (params.ne2 * params.n_rows);
58
+ let i_src2 = i / params.n_rows;
59
+ let i_src1 = i % params.n_rows;
60
+
61
+ let i_idx2 = i_src3 % params.idx2;
62
+ let i_idx1 = i_src2 % params.idx1;
63
+ let i_idx0 = i_src1;
64
+
65
+ let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
66
+
67
+ let idx_high_val = idx[idx_high];
68
+ let idx_low_val = idx[idx_high + 1];
69
+
70
+ if (idx_low_val != 0) {
71
+ // Upper bits of index are not zero, output will be incorrect
72
+ atomicStore(&error, 1);
73
+ return;
74
+ }
75
+
76
+ let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
77
+ let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
78
+
79
+ for (var i: u32 = 0; i < params.ne0; i++) {
80
+ dst[i_dst_row + i] = f16(src[i_src_row + i]);
81
+ }
82
+ }