ggml-webgpu: Fix bug in FlashAttention support check (#22492)
* Fix flashattention support check for devices that don't support subgroups * set path to none if kv_tile doesn't fit
This commit is contained in:
@@ -494,9 +494,10 @@ struct ggml_webgpu_unary_pipeline_key_hash {
|
|||||||
/** FlashAttention */
|
/** FlashAttention */
|
||||||
|
|
||||||
enum ggml_webgpu_flash_attn_path : uint32_t {
|
enum ggml_webgpu_flash_attn_path : uint32_t {
|
||||||
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u,
|
GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u,
|
||||||
GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u,
|
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u,
|
||||||
GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u,
|
GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u,
|
||||||
|
GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||||
@@ -534,7 +535,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_webgpu_flash_attn_decisions {
|
struct ggml_webgpu_flash_attn_decisions {
|
||||||
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||||
uint32_t q_tile = 0;
|
uint32_t q_tile = 0;
|
||||||
uint32_t kv_tile = 0;
|
uint32_t kv_tile = 0;
|
||||||
uint32_t wg_size = 0;
|
uint32_t wg_size = 0;
|
||||||
@@ -709,19 +710,29 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||||||
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec;
|
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec;
|
||||||
|
|
||||||
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||||
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||||
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
|
||||||
|
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||||
|
|
||||||
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
|
||||||
|
return decisions;
|
||||||
|
}
|
||||||
|
|
||||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||||
decisions.kv_direct = key.kv_direct;
|
decisions.kv_direct = key.kv_direct;
|
||||||
|
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||||
|
// invalidate if even the smallest kv_tile doesn't fit in shared memory
|
||||||
|
if (max_kv_tile == 0) {
|
||||||
|
decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||||
|
return decisions;
|
||||||
|
}
|
||||||
|
|
||||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
decisions.q_tile = 1u;
|
||||||
decisions.q_tile = 1u;
|
decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile));
|
||||||
decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile));
|
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
|
||||||
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
|
decisions.wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||||
decisions.wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
|
||||||
if (decisions.kv_direct) {
|
if (decisions.kv_direct) {
|
||||||
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||||
@@ -734,9 +745,8 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||||||
decisions.q_tile =
|
decisions.q_tile =
|
||||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m;
|
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m;
|
||||||
decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||||
std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) :
|
std::min(64u, max_kv_tile) :
|
||||||
std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key),
|
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
|
||||||
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||||
GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE :
|
GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE :
|
||||||
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||||
@@ -755,7 +765,6 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
|||||||
context.sg_mat_n;
|
context.sg_mat_n;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return decisions;
|
return decisions;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1364,7 +1373,7 @@ class ggml_webgpu_shader_lib {
|
|||||||
if (key.src_type == GGML_TYPE_Q1_0) {
|
if (key.src_type == GGML_TYPE_Q1_0) {
|
||||||
defines.push_back("BLOCK_SIZE=128u");
|
defines.push_back("BLOCK_SIZE=128u");
|
||||||
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||||
key.src_type == GGML_TYPE_IQ4_NL) {
|
key.src_type == GGML_TYPE_IQ4_NL) {
|
||||||
defines.push_back("BLOCK_SIZE=32u");
|
defines.push_back("BLOCK_SIZE=32u");
|
||||||
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
||||||
defines.push_back("BLOCK_SIZE=256u");
|
defines.push_back("BLOCK_SIZE=256u");
|
||||||
@@ -2325,6 +2334,7 @@ class ggml_webgpu_shader_lib {
|
|||||||
size_t storage_offset_alignment) {
|
size_t storage_offset_alignment) {
|
||||||
const ggml_webgpu_flash_attn_decisions decisions =
|
const ggml_webgpu_flash_attn_decisions decisions =
|
||||||
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
|
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
|
||||||
|
GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE);
|
||||||
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||||
auto it = flash_attn_pipelines.find(key);
|
auto it = flash_attn_pipelines.find(key);
|
||||||
if (it != flash_attn_pipelines.end()) {
|
if (it != flash_attn_pipelines.end()) {
|
||||||
|
|||||||
@@ -3918,6 +3918,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||||||
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||||
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||||
const bool has_mask = op->src[3] != nullptr;
|
const bool has_mask = op->src[3] != nullptr;
|
||||||
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
|
||||||
|
supports_op = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
const size_t min_bytes =
|
const size_t min_bytes =
|
||||||
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||||
|
|||||||
Reference in New Issue
Block a user