vulkan: Reduce temporary memory usage for TOP_K (#17623)
- Compute row size for the temp buffer based on the output of the first pass. - Update shader addressing math to use the output row size - Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k" For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer from about 3.2MB to 500KB.
This commit is contained in:
@@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants {
|
||||
uint32_t orig_ncols;
|
||||
uint32_t ncols_input;
|
||||
uint32_t ncols_output;
|
||||
uint32_t k;
|
||||
uint32_t nrows;
|
||||
uint32_t first_pass;
|
||||
uint32_t last_pass;
|
||||
@@ -1673,6 +1674,14 @@ class vk_perf_logger {
|
||||
timings[name.str()].push_back(time);
|
||||
return;
|
||||
}
|
||||
if (node->op == GGML_OP_TOP_K) {
|
||||
std::stringstream name;
|
||||
name << ggml_op_name(node->op) <<
|
||||
" K=" << node->ne[0] <<
|
||||
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
|
||||
timings[name.str()].push_back(time);
|
||||
return;
|
||||
}
|
||||
timings[ggml_op_name(node->op)].push_back(time);
|
||||
}
|
||||
private:
|
||||
@@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||
uint32_t nrows = ggml_nrows(src0);
|
||||
uint32_t k = dst->ne[0];
|
||||
|
||||
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
|
||||
vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
|
||||
|
||||
// Reserve space for ivec2 per element, double buffered
|
||||
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
|
||||
const size_t x_sz = dbl_buf_size * 2;
|
||||
uint32_t dbl_buf_index = 0;
|
||||
|
||||
if (ctx->prealloc_size_x < x_sz) {
|
||||
ctx->prealloc_size_x = x_sz;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
if (ctx->prealloc_x_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
@@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||
// largest elements. Repeat until we have the top K elements.
|
||||
// Need to do at least one iteration to write out the results.
|
||||
bool done_one_iter = false;
|
||||
uint32_t dbl_buf_index = 0;
|
||||
size_t dbl_buf_size;
|
||||
while (num_elements > k || !done_one_iter) {
|
||||
done_one_iter = true;
|
||||
|
||||
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
||||
// But if K is larger, then we need a larger workgroup
|
||||
@@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||
// Number of elements remaining after this pass
|
||||
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
|
||||
|
||||
pc2.ncols_output = num_dst_elements;
|
||||
|
||||
if (!done_one_iter) {
|
||||
// Reserve space for ivec2 per element, double buffered
|
||||
// K per workgroup per row
|
||||
dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
|
||||
dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t x_sz = dbl_buf_size * 2;
|
||||
|
||||
if (ctx->prealloc_size_x < x_sz) {
|
||||
ctx->prealloc_size_x = x_sz;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
}
|
||||
|
||||
vk_subbuffer src_buf;
|
||||
vk_subbuffer dst_buf;
|
||||
|
||||
@@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||
if (num_elements > k) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
done_one_iter = true;
|
||||
}
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user