diff --git a/common/speculative.cpp b/common/speculative.cpp index c114bccde..bda9993b1 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -467,7 +467,7 @@ struct common_speculative_state_draft : public common_speculative_state { prompt_dft.push_back(id_last); - LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); int ret = llama_decode(ctx_dft, batch); if (ret != 0 && ret != 1) { @@ -495,14 +495,14 @@ struct common_speculative_state_draft : public common_speculative_state { common_sampler_accept(smpl, id, true); - result.push_back(id); - - if (sparams.n_max <= (int) result.size()) { + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < sparams.p_min) { break; } - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < sparams.p_min) { + result.push_back(id); + + if (sparams.n_max <= (int) result.size()) { break; } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e3822225b..ee8366d28 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -354,6 +354,7 @@ struct server_slot { // generate a new draft spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); + n_draft_total += spec_draft.size(); if (spec_draft.size() > (size_t) n_draft_max) { SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); @@ -3019,7 +3020,6 @@ private: // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.n_draft_total += n_draft; // add accepted tokens to the prompt slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);