speculative-simple : add checkpoint support (#22227)

* speculative-simple : add checkpoint support

* cont : fix build
This commit is contained in:
Georgi Gerganov
2026-04-22 15:44:45 +03:00
committed by GitHub
parent 225088ea76
commit bcb5eeb645
2 changed files with 112 additions and 10 deletions
@@ -8,8 +8,24 @@
#include <clocale>
#include <cstdio>
#include <cstring>
#include <cinttypes>
#include <string>
#include <vector>
#include <utility>
struct spec_checkpoint {
int64_t n_tokens = 0;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
bool empty() const {
return data.empty();
}
};
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
@@ -46,6 +62,14 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
// check if the context supports partial sequence removal
const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt);
const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);
if (use_ckpt) {
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
}
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model
@@ -119,7 +143,7 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();
// target model sampling context
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling));
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
@@ -142,21 +166,61 @@ int main(int argc, char ** argv) {
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
size_t n_draft = 0;
llama_tokens draft;
spec_checkpoint spec_ckpt;
const auto t_enc_end = ggml_time_us();
const auto t_dec_start = ggml_time_us();
while (true) {
// optionally, generate draft tokens that can be appended to the target batch
// generate or reuse draft tokens
//
// this is the most important part of the speculation. the more probable tokens that are provided here
// the better the performance will be. in theory, this computation can be performed asynchronously and even
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
// from a cache or lookup tables.
//
llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
if (draft.empty()) {
// generate a new draft
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
if ((int) draft.size() > params_spec.n_max) {
LOG_WRN("draft size %zu exceeds max %d, truncating\n", draft.size(), params_spec.n_max);
draft.resize(params_spec.n_max);
}
if ((int) draft.size() < params_spec.n_min) {
LOG_DBG("ignoring small draft: %zu < %d\n", draft.size(), params_spec.n_min);
draft.clear();
}
// save the original draft size
n_draft = draft.size();
// save a checkpoint of the target context before evaluating the draft
// this allows us to restore the state if partial draft acceptance occurs
if (!draft.empty() && use_ckpt) {
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
spec_ckpt.data.resize(ckpt_size);
const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == ckpt_size);
spec_ckpt.n_tokens = (int64_t) prompt_tgt.size();
LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n",
spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
}
} else {
// we have a previous (partial) draft to reuse from checkpoint restoration
if (use_ckpt) {
GGML_ASSERT(!spec_ckpt.empty());
}
}
GGML_ASSERT(n_draft > 0);
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
@@ -178,6 +242,12 @@ int main(int argc, char ** argv) {
llama_decode(ctx_tgt, batch_tgt);
}
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt) {
smpl_save.reset(common_sampler_clone(smpl.get()));
}
// sample from the full target batch and return the accepted tokens based on the target sampler
//
// for each token to be accepted, the sampler would have to sample that same token
@@ -185,14 +255,38 @@ int main(int argc, char ** argv) {
// available logits from the batch and sample the next token until we run out of logits or the sampler
// disagrees with the draft
//
const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
auto ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft);
//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
// check for partial draft acceptance:
// if the context doesn't support partial sequence removal, restore the checkpoint
// and make the accepted tokens the new partial draft for the next iteration
if (use_ckpt && ids.size() - 1 < draft.size()) {
LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size());
draft = std::move(ids);
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == spec_ckpt.size());
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1);
prompt_tgt.resize(spec_ckpt.n_tokens);
smpl = std::move(smpl_save);
n_past = (int) prompt_tgt.size();
continue;
}
common_speculative_accept(spec, ids.size() - 1);
// full acceptance: consume the draft and commit accepted tokens
n_past += ids.size() - 1;
n_drafted += draft.size(); // note: we ignore the discarded small drafts
n_drafted += n_draft; // note: we ignore the discarded small drafts
n_accept += ids.size() - 1;
n_predict += ids.size();
@@ -222,6 +316,9 @@ int main(int argc, char ** argv) {
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
// clear the draft since it has been consumed
draft.clear();
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
@@ -254,11 +351,10 @@ int main(int argc, char ** argv) {
LOG_INF("\n");
LOG_INF("target:\n\n");
common_perf_print(ctx_tgt, smpl);
common_perf_print(ctx_tgt, smpl.get());
llama_batch_free(batch_tgt);
common_sampler_free(smpl);
common_speculative_free(spec);
llama_backend_free();