spec : refactor params (#22397)

* spec : refactor params

* cont : fix

* cont : rename "sparam" to "sampling"

* cont : add spec params category

* cont : add info about removed arguments

* cont : skip param length check for spec params

* cont : adapt server tests
This commit is contained in:
Georgi Gerganov
2026-04-28 09:07:33 +03:00
committed by GitHub
parent 516e8d7a8a
commit 14e733e36f
18 changed files with 661 additions and 409 deletions
+11 -14
View File
@@ -309,8 +309,10 @@ struct server_slot {
return 0;
}
const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative);
// determine the max draft that fits the current slot state
int n_draft_max = task->params.speculative.n_max;
int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative);
// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
@@ -322,8 +324,8 @@ struct server_slot {
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < task->params.speculative.n_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
if (n_draft_max < n_draft_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min);
n_draft_max = 0;
}
@@ -358,11 +360,6 @@ struct server_slot {
spec_draft.resize(n_draft_max);
}
if (spec_draft.size() < (size_t) params_spec.n_min) {
SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min);
spec_draft.clear();
}
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size();
@@ -770,9 +767,9 @@ private:
if (params_base.speculative.has_dft()) {
// TODO speculative: move to common/speculative.cpp?
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
const auto & params_spec = params_base.speculative.draft;
const auto & params_spec = params_base.speculative;
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
auto params_dft = params_base;
@@ -780,7 +777,7 @@ private:
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams_dft;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
params_dft.cache_type_k = params_spec.cache_type_k;
params_dft.cache_type_v = params_spec.cache_type_v;
@@ -800,8 +797,8 @@ private:
return false;
}
params_base.speculative.model_dft = model_dft.get();
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
params_base.speculative.draft.model = model_dft.get();
params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft);
}
std::string & mmproj_path = params_base.mmproj.path;
@@ -1310,7 +1307,7 @@ private:
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0);
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
+10 -18
View File
@@ -76,13 +76,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.ngram_size_n", speculative.ngram_size_n},
{"speculative.ngram_size_m", speculative.ngram_size_m},
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -139,13 +133,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.ngram_size_n", speculative.ngram_size_n},
{"speculative.ngram_size_m", speculative.ngram_size_m},
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -308,14 +296,17 @@ task_params server_task::params_from_json_cmpl(
params.speculative = defaults.speculative;
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
// TODO: for now, be able to adjust only the draft-model based speculative parameters
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
params.speculative.draft.p_min = json_value(data, "speculative.p_min", defaults.speculative.draft.p_min);
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
params.speculative.n_min = std::max(params.speculative.n_min, 0);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
params.speculative.draft.n_min = std::min(params.speculative.draft.n_max, params.speculative.draft.n_min);
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
#if 0
// for debugging and research purposes
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
@@ -325,6 +316,7 @@ task_params server_task::params_from_json_cmpl(
params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024);
params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024);
params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024);
#endif
// Use OpenAI API logprobs only if n_probs wasn't provided
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
+6 -9
View File
@@ -83,15 +83,14 @@ class ServerProcess:
kv_unified: bool | None = False
server_slots: bool | None = False
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
models_dir: str | None = None
models_max: int | None = None
no_models_autoload: bool | None = None
lora_files: List[str] | None = None
enable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
spec_draft_n_min: int | None = None
spec_draft_n_max: int | None = None
no_webui: bool | None = None
jinja: bool | None = None
reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
@@ -165,8 +164,6 @@ class ServerProcess:
server_args.extend(["--threads", self.n_threads])
if self.n_gpu_layer:
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
if self.draft is not None:
server_args.extend(["--draft", self.draft])
if self.server_continuous_batching:
server_args.append("--cont-batching")
if self.server_embeddings:
@@ -214,10 +211,10 @@ class ServerProcess:
server_args.append("--context-shift")
if self.api_key:
server_args.extend(["--api-key", self.api_key])
if self.draft_max:
server_args.extend(["--draft-max", self.draft_max])
if self.draft_min:
server_args.extend(["--draft-min", self.draft_min])
if self.spec_draft_n_max:
server_args.extend(["--spec-draft-n-max", self.spec_draft_n_max])
if self.spec_draft_n_min:
server_args.extend(["--spec-draft-n-min", self.spec_draft_n_min])
if self.no_webui:
server_args.append("--no-webui")
if self.no_models_autoload: