common : refactor common_sampler + grammar logic changes (#17937)

* common : refactor common_sampler + grammar logic changes

* tests : increase max_tokens to get needed response

* batched : fix uninitialized samplers
This commit is contained in:
Georgi Gerganov
2025-12-14 10:11:13 +02:00
committed by GitHub
parent 3238b1400c
commit 254098a279
27 changed files with 372 additions and 293 deletions
+20 -27
View File
@@ -153,7 +153,7 @@ struct server_slot {
// sampling
json json_schema;
struct common_sampler * smpl = nullptr;
common_sampler_ptr smpl;
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;
@@ -510,8 +510,8 @@ struct server_context_impl {
common_params params_base;
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result llama_init;
common_init_result llama_init_dft;
common_init_result_ptr llama_init;
common_init_result_ptr llama_init_dft;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
@@ -557,9 +557,6 @@ struct server_context_impl {
// Clear any sampling context
for (server_slot & slot : slots) {
common_sampler_free(slot.smpl);
slot.smpl = nullptr;
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
@@ -580,8 +577,8 @@ struct server_context_impl {
llama_init = common_init_from_params(params_base);
model = llama_init.model.get();
ctx = llama_init.context.get();
model = llama_init->model();
ctx = llama_init->context();
if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
@@ -613,25 +610,25 @@ struct server_context_impl {
llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model.get();
model_dft = llama_init_dft->model();
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
return false;
}
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
if (!vocab_dft_compatible) {
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
}
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;
// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
llama_init_dft->free_context();
}
chat_templates = common_chat_templates_init(model, params_base.chat_template);
@@ -1051,18 +1048,15 @@ struct server_context_impl {
// initialize samplers
{
if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl);
}
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
slot.smpl = common_sampler_init(model, task.params.sampling);
if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
}
// initialize draft batch
@@ -1216,11 +1210,10 @@ struct server_context_impl {
}
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
size_t n_probs = slot.task->params.sampling.n_probs;
size_t n_vocab = llama_vocab_n_tokens(vocab);
const size_t n_probs = slot.task->params.sampling.n_probs;
if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
const size_t max_probs = cur_p->size;
// set probability for sampled token
@@ -1245,7 +1238,7 @@ struct server_context_impl {
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
// set probability for sampled token
for (size_t i = 0; i < n_vocab; i++) {
for (size_t i = 0; i < cur.size(); i++) {
// set probability for sampled token
if (cur[i].id == result.tok) {
result.prob = cur[i].p;
@@ -1255,7 +1248,7 @@ struct server_context_impl {
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
result.probs.push_back({
cur[i].id,
common_token_to_piece(ctx, cur[i].id, special),
@@ -2301,13 +2294,13 @@ struct server_context_impl {
GGML_ASSERT(batch.n_tokens > 0);
common_sampler_reset(slot.smpl);
common_sampler_reset(slot.smpl.get());
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.task->n_tokens(); ++i) {
llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl, id, false);
common_sampler_accept(slot.smpl.get(), id, false);
}
}
@@ -2525,11 +2518,11 @@ struct server_context_impl {
const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
slot.i_batch = -1;
common_sampler_accept(slot.smpl, id, true);
common_sampler_accept(slot.smpl.get(), id, true);
slot.n_decoded += 1;
@@ -2570,7 +2563,7 @@ struct server_context_impl {
size_t n_draft = slot.drafted.size();
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear();
slot.drafted.clear();