spec : refactor params (#22397)
* spec : refactor params * cont : fix * cont : rename "sparam" to "sampling" * cont : add spec params category * cont : add info about removed arguments * cont : skip param length check for spec params * cont : adapt server tests
This commit is contained in:
@@ -309,8 +309,10 @@ struct server_slot {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative);
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = task->params.speculative.n_max;
|
||||
int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative);
|
||||
|
||||
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
@@ -322,8 +324,8 @@ struct server_slot {
|
||||
|
||||
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
||||
|
||||
if (n_draft_max < task->params.speculative.n_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
|
||||
if (n_draft_max < n_draft_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min);
|
||||
n_draft_max = 0;
|
||||
}
|
||||
|
||||
@@ -358,11 +360,6 @@ struct server_slot {
|
||||
spec_draft.resize(n_draft_max);
|
||||
}
|
||||
|
||||
if (spec_draft.size() < (size_t) params_spec.n_min) {
|
||||
SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min);
|
||||
spec_draft.clear();
|
||||
}
|
||||
|
||||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
@@ -770,9 +767,9 @@ private:
|
||||
|
||||
if (params_base.speculative.has_dft()) {
|
||||
// TODO speculative: move to common/speculative.cpp?
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
|
||||
const auto & params_spec = params_base.speculative.draft;
|
||||
|
||||
const auto & params_spec = params_base.speculative;
|
||||
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
@@ -780,7 +777,7 @@ private:
|
||||
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams_dft;
|
||||
params_dft.model = params_spec.mparams;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
params_dft.cache_type_k = params_spec.cache_type_k;
|
||||
params_dft.cache_type_v = params_spec.cache_type_v;
|
||||
@@ -800,8 +797,8 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
params_base.speculative.model_dft = model_dft.get();
|
||||
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
params_base.speculative.draft.model = model_dft.get();
|
||||
params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft);
|
||||
}
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
@@ -1310,7 +1307,7 @@ private:
|
||||
backend_sampling &= task.params.sampling.backend_sampling;
|
||||
|
||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||
backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0);
|
||||
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_logits;
|
||||
|
||||
Reference in New Issue
Block a user