[WebGPU] Plug memory leaks and free resources on shutdown (#19315)

* Fix memory leaks in shader lib, backend, backend_context, buffer_context, and webgpu_buf_pool

* Free pools

* Cleanup

* More cleanup

* Run clang-format

* Fix arg-parser and tokenizer test errors that free an unallocated buffer

* Fix device lost callback to not print on device teardown

* Fix include and run clang-format

* remove unused unused

* Update binary ops

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
Nikhil Jain
2026-02-10 08:04:00 -08:00
committed by GitHub
parent fc0fe40049
commit 57487a64c8
2 changed files with 94 additions and 76 deletions
+53 -36
View File
@@ -186,11 +186,17 @@ struct webgpu_buf_pool {
void cleanup() {
std::lock_guard<std::mutex> lock(mutex);
for (auto & bufs : free) {
bufs.host_buf.Destroy();
bufs.dev_buf.Destroy();
if (bufs.host_buf) {
bufs.host_buf.Destroy();
}
if (bufs.dev_buf) {
bufs.dev_buf.Destroy();
}
}
free.clear();
}
~webgpu_buf_pool() { this->cleanup(); }
};
#ifdef GGML_WEBGPU_GPU_PROFILE
@@ -252,13 +258,15 @@ struct webgpu_gpu_profile_buf_pool {
}
free.clear();
}
~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
};
#endif
struct webgpu_pipeline {
wgpu::ComputePipeline pipeline;
std::string name;
void * context = nullptr;
std::shared_ptr<void> context = nullptr;
};
struct webgpu_command {
@@ -319,6 +327,23 @@ struct webgpu_global_context_struct {
wgpu::Buffer debug_host_buf;
wgpu::Buffer debug_dev_buf;
#endif
~webgpu_global_context_struct() {
if (this->get_tensor_staging_buf) {
this->get_tensor_staging_buf.Destroy();
this->get_tensor_staging_buf = nullptr;
}
#ifdef GGML_WEBGPU_DEBUG
if (this->debug_host_buf) {
this->debug_host_buf.Destroy();
this->debug_host_buf = nullptr;
}
if (this->debug_dev_buf) {
this->debug_dev_buf.Destroy();
this->debug_dev_buf = nullptr;
}
#endif
}
};
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
@@ -744,7 +769,6 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
return ctx->name.c_str();
}
// TODO: implement proper cleanup
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
@@ -788,9 +812,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
#endif
#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE)
GGML_UNUSED(ctx);
#endif
delete ctx;
delete backend;
}
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
@@ -896,8 +919,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const uint32_t ne = (uint32_t) ggml_nelements(dst);
@@ -941,7 +963,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@@ -975,8 +997,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ctx->set_rows_pipelines.emplace(key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
if (key.i64_idx) {
@@ -1028,7 +1049,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
} else {
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
}
uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size);
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
error_bufs);
}
@@ -1297,10 +1318,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
ctx->flash_attn_pipelines.emplace(key, pipeline);
}
ggml_webgpu_flash_attn_shader_decisions decisions =
*static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@@ -1331,8 +1351,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
@@ -1392,7 +1411,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@@ -1425,8 +1444,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ctx->binary_pipelines.emplace(pipeline_key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
@@ -1471,7 +1489,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@@ -1821,8 +1839,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
argsort_pipeline.context = processed.decisions;
ctx->argsort_pipelines.emplace(order, argsort_pipeline);
}
ggml_webgpu_argsort_shader_decisions argsort_decisions =
*static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context);
auto * argsort_decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context.get());
webgpu_pipeline argsort_merge_pipeline;
it = ctx->argsort_merge_pipelines.find(order);
@@ -1839,13 +1856,13 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
const uint32_t src_ne0 = (uint32_t) src->ne[0];
const uint32_t nrows = (uint32_t) ggml_nrows(src);
const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions.wg_size);
const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
const uint32_t block_size =
is_top_k ? std::min(argsort_decisions.wg_size, (uint32_t) dst->ne[0]) : argsort_decisions.wg_size;
is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
uint32_t out_ne0 = src_ne0;
if (is_top_k) {
if (npr > 1) {
const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions.wg_size;
const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
} else {
out_ne0 = block_size;
@@ -2198,7 +2215,10 @@ static ggml_backend_i ggml_backend_webgpu_i = {
static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
ctx->buffer.Destroy();
if (ctx != nullptr && ctx->buffer != nullptr) {
ctx->buffer.Destroy();
delete ctx;
}
}
// Returns the "fake" base pointer.
@@ -2926,12 +2946,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
dev_desc.SetDeviceLostCallback(
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
if (reason == wgpu::DeviceLostReason::Destroyed) {
return;
}
GGML_UNUSED(device);
GGML_UNUSED(reason);
GGML_UNUSED(message);
//TODO: uncomment once proper free logic is in place
//GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
//std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
@@ -3365,10 +3385,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
return ctx->device_count;
}
// TODO: Does this need to be thread safe? Is it only called once?
// TODO: move most logic to device_init function so backend can be freed/initialized properly
// Only one device is supported for now
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");