sampling : delegate input allocation to the scheduler (#19266)
* sampling : delegate input allocation to the scheduler * graph : compute backend samplers only if needed
This commit is contained in:
+13
-6
@@ -2419,6 +2419,9 @@ void llm_graph_context::build_sampling() const {
|
||||
return;
|
||||
}
|
||||
|
||||
std::array<ggml_tensor *, 2> outs;
|
||||
outs[0] = res->t_logits;
|
||||
|
||||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
|
||||
@@ -2439,14 +2442,14 @@ void llm_graph_context::build_sampling() const {
|
||||
// add a dummy row of logits
|
||||
// this trick makes the graph static, regardless of which samplers are activated
|
||||
// this is important in order to minimize graph reallocations
|
||||
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
|
||||
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
||||
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
const auto it = seq_to_logit_row.find(seq_id);
|
||||
|
||||
// inactive samplers always work on the first row
|
||||
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
|
||||
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
|
||||
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
|
||||
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
@@ -2463,22 +2466,26 @@ void llm_graph_context::build_sampling() const {
|
||||
|
||||
if (data.sampled != nullptr) {
|
||||
res->t_sampled[seq_id] = data.sampled;
|
||||
ggml_build_forward_expand(gf, data.sampled);
|
||||
outs[1] = data.sampled;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
|
||||
if (data.probs != nullptr) {
|
||||
res->t_sampled_probs[seq_id] = data.probs;
|
||||
ggml_build_forward_expand(gf, data.probs);
|
||||
outs[1] = data.probs;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
|
||||
if (data.logits != nullptr) {
|
||||
res->t_sampled_logits[seq_id] = data.logits;
|
||||
ggml_build_forward_expand(gf, data.logits);
|
||||
outs[1] = data.logits;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
|
||||
if (data.candidates != nullptr) {
|
||||
res->t_candidates[seq_id] = data.candidates;
|
||||
ggml_build_forward_expand(gf, data.candidates);
|
||||
outs[1] = data.candidates;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user