ggml webgpu: initial flashattention implementation (#18610)
* FlashAttention (#13) * Add inplace softmax * Move rms_norm to split row approach * Update debug for supports_op * clean up debug statements * neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though * neg passes backend test * unary operators pass ggml tests * rms_norm double declaration bug atoned * abides by editor-config * removed vestigial files * fixed autoconfig * All operators (inlcluding xielu) working * removed unnecesarry checking if node->src[1] exists for unary operators * responded and dealt with PR comments * implemented REPL_Template support and removed bug in unary operators kernel * formatted embed wgsl and ggml-webgpu.cpp * Faster tensors (#8) Add fast matrix and matrix/vector multiplication. * Use map for shader replacements instead of pair of strings * Wasm (#9) * webgpu : fix build on emscripten * more debugging stuff * test-backend-ops: force single thread on wasm * fix single-thread case for init_tensor_uniform * use jspi * add pthread * test: remember to set n_thread for cpu backend * Add buffer label and enable dawn-specific toggles to turn off some checks * Intermediate state * Fast working f16/f32 vec4 * Working float fast mul mat * Clean up naming of mul_mat to match logical model, start work on q mul_mat * Setup for subgroup matrix mat mul * Basic working subgroup matrix * Working subgroup matrix tiling * Handle weirder sg matrix sizes (but still % sg matrix size) * Working start to gemv * working f16 accumulation with shared memory staging * Print out available subgroup matrix configurations * Vectorize dst stores for sg matrix shader * Gemv working scalar * Minor set_rows optimization (#4) * updated optimization, fixed errors * non vectorized version now dispatches one thread per element * Simplify * Change logic for set_rows pipelines --------- Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan> Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local> Co-authored-by: Reese Levine <reeselevine1@gmail.com> * Comment on dawn toggles * Working subgroup matrix code for (semi)generic sizes * Remove some comments * Cleanup code * Update dawn version and move to portable subgroup size * Try to fix new dawn release * Update subgroup size comment * Only check for subgroup matrix configs if they are supported * Add toggles for subgroup matrix/f16 support on nvidia+vulkan * Make row/col naming consistent * Refactor shared memory loading * Move sg matrix stores to correct file * Working q4_0 * Formatting * Work with emscripten builds * Fix test-backend-ops emscripten for f16/quantized types * Use emscripten memory64 to support get_memory * Add build flags and try ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> * Remove extra whitespace * Move wasm single-thread logic out of test-backend-ops for cpu backend * Disable multiple threads for emscripten single-thread builds in ggml_graph_plan * Refactored pipelines and workgroup calculations (#10) * refactored pipelines * refactored workgroup calculation * removed commented out block of prior maps * Clean up ceiling division pattern --------- Co-authored-by: Neha Abbas <nehaabbas@eduroam-169-233-141-223.ucsc.edu> Co-authored-by: Reese Levine <reeselevine1@gmail.com> * Start work on flash attention * Shader structure set up (many bugs still) * debugging * Working first test * Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32 * Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling * Start work on integrating pre-wgsl * Separate structs/initial shader compilation library into separate files * Work on compilation choices for flashattention * Work on subgroup matrix/tile size portability * subgroup size agnostic online softmax * Cleanups, quantization types * more cleanup * fix wasm build * Refactor flashattention to increase parallelism, use direct loads for KV in somce cases * Checkpoint * formatting * Update to account for default kv cache padding * formatting shader * Add workflow for ggml-ci webgpu * Try passing absolute path to dawn in ggml-ci * Avoid error on device destruction, add todos for proper cleanup * Fix unused warning * Forgot one parameter unused * Move some flashattn computation to f32 for correctness
This commit is contained in:
@@ -7,7 +7,9 @@
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-webgpu-shader-lib.hpp"
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# include <emscripten/emscripten.h>
|
||||
@@ -30,7 +32,7 @@
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
||||
# define WEBGPU_DEBUG_BUF_ELEMS 32
|
||||
# define WEBGPU_DEBUG_BUF_ELEMS 512
|
||||
#else
|
||||
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
|
||||
#endif // GGML_WEBGPU_DEBUG
|
||||
@@ -251,6 +253,7 @@ struct webgpu_gpu_profile_buf_pool {
|
||||
struct webgpu_pipeline {
|
||||
wgpu::ComputePipeline pipeline;
|
||||
std::string name;
|
||||
void * context = nullptr;
|
||||
};
|
||||
|
||||
struct webgpu_command {
|
||||
@@ -263,6 +266,46 @@ struct webgpu_command {
|
||||
#endif
|
||||
};
|
||||
|
||||
struct flash_attn_pipeline_key {
|
||||
int q_type;
|
||||
int kv_type;
|
||||
int dst_type;
|
||||
uint32_t head_dim_qk;
|
||||
uint32_t head_dim_v;
|
||||
bool kv_direct;
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
|
||||
bool operator==(const flash_attn_pipeline_key & other) const {
|
||||
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
|
||||
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
|
||||
has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap;
|
||||
}
|
||||
};
|
||||
|
||||
// Same hash combine function as in boost
|
||||
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
||||
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
struct flash_attn_pipeline_key_hash {
|
||||
size_t operator()(const flash_attn_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.q_type);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_type);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
// All the base objects needed to run operations on a WebGPU device
|
||||
struct webgpu_context_struct {
|
||||
wgpu::Instance instance;
|
||||
@@ -271,12 +314,12 @@ struct webgpu_context_struct {
|
||||
wgpu::Queue queue;
|
||||
wgpu::Limits limits;
|
||||
|
||||
uint32_t subgroup_size;
|
||||
uint32_t max_subgroup_size;
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
bool supports_subgroup_matrix = false;
|
||||
wgpu::SubgroupMatrixConfig subgroup_matrix_config;
|
||||
#endif
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t sg_mat_m;
|
||||
uint32_t sg_mat_n;
|
||||
uint32_t sg_mat_k;
|
||||
|
||||
std::recursive_mutex mutex;
|
||||
std::atomic_uint inflight_threads = 0;
|
||||
@@ -284,20 +327,24 @@ struct webgpu_context_struct {
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
|
||||
pre_wgsl::Preprocessor p;
|
||||
|
||||
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
|
||||
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
||||
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
|
||||
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
||||
std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
|
||||
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
|
||||
|
||||
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
||||
@@ -361,8 +408,6 @@ struct ggml_backend_webgpu_buffer_context {
|
||||
label(std::move(lbl)) {}
|
||||
};
|
||||
|
||||
/* End struct definitions */
|
||||
|
||||
/* WebGPU object initializations */
|
||||
|
||||
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
|
||||
@@ -484,14 +529,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
|
||||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
ctx->queue.Submit(1, &commands);
|
||||
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
||||
const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug data:";
|
||||
for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
|
||||
std::cout << " " << i << ": " << debug_data[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
||||
ctx->debug_host_buf.Unmap();
|
||||
}
|
||||
#endif
|
||||
@@ -673,6 +713,7 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
|
||||
return ctx->name.c_str();
|
||||
}
|
||||
|
||||
// TODO: implement proper cleanup
|
||||
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
||||
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
|
||||
@@ -730,12 +771,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
||||
return ctx->buffer;
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
|
||||
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
|
||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
@@ -964,12 +1005,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (ctx->supports_subgroup_matrix) {
|
||||
// The total number of subgroups/workgroups needed per matrix.
|
||||
uint32_t wg_m_sg_tile =
|
||||
WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
|
||||
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
||||
uint32_t wg_n_sg_tile =
|
||||
WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
|
||||
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
||||
uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
|
||||
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
||||
uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
|
||||
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
||||
} else {
|
||||
#endif
|
||||
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
@@ -986,6 +1025,146 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
ggml_tensor * V,
|
||||
ggml_tensor * mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * dst) {
|
||||
float scale = *(float *) dst->op_params;
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
|
||||
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
||||
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) Q->ne[2], // number of heads
|
||||
(uint32_t) Q->ne[1], // sequence length (Q)
|
||||
(uint32_t) K->ne[1], // sequence length (K/V)
|
||||
(uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
|
||||
(uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
|
||||
(uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
|
||||
(uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
|
||||
(uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
|
||||
(uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
|
||||
(uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
|
||||
(uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
|
||||
(uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
|
||||
has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
|
||||
(uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
|
||||
*(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
|
||||
*(uint32_t *) &max_bias,
|
||||
*(uint32_t *) &logit_softcap,
|
||||
*(uint32_t *) &n_head_log2,
|
||||
*(uint32_t *) &m0,
|
||||
*(uint32_t *) &m1
|
||||
|
||||
};
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(Q),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(K),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(V),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, V) }
|
||||
};
|
||||
uint32_t binding_index = 3;
|
||||
if (has_mask) {
|
||||
entries.push_back({ .binding = binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(mask),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
|
||||
}
|
||||
if (has_sinks) {
|
||||
entries.push_back({ .binding = binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(sinks),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
|
||||
}
|
||||
entries.push_back({ .binding = binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
bool kv_direct =
|
||||
(K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
|
||||
flash_attn_pipeline_key key = {
|
||||
.q_type = Q->type,
|
||||
.kv_type = K->type,
|
||||
.dst_type = dst->type,
|
||||
.head_dim_qk = (uint32_t) Q->ne[0],
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = static_cast<bool>(has_mask),
|
||||
.has_sinks = static_cast<bool>(has_sinks),
|
||||
.uses_logit_softcap = logit_softcap != 0.0f,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
ggml_webgpu_flash_attn_shader_decisions decisions = {};
|
||||
|
||||
auto it = ctx->flash_attn_pipelines.find(key);
|
||||
if (it != ctx->flash_attn_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
|
||||
} else {
|
||||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||
it = ctx->flash_attn_pipelines.find(key);
|
||||
if (it != ctx->flash_attn_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
|
||||
} else {
|
||||
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
|
||||
.head_dim_qk = (uint32_t) Q->ne[0],
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = static_cast<bool>(has_mask),
|
||||
.has_sinks = static_cast<bool>(has_sinks),
|
||||
.uses_logit_softcap = logit_softcap != 0.0f,
|
||||
.sg_mat_m = ctx->sg_mat_m,
|
||||
.sg_mat_n = ctx->sg_mat_n,
|
||||
.sg_mat_k = ctx->sg_mat_k,
|
||||
.wg_mem_limit_bytes =
|
||||
ctx->limits.maxComputeWorkgroupStorageSize,
|
||||
.max_subgroup_size = ctx->max_subgroup_size };
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
|
||||
pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
|
||||
ctx->flash_attn_pipelines.emplace(key, pipeline);
|
||||
decisions = processed.decisions;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
ggml_unary_op unary_op = ggml_get_unary_op(dst);
|
||||
@@ -1397,6 +1576,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
return ggml_webgpu_get_rows(ctx, src0, src1, node);
|
||||
case GGML_OP_MUL_MAT:
|
||||
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||
@@ -1466,6 +1647,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
|
||||
futures.push_back(new_futures);
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx, futures);
|
||||
ctx->inflight_threads--;
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
|
||||
@@ -1808,15 +1990,15 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (webgpu_ctx->supports_subgroup_matrix) {
|
||||
std::map<std::string, std::string> sg_matrix_repls;
|
||||
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
|
||||
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
|
||||
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
|
||||
|
||||
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||
proc_mul_mat_f32_f32_vec =
|
||||
@@ -2328,6 +2510,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
||||
webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
|
||||
}
|
||||
|
||||
// TODO: move most initialization logic here
|
||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
GGML_UNUSED(params);
|
||||
|
||||
@@ -2489,6 +2672,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (!webgpu_ctx->supports_subgroup_matrix) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||
size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
|
||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
|
||||
has_mask, kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
break;
|
||||
}
|
||||
|
||||
supports_op = src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
||||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
||||
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
@@ -2606,6 +2812,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
}
|
||||
|
||||
// TODO: Does this need to be thread safe? Is it only called once?
|
||||
// TODO: move most logic to device_init function so backend can be freed/initialized properly
|
||||
// Only one device is supported for now
|
||||
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
@@ -2665,7 +2872,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
||||
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
||||
ctx->subgroup_matrix_config = config;
|
||||
ctx->sg_mat_m = config.M;
|
||||
ctx->sg_mat_n = config.N;
|
||||
ctx->sg_mat_k = config.K;
|
||||
valid_subgroup_matrix_config = true;
|
||||
break;
|
||||
}
|
||||
@@ -2676,7 +2885,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
#endif
|
||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||
ctx->subgroup_size = info.subgroupMaxSize;
|
||||
ctx->max_subgroup_size = info.subgroupMaxSize;
|
||||
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
|
||||
@@ -2701,8 +2910,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||
std::string(message).c_str());
|
||||
GGML_UNUSED(reason);
|
||||
GGML_UNUSED(message);
|
||||
//TODO: uncomment once proper free logic is in place
|
||||
//GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||
//std::string(message).c_str());
|
||||
});
|
||||
dev_desc.SetUncapturedErrorCallback(
|
||||
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
||||
|
||||
Reference in New Issue
Block a user