ggml-webgpu: add vectorized flash attention (#20709)
* naive vectorized version * add vectorized flash attention * update vec version * remove unused path and shader * remove unused helper functions * add comments * remove pad path * ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization * change back to vec4 * enable multi split * enable vec path when: - Q->ne[1] < 20 - Q->ne[0] % 32 == 0 - V->ne[0] % 4 == 0 - K->type == f16 * update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select * enable vec path for q4 and q8 * flash-attn vec nwg=1 fast path (skip tmp/reduce staging) * use packed f16 K loads in flash-attn vec split * use packed f16 K loads in flash-attn vec split on host side * tune flash-attn vec f16 VEC_NE by head dim * cleanup * cleanup * keep host side clean * cleanup host side * change back to original host wait/submit behavior * formatting * reverted param-buffer pool r ecfactor * add helper functions * ggml-webgpu: move flash-attn vec pipeline caching back into shader lib * ggml-webgpu: remove duplicate functions * ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation * ggml-webgpu: revert unrelated change * ggml-webgpu: revert deleted comment * disable uniformity check * remove unnecessary change * Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl * Update ggml/src/ggml-webgpu/ggml-webgpu.cpp --------- Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
@@ -658,7 +658,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
for (size_t i = 0; i < params_bufs_list.size(); i++) {
|
||||
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
|
||||
if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
||||
@@ -1481,7 +1480,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
@@ -1565,30 +1563,248 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = Q,
|
||||
.src1 = K,
|
||||
.src2 = V,
|
||||
.src3 = mask,
|
||||
.src4 = sinks,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||
const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
|
||||
|
||||
const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned &&
|
||||
(Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
|
||||
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
|
||||
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
|
||||
const uint32_t vec_nwg_cap =
|
||||
std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
|
||||
const bool use_blk = use_vec && has_mask;
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {
|
||||
.kv_type = K->type,
|
||||
.head_dim_qk = (uint32_t) Q->ne[0],
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = static_cast<bool>(has_mask),
|
||||
.has_sinks = static_cast<bool>(has_sinks),
|
||||
.uses_logit_softcap = logit_softcap != 0.0f,
|
||||
.use_vec = use_vec,
|
||||
};
|
||||
|
||||
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
|
||||
.key = key,
|
||||
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
||||
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
||||
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
|
||||
wgpu::Buffer blk_buf = {};
|
||||
uint64_t blk_size_bytes = 0;
|
||||
uint32_t blk_nblk0 = 0;
|
||||
uint32_t blk_nblk1 = 0;
|
||||
uint32_t blk_batch_count = 0;
|
||||
|
||||
if (use_vec) {
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
nwg <<= 1;
|
||||
}
|
||||
nwg = std::min(nwg, vec_nwg_cap);
|
||||
GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size);
|
||||
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
|
||||
const bool use_vec_reduce = nwg > 1u;
|
||||
GGML_ASSERT(nrows <= UINT32_MAX);
|
||||
|
||||
uint64_t tmp_stats_base = 0;
|
||||
uint64_t tmp_size_bytes = 0;
|
||||
wgpu::Buffer tmp_buf = {};
|
||||
uint64_t tmp_bind_offset = 0;
|
||||
uint64_t tmp_bind_size = 0;
|
||||
const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes);
|
||||
|
||||
if (use_vec_reduce) {
|
||||
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
|
||||
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
|
||||
tmp_stats_base = tmp_data_elems;
|
||||
tmp_size_bytes =
|
||||
ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
GGML_ASSERT(tmp_stats_base <= UINT32_MAX);
|
||||
tmp_buf = ggml_webgpu_tensor_buf(dst);
|
||||
tmp_bind_offset = scratch_offset;
|
||||
tmp_bind_size = tmp_size_bytes;
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
||||
} else {
|
||||
// nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
|
||||
tmp_buf = ggml_webgpu_tensor_buf(dst);
|
||||
tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
|
||||
tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
|
||||
}
|
||||
|
||||
webgpu_pipeline blk_pipeline;
|
||||
std::vector<uint32_t> blk_params;
|
||||
std::vector<wgpu::BindGroupEntry> blk_entries;
|
||||
if (use_blk) {
|
||||
GGML_ASSERT(has_mask);
|
||||
|
||||
blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
|
||||
blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile);
|
||||
blk_buf = ggml_webgpu_tensor_buf(dst);
|
||||
const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
|
||||
blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = {
|
||||
.key =
|
||||
{
|
||||
.q_tile = decisions->q_tile,
|
||||
.kv_tile = decisions->kv_tile,
|
||||
},
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
|
||||
|
||||
blk_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
|
||||
(uint32_t) Q->ne[1], // seq_len_q
|
||||
(uint32_t) K->ne[1], // seq_len_kv
|
||||
stride_mask3, // stride_mask3
|
||||
blk_nblk0, // nblk0
|
||||
blk_nblk1, // nblk1
|
||||
};
|
||||
blk_entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(mask),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, mask) },
|
||||
{ .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes },
|
||||
};
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> split_params = params;
|
||||
if (use_blk) {
|
||||
split_params.push_back(0u); // blk_base
|
||||
split_params.push_back(blk_nblk0); // blk_nblk0
|
||||
split_params.push_back(blk_nblk1); // blk_nblk1
|
||||
}
|
||||
split_params.push_back(0u); // tmp_data_base
|
||||
split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base
|
||||
split_params.push_back(nwg); // nwg
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> split_entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(Q),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(K),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(V),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, V) },
|
||||
};
|
||||
uint32_t split_binding_index = 3;
|
||||
if (has_mask) {
|
||||
split_entries.push_back({ .binding = split_binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(mask),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
|
||||
}
|
||||
if (has_sinks) {
|
||||
split_entries.push_back({ .binding = split_binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(sinks),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
|
||||
}
|
||||
if (use_blk) {
|
||||
split_entries.push_back(
|
||||
{ .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes });
|
||||
}
|
||||
split_entries.push_back(
|
||||
{ .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size });
|
||||
split_entries.push_back({ .binding = split_binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
webgpu_pipeline reduce_pipeline;
|
||||
std::vector<uint32_t> reduce_params;
|
||||
std::vector<wgpu::BindGroupEntry> reduce_entries;
|
||||
if (use_vec_reduce) {
|
||||
const uint32_t reduce_wg_size = std::max(
|
||||
32u,
|
||||
std::min<uint32_t>(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
|
||||
ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = {
|
||||
.key =
|
||||
{
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.wg_size = reduce_wg_size,
|
||||
},
|
||||
.max_wg_size = reduce_wg_size,
|
||||
};
|
||||
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
|
||||
|
||||
reduce_params = {
|
||||
(uint32_t) nrows, // nrows
|
||||
(uint32_t) Q->ne[1], // seq_len_q
|
||||
(uint32_t) Q->ne[2], // n_heads
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst
|
||||
nwg, // nwg
|
||||
0u, // tmp_data_base
|
||||
(uint32_t) tmp_stats_base, // tmp_stats_base
|
||||
};
|
||||
|
||||
reduce_entries = {
|
||||
{ .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
};
|
||||
}
|
||||
|
||||
const uint64_t split_wg_total = (uint64_t) wg_x * nwg;
|
||||
GGML_ASSERT(split_wg_total <= UINT32_MAX);
|
||||
std::vector<webgpu_pipeline> pipelines;
|
||||
std::vector<std::vector<uint32_t>> params_list;
|
||||
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
|
||||
|
||||
if (use_blk) {
|
||||
pipelines.push_back(blk_pipeline);
|
||||
params_list.push_back(std::move(blk_params));
|
||||
entries_list.push_back(std::move(blk_entries));
|
||||
workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count });
|
||||
}
|
||||
pipelines.push_back(pipeline);
|
||||
params_list.push_back(std::move(split_params));
|
||||
entries_list.push_back(std::move(split_entries));
|
||||
workgroups_list.push_back({ (uint32_t) split_wg_total, 1u });
|
||||
if (use_vec_reduce) {
|
||||
pipelines.push_back(reduce_pipeline);
|
||||
params_list.push_back(std::move(reduce_params));
|
||||
entries_list.push_back(std::move(reduce_entries));
|
||||
workgroups_list.push_back({ (uint32_t) nrows, 1u });
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
|
||||
entries_list, workgroups_list);
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
#endif
|
||||
|
||||
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||
@@ -2559,7 +2775,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
std::vector<webgpu_submission> subs;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
bool contains_set_rows = false;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
|
||||
contains_set_rows = true;
|
||||
@@ -2834,6 +3049,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const ggml_tensor * Q = tensor->src[0];
|
||||
const ggml_tensor * K = tensor->src[1];
|
||||
const ggml_tensor * V = tensor->src[2];
|
||||
const ggml_tensor * mask = tensor->src[3];
|
||||
const ggml_tensor * sinks = tensor->src[4];
|
||||
if (Q && K && V) {
|
||||
GGML_UNUSED(sinks);
|
||||
const bool kv_direct = (K->type == GGML_TYPE_F16) &&
|
||||
(Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) &&
|
||||
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool use_vec =
|
||||
(Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
|
||||
(V->type == K->type);
|
||||
if (use_vec) {
|
||||
const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
const size_t limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const size_t q_tile = sg_mat_m;
|
||||
const size_t base_q_bytes =
|
||||
(Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!kv_direct) {
|
||||
bytes_per_kv += std::max(Q->ne[0], V->ne[0]);
|
||||
}
|
||||
if (mask != nullptr) {
|
||||
bytes_per_kv += q_tile;
|
||||
}
|
||||
bytes_per_kv += q_tile;
|
||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||
uint32_t kv_tile =
|
||||
((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n;
|
||||
kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile));
|
||||
kv_tile = (kv_tile / sg_mat_n) * sg_mat_n;
|
||||
if (kv_direct) {
|
||||
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= sg_mat_n;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t vec_nwg_cap = std::max(
|
||||
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
nwg <<= 1;
|
||||
}
|
||||
nwg = std::min(nwg, vec_nwg_cap);
|
||||
|
||||
const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
|
||||
if (nwg > 1u) {
|
||||
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
|
||||
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
|
||||
const size_t tmp_size_bytes = ROUNDUP_POW2(
|
||||
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += tmp_size_bytes + align;
|
||||
}
|
||||
if (mask != nullptr) {
|
||||
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
|
||||
const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
|
||||
const uint32_t stride_mask3 =
|
||||
(uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
|
||||
const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
const size_t blk_size_bytes =
|
||||
ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += blk_size_bytes + align;
|
||||
}
|
||||
res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user