server : refactor "use checkpoint" logic (#22114)

This commit is contained in:
Georgi Gerganov
2026-04-20 08:42:37 +03:00
committed by GitHub
parent 788fcbc5dd
commit de71b5f81c
7 changed files with 93 additions and 92 deletions
+14 -48
View File
@@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
bool use_ckpt = false;
struct common_speculative_checkpoint ckpt;
bool use_checkpoint;
common_sampler * smpl;
@@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt,
llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements,
bool use_checkpoint)
bool use_ckpt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, use_checkpoint(use_checkpoint)
, use_ckpt(use_ckpt)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr;
@@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state {
}
void begin(const llama_tokens & prompt) override {
if (use_checkpoint && ckpt.size() > 0) {
if (use_ckpt && ckpt.size() > 0) {
// delete checkpoint
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
@@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state {
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
__func__, reuse_i, reuse_n);
reuse_i = 0;
@@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state {
result.clear();
result.reserve(params.n_max);
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
@@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state {
}
if (reuse_n < (int) prompt_dft.size() || do_restore) {
if (use_checkpoint) {
if (use_ckpt) {
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
@@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
return it->second;
}
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
auto * mem = llama_get_memory(ctx_tgt);
if (mem == nullptr) {
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
}
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
llama_memory_clear(mem, true);
// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
goto done;
}
done:
llama_memory_clear(mem, true);
llama_synchronize(ctx_tgt);
return res;
}
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(
@@ -1022,11 +986,13 @@ common_speculative * common_speculative_init(
case COMMON_SPECULATIVE_TYPE_NONE:
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .use_ckpt = */ use_ckpt
));
break;
}