common : refactor common_sampler + grammar logic changes (#17937)
* common : refactor common_sampler + grammar logic changes * tests : increase max_tokens to get needed response * batched : fix uninitialized samplers
This commit is contained in:
@@ -153,7 +153,7 @@ struct server_slot {
|
||||
// sampling
|
||||
json json_schema;
|
||||
|
||||
struct common_sampler * smpl = nullptr;
|
||||
common_sampler_ptr smpl;
|
||||
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_tokens drafted;
|
||||
@@ -510,8 +510,8 @@ struct server_context_impl {
|
||||
common_params params_base;
|
||||
|
||||
// note: keep these alive - they determine the lifetime of the model, context, etc.
|
||||
common_init_result llama_init;
|
||||
common_init_result llama_init_dft;
|
||||
common_init_result_ptr llama_init;
|
||||
common_init_result_ptr llama_init_dft;
|
||||
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
@@ -557,9 +557,6 @@ struct server_context_impl {
|
||||
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
common_sampler_free(slot.smpl);
|
||||
slot.smpl = nullptr;
|
||||
|
||||
llama_free(slot.ctx_dft);
|
||||
slot.ctx_dft = nullptr;
|
||||
|
||||
@@ -580,8 +577,8 @@ struct server_context_impl {
|
||||
|
||||
llama_init = common_init_from_params(params_base);
|
||||
|
||||
model = llama_init.model.get();
|
||||
ctx = llama_init.context.get();
|
||||
model = llama_init->model();
|
||||
ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr) {
|
||||
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
|
||||
@@ -613,25 +610,25 @@ struct server_context_impl {
|
||||
|
||||
llama_init_dft = common_init_from_params(params_dft);
|
||||
|
||||
model_dft = llama_init_dft.model.get();
|
||||
model_dft = llama_init_dft->model();
|
||||
|
||||
if (model_dft == nullptr) {
|
||||
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
|
||||
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
|
||||
if (!vocab_dft_compatible) {
|
||||
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
|
||||
}
|
||||
|
||||
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
|
||||
const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
|
||||
|
||||
cparams_dft = common_context_params_to_llama(params_dft);
|
||||
cparams_dft.n_batch = n_ctx_dft;
|
||||
|
||||
// the context is not needed - we will create one for each slot
|
||||
llama_init_dft.context.reset();
|
||||
llama_init_dft->free_context();
|
||||
}
|
||||
|
||||
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
||||
@@ -1051,18 +1048,15 @@ struct server_context_impl {
|
||||
|
||||
// initialize samplers
|
||||
{
|
||||
if (slot.smpl != nullptr) {
|
||||
common_sampler_free(slot.smpl);
|
||||
}
|
||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||
|
||||
slot.smpl = common_sampler_init(model, task.params.sampling);
|
||||
if (slot.smpl == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||
}
|
||||
|
||||
// initialize draft batch
|
||||
@@ -1216,11 +1210,10 @@ struct server_context_impl {
|
||||
}
|
||||
|
||||
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
|
||||
size_t n_probs = slot.task->params.sampling.n_probs;
|
||||
size_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
const size_t n_probs = slot.task->params.sampling.n_probs;
|
||||
|
||||
if (post_sampling) {
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
|
||||
const size_t max_probs = cur_p->size;
|
||||
|
||||
// set probability for sampled token
|
||||
@@ -1245,7 +1238,7 @@ struct server_context_impl {
|
||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||
|
||||
// set probability for sampled token
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
for (size_t i = 0; i < cur.size(); i++) {
|
||||
// set probability for sampled token
|
||||
if (cur[i].id == result.tok) {
|
||||
result.prob = cur[i].p;
|
||||
@@ -1255,7 +1248,7 @@ struct server_context_impl {
|
||||
|
||||
// set probability for top n_probs tokens
|
||||
result.probs.reserve(n_probs);
|
||||
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||
for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur[i].id,
|
||||
common_token_to_piece(ctx, cur[i].id, special),
|
||||
@@ -2301,13 +2294,13 @@ struct server_context_impl {
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
common_sampler_reset(slot.smpl.get());
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.task->n_tokens(); ++i) {
|
||||
llama_token id = input_tokens[i];
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
common_sampler_accept(slot.smpl, id, false);
|
||||
common_sampler_accept(slot.smpl.get(), id, false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2525,11 +2518,11 @@ struct server_context_impl {
|
||||
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl, id, true);
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
@@ -2570,7 +2563,7 @@ struct server_context_impl {
|
||||
size_t n_draft = slot.drafted.size();
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user