sampling : delegate input allocation to the scheduler (#19266)

* sampling : delegate input allocation to the scheduler

* graph : compute backend samplers only if needed
This commit is contained in:
Georgi Gerganov
2026-02-03 22:16:16 +02:00
committed by GitHub
parent 32b17abdb0
commit faa1bc26ee
3 changed files with 33 additions and 73 deletions
+13 -6
View File
@@ -2419,6 +2419,9 @@ void llm_graph_context::build_sampling() const {
return;
}
std::array<ggml_tensor *, 2> outs;
outs[0] = res->t_logits;
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
res->add_input(std::move(inp_sampling));
@@ -2439,14 +2442,14 @@ void llm_graph_context::build_sampling() const {
// add a dummy row of logits
// this trick makes the graph static, regardless of which samplers are activated
// this is important in order to minimize graph reallocations
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
for (const auto & [seq_id, sampler] : samplers) {
const auto it = seq_to_logit_row.find(seq_id);
// inactive samplers always work on the first row
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
@@ -2463,22 +2466,26 @@ void llm_graph_context::build_sampling() const {
if (data.sampled != nullptr) {
res->t_sampled[seq_id] = data.sampled;
ggml_build_forward_expand(gf, data.sampled);
outs[1] = data.sampled;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.probs != nullptr) {
res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, data.probs);
outs[1] = data.probs;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.logits != nullptr) {
res->t_sampled_logits[seq_id] = data.logits;
ggml_build_forward_expand(gf, data.logits);
outs[1] = data.logits;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
if (data.candidates != nullptr) {
res->t_candidates[seq_id] = data.candidates;
ggml_build_forward_expand(gf, data.candidates);
outs[1] = data.candidates;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
}
}