ggml-webgpu: Update register tiling matmul to use f32 accumulation (#21644)
* Update register tiling matmul to use f32 accumulation * fix profiling code * Fix register tiling matmul for chrome, i'm blaming dawn * Update batch tuning value for iOS * compile fix * Fix use of new load function
This commit is contained in:
@@ -79,7 +79,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
|||||||
|
|
||||||
/* Constants */
|
/* Constants */
|
||||||
|
|
||||||
#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u
|
#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 64u
|
||||||
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
|
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
|
||||||
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
|
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
|
||||||
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
|
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
|
||||||
@@ -97,14 +97,6 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
|||||||
|
|
||||||
/* End Constants */
|
/* End Constants */
|
||||||
|
|
||||||
static inline wgpu::CallbackMode ggml_webgpu_callback_mode() {
|
|
||||||
#ifdef __EMSCRIPTEN__
|
|
||||||
return wgpu::CallbackMode::AllowProcessEvents;
|
|
||||||
#else
|
|
||||||
return wgpu::CallbackMode::AllowSpontaneous;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
||||||
// their locations.
|
// their locations.
|
||||||
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
||||||
@@ -445,34 +437,25 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __EMSCRIPTEN__
|
#ifdef __EMSCRIPTEN__
|
||||||
// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures.
|
|
||||||
EM_JS(int, ggml_webgpu_is_ios_browser, (), {
|
EM_JS(int, ggml_webgpu_is_ios_browser, (), {
|
||||||
const ua = navigator.userAgent;
|
const ua = navigator.userAgent;
|
||||||
return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0;
|
return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0;
|
||||||
});
|
});
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) {
|
// TODO: these next two functions may want tuning across different platforms and workloads,
|
||||||
|
static uint32_t ggml_backend_webgpu_get_max_inflight_batches() {
|
||||||
#ifdef __EMSCRIPTEN__
|
#ifdef __EMSCRIPTEN__
|
||||||
|
// iOS has very strict limits on the number of in-flight GPU commands,
|
||||||
|
// so we need to throttle to avoid failures.
|
||||||
if (ggml_webgpu_is_ios_browser()) {
|
if (ggml_webgpu_is_ios_browser()) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
GGML_UNUSED(info);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return UINT32_MAX;
|
return UINT32_MAX;
|
||||||
}
|
}
|
||||||
|
|
||||||
static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) {
|
static uint32_t ggml_backend_webgpu_get_command_submit_batch_size() {
|
||||||
#ifdef __EMSCRIPTEN__
|
|
||||||
if (ggml_webgpu_is_ios_browser()) {
|
|
||||||
return 16;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(info);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE;
|
return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -482,7 +465,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
|
|||||||
|
|
||||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||||
ctx->queue.OnSubmittedWorkDone(
|
ctx->queue.OnSubmittedWorkDone(
|
||||||
ggml_webgpu_callback_mode(),
|
wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||||
callback_status = status;
|
callback_status = status;
|
||||||
callback_message = std::string(message);
|
callback_message = std::string(message);
|
||||||
@@ -502,7 +485,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
|
|||||||
std::string callback_message;
|
std::string callback_message;
|
||||||
|
|
||||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||||
buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(),
|
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||||
callback_status = status;
|
callback_status = status;
|
||||||
callback_message = std::string(message);
|
callback_message = std::string(message);
|
||||||
@@ -542,15 +525,15 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||||
static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx,
|
static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx,
|
||||||
const std::vector<webgpu_command> & commands,
|
const std::vector<webgpu_encoded_op> & commands,
|
||||||
std::vector<wgpu::FutureWaitInfo> & futures) {
|
std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||||
for (const auto & command : commands) {
|
for (const auto & command : commands) {
|
||||||
auto label = command.pipeline_name;
|
auto label = command.pipeline_name;
|
||||||
auto ts_bufs = command.timestamp_query_bufs;
|
auto ts_bufs = command.timestamp_query_bufs;
|
||||||
|
|
||||||
wgpu::Future f = ts_bufs.host_buf.MapAsync(
|
wgpu::Future f = ts_bufs.host_buf.MapAsync(
|
||||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(),
|
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||||
if (status != wgpu::MapAsyncStatus::Success) {
|
if (status != wgpu::MapAsyncStatus::Success) {
|
||||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
|
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
|
||||||
@@ -3428,7 +3411,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||||||
|
|
||||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||||
&options, ggml_webgpu_callback_mode(),
|
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||||
@@ -3449,8 +3432,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
|
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
|
||||||
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info);
|
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
|
||||||
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info);
|
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
|
||||||
wgpu::SupportedFeatures features;
|
wgpu::SupportedFeatures features;
|
||||||
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
|
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
|
||||||
// we require f16 support
|
// we require f16 support
|
||||||
@@ -3501,7 +3484,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||||||
dev_desc.requiredFeatures = required_features.data();
|
dev_desc.requiredFeatures = required_features.data();
|
||||||
dev_desc.requiredFeatureCount = required_features.size();
|
dev_desc.requiredFeatureCount = required_features.size();
|
||||||
dev_desc.SetDeviceLostCallback(
|
dev_desc.SetDeviceLostCallback(
|
||||||
ggml_webgpu_callback_mode(),
|
wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||||
if (reason == wgpu::DeviceLostReason::Destroyed) {
|
if (reason == wgpu::DeviceLostReason::Destroyed) {
|
||||||
return;
|
return;
|
||||||
@@ -3535,7 +3518,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||||||
|
|
||||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||||
ctx->webgpu_global_ctx->adapter.RequestDevice(
|
ctx->webgpu_global_ctx->adapter.RequestDevice(
|
||||||
&dev_desc, ggml_webgpu_callback_mode(),
|
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
||||||
|
|||||||
@@ -502,12 +502,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||||||
let d = load_f16_at(&src0, block_byte_base);
|
let d = load_f16_at(&src0, block_byte_base);
|
||||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||||
|
|
||||||
// Load packed scales
|
|
||||||
var scale_vals: array<u32, 3>;
|
|
||||||
for (var i: u32 = 0u; i < 3u; i++) {
|
|
||||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map k_in_block to loop structure:
|
// Map k_in_block to loop structure:
|
||||||
// Outer loop over 64-element groups (alternating q_b_idx)
|
// Outer loop over 64-element groups (alternating q_b_idx)
|
||||||
// Inner loop over 2 shifts per group
|
// Inner loop over 2 shifts per group
|
||||||
@@ -523,15 +517,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||||||
var sc: u32;
|
var sc: u32;
|
||||||
var mn: u32;
|
var mn: u32;
|
||||||
|
|
||||||
|
let scale_base = block_byte_base + 4u;
|
||||||
|
|
||||||
if (is < 4u) {
|
if (is < 4u) {
|
||||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
|
||||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
|
||||||
sc = sc_byte & 63u;
|
sc = sc_byte & 63u;
|
||||||
mn = min_byte & 63u;
|
mn = min_byte & 63u;
|
||||||
} else {
|
} else {
|
||||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
|
||||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
|
||||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
|
||||||
|
|
||||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||||
@@ -578,11 +574,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||||||
let d = load_f16_at(&src0, block_byte_base);
|
let d = load_f16_at(&src0, block_byte_base);
|
||||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||||
|
|
||||||
// Load packed scales
|
|
||||||
var scale_vals: array<u32, 3>;
|
|
||||||
for (var i: u32 = 0u; i < 3u; i++) {
|
|
||||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The original loop processes elements in groups of 64
|
// The original loop processes elements in groups of 64
|
||||||
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
|
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
|
||||||
@@ -603,15 +594,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||||||
var sc: u32;
|
var sc: u32;
|
||||||
var mn: u32;
|
var mn: u32;
|
||||||
|
|
||||||
|
let scale_base = block_byte_base + 4u;
|
||||||
|
|
||||||
if (is < 4u) {
|
if (is < 4u) {
|
||||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
|
||||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
|
||||||
sc = sc_byte & 63u;
|
sc = sc_byte & 63u;
|
||||||
mn = min_byte & 63u;
|
mn = min_byte & 63u;
|
||||||
} else {
|
} else {
|
||||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
|
||||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
|
||||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
|
||||||
|
|
||||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ enable f16;
|
|||||||
#include "mul_mat_decls.tmpl"
|
#include "mul_mat_decls.tmpl"
|
||||||
|
|
||||||
#ifdef VEC
|
#ifdef VEC
|
||||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
|
fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
|
||||||
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
|
return vec4<f32>(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef SCALAR
|
#ifdef SCALAR
|
||||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
|
fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
|
||||||
return f32(acc[tm][tn]);
|
return acc[tm][tn];
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
|
let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
|
||||||
let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
|
let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
|
||||||
|
|
||||||
var acc: array<array<f16, TILE_N>, TILE_M>;
|
var acc: array<array<f32, TILE_N>, TILE_M>;
|
||||||
|
|
||||||
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
|
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
|
||||||
|
|
||||||
@@ -122,7 +122,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
let src1_idx = src1_n * TILE_K + k_inner;
|
let src1_idx = src1_n * TILE_K + k_inner;
|
||||||
let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
|
let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
|
||||||
for (var tm = 0u; tm < TILE_M; tm++) {
|
for (var tm = 0u; tm < TILE_M; tm++) {
|
||||||
acc[tm][tn] += src0_tile[tm] * src1_val;
|
acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ enable chromium_experimental_subgroup_matrix;
|
|||||||
#include "common_decls.tmpl"
|
#include "common_decls.tmpl"
|
||||||
#include "mul_mat_decls.tmpl"
|
#include "mul_mat_decls.tmpl"
|
||||||
|
|
||||||
|
// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs.
|
||||||
|
// See https://github.com/ggml-org/llama.cpp/issues/21602
|
||||||
|
|
||||||
#ifdef VEC
|
#ifdef VEC
|
||||||
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||||
dst[dst_idx] = vec4<f32>(
|
dst[dst_idx] = vec4<f32>(
|
||||||
|
|||||||
Reference in New Issue
Block a user