diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5d3169904..44e3bf822 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2832,22 +2832,107 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str return GGML_STATUS_SUCCESS; } +struct ggml_backend_webgpu_event_context { + webgpu_global_context global_ctx; + wgpu::Future future; + bool recorded = false; +}; + +static ggml_backend_event_t ggml_backend_webgpu_device_event_new(ggml_backend_dev_t device) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) device->context; + + auto * event_ctx = new ggml_backend_webgpu_event_context(); + event_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + + auto * event = new ggml_backend_event; + event->device = device; + event->context = event_ctx; + return event; +} + +static void ggml_backend_webgpu_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + delete static_cast(event->context); + delete event; +} + +static void ggml_backend_webgpu_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + if (!event_ctx->recorded) { + return; + } + wgpu::WaitStatus status = + event_ctx->global_ctx->instance.WaitAny(event_ctx->future, WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + if (status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: event_synchronize timed out after %u ms\n", WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + event_ctx->recorded = false; +} + +static void ggml_backend_webgpu_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + + event_ctx->future = backend_ctx->webgpu_ctx->global_ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {}); + event_ctx->recorded = true; +} + +static void ggml_backend_webgpu_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_UNUSED(backend); + ggml_backend_webgpu_device_event_synchronize(nullptr, event); +} + +static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_UNUSED(backend); + auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + + // Write aligned portion + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + + if (size % 4 != 0) { + // If size is not a multiple of 4, we need to memset the remaining bytes + size_t remaining_size = size % 4; + + // pack the remaining bytes into a uint32_t + uint32_t val32 = 0; + + for (size_t i = 0; i < remaining_size; i++) { + ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; + } + // memset the remaining bytes + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); + } +} + +static void ggml_backend_webgpu_synchronize(ggml_backend_t backend) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_wait_queue(backend_ctx->webgpu_ctx->global_ctx); +} + static ggml_backend_i ggml_backend_webgpu_i = { /* .get_name = */ ggml_backend_webgpu_name, /* .free = */ ggml_backend_webgpu_free, - /* .set_tensor_async = */ NULL, + /* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async, /* .get_tensor_async = */ NULL, /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .synchronize = */ ggml_backend_webgpu_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_webgpu_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_webgpu_event_record, + /* .event_wait = */ ggml_backend_webgpu_event_wait, /* .graph_optimize = */ NULL, }; @@ -3810,9 +3895,9 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { /* .supports_op = */ ggml_backend_webgpu_device_supports_op, /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft, /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_webgpu_device_event_new, + /* .event_free = */ ggml_backend_webgpu_device_event_free, + /* .event_synchronize = */ ggml_backend_webgpu_device_event_synchronize, }; /* End GGML Backend Device Interface */