common : inhibit lazy grammar sampler while reasoning is active (#20970)
* common : inhibit grammar while reasoning budget is active * cont : update force_pos in accept * cont : fix tests * cont : tweak should apply logic * cont : return early not using grammar sampler * Add tests * cont : prevent backend sampling when reasoning budget enabled * cont : fix typo --------- Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
This commit is contained in:
+46
-10
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
@@ -109,6 +110,7 @@ struct common_sampler {
|
||||
common_params_sampling params;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
struct llama_sampler * rbudget;
|
||||
struct llama_sampler * chain;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
@@ -188,6 +190,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
llama_sampler * grmr = nullptr;
|
||||
llama_sampler * rbudget = nullptr;
|
||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
@@ -270,7 +273,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr) {
|
||||
if (grmr && !params.grammar_lazy) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
@@ -284,15 +287,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler — added first so it can force tokens before other samplers
|
||||
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
|
||||
samplers.push_back(common_reasoning_budget_init(
|
||||
// reasoning budget sampler
|
||||
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
|
||||
rbudget = common_reasoning_budget_init(
|
||||
vocab,
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens,
|
||||
prefill_tokens));
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
|
||||
prefill_tokens);
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
@@ -383,6 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .rbudget = */ rbudget,
|
||||
/* .chain = */ chain,
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
@@ -398,11 +402,27 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
}
|
||||
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
llama_sampler_free(gsmpl->rbudget);
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
delete gsmpl;
|
||||
}
|
||||
|
||||
static bool grammar_should_apply(struct common_sampler * gsmpl) {
|
||||
if (!gsmpl->grmr) {
|
||||
return false;
|
||||
}
|
||||
if (!gsmpl->rbudget) {
|
||||
return true;
|
||||
}
|
||||
if (gsmpl->params.grammar_lazy) {
|
||||
// if grammar is lazy, only apply when reasoning budget is not active
|
||||
const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
|
||||
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
@@ -410,6 +430,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
|
||||
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
|
||||
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
|
||||
if (gsmpl->grmr && accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
@@ -431,6 +456,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
return new common_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .rbudget = */ llama_sampler_clone(gsmpl->rbudget),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
@@ -500,6 +526,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
llama_token id = LLAMA_TOKEN_NULL;
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & rbudget = gsmpl->rbudget;
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
@@ -511,7 +538,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
@@ -524,7 +552,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
if (grammar_first) {
|
||||
// apply reasoning budget first
|
||||
llama_sampler_apply(rbudget, &cur_p);
|
||||
|
||||
if (grammar_first && grammar_should_apply(gsmpl)) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
@@ -532,7 +563,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
|
||||
id = cur_p.data[cur_p.selected].id;
|
||||
|
||||
if (grammar_first) {
|
||||
if (grammar_first || !grammar_should_apply(gsmpl)) {
|
||||
return id;
|
||||
}
|
||||
|
||||
@@ -553,7 +584,12 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(rbudget, &cur_p);
|
||||
|
||||
if (grammar_should_apply(gsmpl)) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
||||
Reference in New Issue
Block a user