server : speculative checkpointing (#19493)
* server : speculative decoding using checkpoints * server : fix draft check with checkpoints * server : rename spec vars * server : log levels * server : refactored spec logic to speculative.cpp * server : renamed spec checkpoints option * server : fix spec checkpoints, logging * speculative : checkpoints with draft model, logging * server : n_tokens_cur and create_checkpoint in draft * server : fix server_speculative_callback (slot.id) * spec : fix ngram-map/begin idx_last_check * spec : init ckpt (begin() wasn't called) * chore: update webui build output * server : restore sampler in spec checkpoint and clear mem * cont : avoid --spec-use-checkpoints argument * cont : remove server_prompt_checkpoint_with_size * spec : rename (leave_draft_state) * cont : clean-up * cont : do not ignore partial drafts even if the are short * cont : spec callback owned by session * cont : simplify * cont : avoid empty speculative session * cont : simplify * cont : simplify * cont : enable mtmd speculative decoding * cont : keep the spec sampler alive * cont : simplify * cont : fix nullptr deref + draft checkpoints * cont : remove common_speculative_accept_response * cont : remove callback * cont : simplify * cont : minor * cont : simplify * cont : fix accepted number --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
+1
-1
@@ -2334,7 +2334,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
|||||||
? input
|
? input
|
||||||
: params.generation_prompt + input;
|
: params.generation_prompt + input;
|
||||||
|
|
||||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
|
//LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
|
||||||
|
|
||||||
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
|
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
|
||||||
if (params.debug) {
|
if (params.debug) {
|
||||||
|
|||||||
+2
-2
@@ -11,7 +11,6 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <variant>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
@@ -303,7 +302,7 @@ struct common_params_speculative {
|
|||||||
// general-purpose speculative decoding parameters
|
// general-purpose speculative decoding parameters
|
||||||
|
|
||||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||||
float p_split = 0.1f; // speculative decoding split probability
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||||
|
|
||||||
@@ -312,6 +311,7 @@ struct common_params_speculative {
|
|||||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||||
|
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
|
||||||
|
|
||||||
std::shared_ptr<common_ngram_mod> ngram_mod;
|
std::shared_ptr<common_ngram_mod> ngram_mod;
|
||||||
|
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ void common_ngram_map_begin(
|
|||||||
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
|
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
|
||||||
}
|
}
|
||||||
|
|
||||||
map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
|
map.idx_last_check = size_begin;
|
||||||
map.size_last_begin = size_begin;
|
map.size_last_begin = size_begin;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
|||||||
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
|
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (map.idx_last_check > cur_len) {
|
if (map.idx_last_check > cur_len) {
|
||||||
// Should not happen because of common_ngram_map_begin().
|
// Should not happen because of common_ngram_map_begin().
|
||||||
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
|
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
|
||||||
}
|
}
|
||||||
@@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
|||||||
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
|
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
|
||||||
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
|
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
|
||||||
|
|
||||||
map.last_draft_created = false;
|
map.last_draft_created = true;
|
||||||
map.last_draft_key_idx = key_offset;
|
map.last_draft_key_idx = key_offset;
|
||||||
map.last_draft_value_idx = 0; // value 0 is used for simple mode
|
map.last_draft_value_idx = 0; // value 0 is used for simple mode
|
||||||
return;
|
return;
|
||||||
@@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
|
|||||||
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
|
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
|
||||||
|
|
||||||
// update the value statistics
|
// update the value statistics
|
||||||
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
||||||
n_accepted, curr_value.n_accepted);
|
n_accepted, curr_value.n_accepted);
|
||||||
curr_value.n_accepted = n_accepted;
|
curr_value.n_accepted = n_accepted;
|
||||||
}
|
}
|
||||||
|
|||||||
+139
-23
@@ -13,6 +13,7 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <cinttypes>
|
||||||
|
|
||||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||||
@@ -144,10 +145,28 @@ struct common_speculative_state {
|
|||||||
virtual void accept(uint16_t n_accepted) = 0;
|
virtual void accept(uint16_t n_accepted) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct common_speculative_checkpoint {
|
||||||
|
llama_pos pos_min = 0;
|
||||||
|
llama_pos pos_max = 0;
|
||||||
|
|
||||||
|
int64_t n_tokens = 0;
|
||||||
|
|
||||||
|
std::vector<uint8_t> data;
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return data.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ckpt_size = 0;
|
||||||
|
};
|
||||||
|
|
||||||
struct common_speculative_state_draft : public common_speculative_state {
|
struct common_speculative_state_draft : public common_speculative_state {
|
||||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||||
llama_context * ctx_dft;
|
llama_context * ctx_dft;
|
||||||
|
|
||||||
|
struct common_speculative_checkpoint ckpt;
|
||||||
|
bool use_checkpoint;
|
||||||
|
|
||||||
common_sampler * smpl;
|
common_sampler * smpl;
|
||||||
|
|
||||||
llama_batch batch;
|
llama_batch batch;
|
||||||
@@ -160,10 +179,12 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||||||
enum common_speculative_type type,
|
enum common_speculative_type type,
|
||||||
llama_context * ctx_tgt,
|
llama_context * ctx_tgt,
|
||||||
llama_context * ctx_dft,
|
llama_context * ctx_dft,
|
||||||
const std::vector<std::pair<std::string, std::string>> & replacements)
|
const std::vector<std::pair<std::string, std::string>> & replacements,
|
||||||
|
bool use_checkpoint)
|
||||||
: common_speculative_state(type)
|
: common_speculative_state(type)
|
||||||
, ctx_tgt(ctx_tgt)
|
, ctx_tgt(ctx_tgt)
|
||||||
, ctx_dft(ctx_dft)
|
, ctx_dft(ctx_dft)
|
||||||
|
, use_checkpoint(use_checkpoint)
|
||||||
{
|
{
|
||||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||||
smpl = nullptr;
|
smpl = nullptr;
|
||||||
@@ -218,7 +239,48 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void begin(const llama_tokens & prompt) override {
|
void begin(const llama_tokens & prompt) override {
|
||||||
GGML_UNUSED(prompt);
|
if (use_checkpoint && ckpt.size() > 0) {
|
||||||
|
// delete checkpoint
|
||||||
|
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
|
||||||
|
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||||
|
ckpt.pos_min = 0;
|
||||||
|
ckpt.pos_max = 0;
|
||||||
|
ckpt.n_tokens = 0;
|
||||||
|
ckpt.ckpt_size = 0;
|
||||||
|
ckpt.data.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
|
||||||
|
int slot_id = 0;
|
||||||
|
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
|
||||||
|
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
|
||||||
|
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
|
||||||
|
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
|
||||||
|
ckpt.data.resize(checkpoint_size);
|
||||||
|
|
||||||
|
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
if (n != checkpoint_size) {
|
||||||
|
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__,
|
||||||
|
ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024);
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
|
||||||
|
int slot_id = 0;
|
||||||
|
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
|
||||||
|
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
if (n != ckpt_size_part_expected) {
|
||||||
|
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||||
|
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||||
|
}
|
||||||
|
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
|
||||||
|
|
||||||
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
void draft(
|
void draft(
|
||||||
@@ -236,8 +298,8 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||||||
|
|
||||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||||
|
|
||||||
int reuse_i = 0;
|
int reuse_i = 0; // index of part to be reused in prompt_dft
|
||||||
int reuse_n = 0;
|
int reuse_n = 0; // length of part to be reused in prompt_dft
|
||||||
|
|
||||||
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
|
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
|
||||||
|
|
||||||
@@ -287,18 +349,26 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
|
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
|
||||||
|
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
|
||||||
|
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
|
||||||
|
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
|
||||||
|
__func__, reuse_i, reuse_n);
|
||||||
|
reuse_i = 0;
|
||||||
|
reuse_n = 0;
|
||||||
|
}
|
||||||
|
|
||||||
result.clear();
|
result.clear();
|
||||||
result.reserve(params.n_max);
|
result.reserve(params.n_max);
|
||||||
|
|
||||||
if (reuse_n == 0) {
|
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
|
||||||
|
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
|
||||||
llama_memory_clear(mem_dft, false);
|
llama_memory_clear(mem_dft, false);
|
||||||
prompt_dft.clear();
|
prompt_dft.clear();
|
||||||
} else {
|
} else {
|
||||||
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
||||||
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
||||||
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
|
if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
|
||||||
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
||||||
result.push_back(prompt_dft[i]);
|
result.push_back(prompt_dft[i]);
|
||||||
|
|
||||||
@@ -310,19 +380,50 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool do_restore = false;
|
||||||
|
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
|
||||||
|
// This can happen after a partial acceptance (speculative decoding with checkpoints)
|
||||||
|
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
|
||||||
|
__func__, prompt_dft.size(), prompt_cur.size());
|
||||||
|
prompt_dft.resize(prompt_cur.size());
|
||||||
|
do_restore = true;
|
||||||
|
}
|
||||||
|
|
||||||
if (reuse_i > 0) {
|
if (reuse_i > 0) {
|
||||||
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
||||||
|
if (!is_removed) {
|
||||||
|
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
|
||||||
|
}
|
||||||
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
||||||
|
|
||||||
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (reuse_n < (int) prompt_dft.size()) {
|
if (reuse_n < (int) prompt_dft.size() || do_restore) {
|
||||||
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
if (use_checkpoint) {
|
||||||
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
|
||||||
|
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
|
||||||
|
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
|
||||||
|
}
|
||||||
|
draft_restore_checkpoint(ckpt.ckpt_size);
|
||||||
|
reuse_n = ckpt.n_tokens;
|
||||||
|
prompt_dft.resize(reuse_n);
|
||||||
|
needs_ckpt = false;
|
||||||
|
} else {
|
||||||
|
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
||||||
|
if (!is_removed) {
|
||||||
|
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
|
||||||
|
__func__, reuse_n, prompt_dft.size());
|
||||||
|
}
|
||||||
|
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (needs_ckpt) {
|
||||||
|
ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
// prepare a batch to evaluate any new tokens in the prompt
|
// prepare a batch to evaluate any new tokens in the prompt
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
|
|
||||||
@@ -337,7 +438,11 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||||||
if (batch.n_tokens > 0) {
|
if (batch.n_tokens > 0) {
|
||||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||||
|
|
||||||
llama_decode(ctx_dft, batch);
|
int ret = llama_decode(ctx_dft, batch);
|
||||||
|
if (ret != 0 && ret != 1) {
|
||||||
|
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
|
||||||
|
__func__, ret, prompt_cur.size());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_pos n_past = prompt_dft.size();
|
const llama_pos n_past = prompt_dft.size();
|
||||||
@@ -351,7 +456,11 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||||||
|
|
||||||
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
||||||
|
|
||||||
llama_decode(ctx_dft, batch);
|
int ret = llama_decode(ctx_dft, batch);
|
||||||
|
if (ret != 0 && ret != 1) {
|
||||||
|
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
|
||||||
|
__func__, ret, prompt_cur.size(), prompt_dft.size());
|
||||||
|
}
|
||||||
|
|
||||||
common_sampler_reset(smpl);
|
common_sampler_reset(smpl);
|
||||||
|
|
||||||
@@ -387,7 +496,11 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||||||
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
||||||
|
|
||||||
// evaluate the drafted tokens on the draft model
|
// evaluate the drafted tokens on the draft model
|
||||||
llama_decode(ctx_dft, batch);
|
ret = llama_decode(ctx_dft, batch);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
|
||||||
|
__func__, i, ret, prompt_cur.size(), prompt_dft.size());
|
||||||
|
}
|
||||||
|
|
||||||
prompt_dft.push_back(id);
|
prompt_dft.push_back(id);
|
||||||
}
|
}
|
||||||
@@ -739,6 +852,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
|
|||||||
|
|
||||||
struct common_speculative {
|
struct common_speculative {
|
||||||
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
|
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
|
||||||
|
|
||||||
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -798,13 +912,13 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool common_speculative_is_compat(llama_context * ctx_tgt) {
|
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
|
||||||
auto * mem = llama_get_memory(ctx_tgt);
|
auto * mem = llama_get_memory(ctx_tgt);
|
||||||
if (mem == nullptr) {
|
if (mem == nullptr) {
|
||||||
return false;
|
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool res = true;
|
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
|
||||||
|
|
||||||
llama_memory_clear(mem, true);
|
llama_memory_clear(mem, true);
|
||||||
|
|
||||||
@@ -816,14 +930,14 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {
|
|||||||
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
|
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
|
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
|
||||||
res = false;
|
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
|
||||||
goto done;
|
goto done;
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to remove the last tokens
|
// try to remove the last tokens
|
||||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||||
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
|
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
|
||||||
res = false;
|
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
|
||||||
goto done;
|
goto done;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -909,9 +1023,10 @@ common_speculative * common_speculative_init(
|
|||||||
break;
|
break;
|
||||||
case COMMON_SPECULATIVE_TYPE_DRAFT: {
|
case COMMON_SPECULATIVE_TYPE_DRAFT: {
|
||||||
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||||
/* .ctx_tgt = */ ctx_tgt,
|
/* .ctx_tgt = */ ctx_tgt,
|
||||||
/* .ctx_dft = */ ctx_dft,
|
/* .ctx_dft = */ ctx_dft,
|
||||||
/* .replacements = */ params.replacements
|
/* .replacements = */ params.replacements,
|
||||||
|
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
|
||||||
));
|
));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -966,7 +1081,8 @@ common_speculative * common_speculative_init(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto * result = new common_speculative {
|
auto * result = new common_speculative {
|
||||||
/* .impls = */ std::move(impls)
|
/* .impls = */ std::move(impls),
|
||||||
|
/* .curr_impl = */ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|||||||
+13
-1
@@ -14,9 +14,15 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
|||||||
// convert type to string
|
// convert type to string
|
||||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||||
|
|
||||||
|
enum common_speculative_compat_type {
|
||||||
|
COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0,
|
||||||
|
COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1,
|
||||||
|
COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2,
|
||||||
|
};
|
||||||
|
|
||||||
// check if the llama_context is compatible for speculative decoding
|
// check if the llama_context is compatible for speculative decoding
|
||||||
// note: clears the memory of the context
|
// note: clears the memory of the context
|
||||||
bool common_speculative_is_compat(llama_context * ctx_tgt);
|
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt);
|
||||||
|
|
||||||
common_speculative * common_speculative_init(
|
common_speculative * common_speculative_init(
|
||||||
common_params_speculative & params,
|
common_params_speculative & params,
|
||||||
@@ -39,3 +45,9 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
|||||||
|
|
||||||
// print statistics about the speculative decoding
|
// print statistics about the speculative decoding
|
||||||
void common_speculative_print_stats(const common_speculative * spec);
|
void common_speculative_print_stats(const common_speculative * spec);
|
||||||
|
|
||||||
|
struct common_speculative_deleter {
|
||||||
|
void operator()(common_speculative * s) { common_speculative_free(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::unique_ptr<common_speculative, common_speculative_deleter> common_speculative_ptr;
|
||||||
|
|||||||
@@ -391,15 +391,25 @@ void server_tokens::push_back(server_tokens & tokens) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void server_tokens::insert(const llama_tokens & inp_tokens) {
|
void server_tokens::insert(const llama_tokens & inp_tokens) {
|
||||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
|
||||||
tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
|
tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_tokens & server_tokens::get_text_tokens() const {
|
const llama_tokens & server_tokens::get_tokens() const {
|
||||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
GGML_ASSERT(!has_mtmd);
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_tokens server_tokens::get_text_tokens() const {
|
||||||
|
llama_tokens res;
|
||||||
|
res.reserve(tokens.size());
|
||||||
|
for (llama_token t : tokens) {
|
||||||
|
if (t != LLAMA_TOKEN_NULL) {
|
||||||
|
res.push_back(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
void server_tokens::set_token(llama_pos pos, llama_token id) {
|
void server_tokens::set_token(llama_pos pos, llama_token id) {
|
||||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||||
tokens[pos] = id;
|
tokens[pos] = id;
|
||||||
|
|||||||
@@ -190,7 +190,9 @@ public:
|
|||||||
void insert(const llama_tokens & inp_tokens);
|
void insert(const llama_tokens & inp_tokens);
|
||||||
|
|
||||||
// for compatibility with speculative decoding, ctx shift, slot save/load
|
// for compatibility with speculative decoding, ctx shift, slot save/load
|
||||||
const llama_tokens & get_text_tokens() const;
|
const llama_tokens & get_tokens() const;
|
||||||
|
|
||||||
|
llama_tokens get_text_tokens() const;
|
||||||
|
|
||||||
// for compatibility with speculative decoding
|
// for compatibility with speculative decoding
|
||||||
void set_token(llama_pos pos, llama_token id);
|
void set_token(llama_pos pos, llama_token id);
|
||||||
|
|||||||
+232
-144
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
#include "server-context.h"
|
#include "server-context.h"
|
||||||
#include "server-common.h"
|
#include "server-common.h"
|
||||||
#include "server-http.h"
|
#include "server-http.h"
|
||||||
@@ -19,6 +20,7 @@
|
|||||||
#include <exception>
|
#include <exception>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
// fix problem with std::min and std::max
|
// fix problem with std::min and std::max
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
@@ -33,6 +35,31 @@ using json = nlohmann::ordered_json;
|
|||||||
|
|
||||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||||
|
|
||||||
|
static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
|
||||||
|
if (pos_min == -1) {
|
||||||
|
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
|
||||||
|
}
|
||||||
|
if (pos_max == -1) {
|
||||||
|
pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
|
||||||
|
auto cur = server_prompt_checkpoint {
|
||||||
|
/*.pos_min = */ pos_min,
|
||||||
|
/*.pos_max = */ pos_max,
|
||||||
|
/*.n_tokens = */ n_tokens,
|
||||||
|
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||||
|
};
|
||||||
|
|
||||||
|
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
if (n != checkpoint_size) {
|
||||||
|
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||||
enum slot_state {
|
enum slot_state {
|
||||||
SLOT_STATE_IDLE,
|
SLOT_STATE_IDLE,
|
||||||
@@ -57,7 +84,12 @@ struct server_slot {
|
|||||||
// multimodal
|
// multimodal
|
||||||
mtmd_context * mctx = nullptr;
|
mtmd_context * mctx = nullptr;
|
||||||
|
|
||||||
common_speculative * spec = nullptr;
|
// speculative decoding
|
||||||
|
llama_tokens spec_draft;
|
||||||
|
std::vector<int32_t> spec_i_batch;
|
||||||
|
server_prompt_checkpoint spec_ckpt;
|
||||||
|
common_speculative_ptr spec;
|
||||||
|
|
||||||
|
|
||||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||||
@@ -83,11 +115,6 @@ struct server_slot {
|
|||||||
std::string debug_generated_text;
|
std::string debug_generated_text;
|
||||||
llama_tokens generated_tokens;
|
llama_tokens generated_tokens;
|
||||||
|
|
||||||
// idx of draft tokens in the main batch
|
|
||||||
// non-empty if we went to evaluate draft tokens
|
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
|
|
||||||
std::vector<int32_t> i_batch_dft;
|
|
||||||
|
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
@@ -147,8 +174,7 @@ struct server_slot {
|
|||||||
|
|
||||||
common_sampler_ptr smpl;
|
common_sampler_ptr smpl;
|
||||||
|
|
||||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||||
llama_tokens drafted;
|
|
||||||
|
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
@@ -178,8 +204,11 @@ struct server_slot {
|
|||||||
stopping_word = "";
|
stopping_word = "";
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
|
|
||||||
drafted.clear();
|
if (can_speculate()) {
|
||||||
i_batch_dft.clear();
|
spec_draft.clear();
|
||||||
|
spec_i_batch.clear();
|
||||||
|
spec_ckpt.clear();
|
||||||
|
}
|
||||||
generated_tokens.clear();
|
generated_tokens.clear();
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
json_schema = json();
|
json_schema = json();
|
||||||
@@ -300,6 +329,85 @@ struct server_slot {
|
|||||||
return n_draft_max;
|
return n_draft_max;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void update_batch(llama_batch & batch) {
|
||||||
|
const int n_draft_max = get_n_draft_max();
|
||||||
|
if (n_draft_max > 0) {
|
||||||
|
GGML_ASSERT(can_speculate());
|
||||||
|
|
||||||
|
// generate draft tokens in speculative decoding mode
|
||||||
|
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||||
|
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||||
|
const llama_tokens & tokens = prompt.tokens.get_text_tokens();
|
||||||
|
|
||||||
|
const auto & params_spec = task->params.speculative;
|
||||||
|
|
||||||
|
if (!spec_draft.empty()) {
|
||||||
|
// we have a previous (partial) draft to reuse
|
||||||
|
if (task->params.speculative.use_checkpoints) {
|
||||||
|
GGML_ASSERT(!spec_ckpt.empty());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(spec_i_batch.empty());
|
||||||
|
|
||||||
|
// generate a new draft
|
||||||
|
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
|
||||||
|
|
||||||
|
if (spec_draft.size() > (size_t) n_draft_max) {
|
||||||
|
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
|
||||||
|
spec_draft.resize(n_draft_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (spec_draft.size() < (size_t) params_spec.n_min) {
|
||||||
|
SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min);
|
||||||
|
spec_draft.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!spec_draft.empty() && params_spec.use_checkpoints) {
|
||||||
|
const auto n_tokens = prompt.tokens.size();
|
||||||
|
|
||||||
|
auto & ckpt = spec_ckpt;
|
||||||
|
|
||||||
|
ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
|
||||||
|
|
||||||
|
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||||
|
ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (spec_draft.empty()) {
|
||||||
|
// no speculative decoding
|
||||||
|
i_batch = batch.n_tokens;
|
||||||
|
|
||||||
|
common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true);
|
||||||
|
|
||||||
|
SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
||||||
|
sampled, n_ctx, prompt.n_tokens(), truncated);
|
||||||
|
} else {
|
||||||
|
SLT_DBG(*this, "generate_draft: id=%d, #tokens=%zu, #draft=%zu, pos_next=%d\n",
|
||||||
|
sampled, prompt.tokens.size(), spec_draft.size(), prompt.tokens.pos_next());
|
||||||
|
|
||||||
|
GGML_ASSERT(spec_i_batch.empty());
|
||||||
|
|
||||||
|
spec_i_batch.push_back(batch.n_tokens);
|
||||||
|
for (size_t i = 0; i < spec_draft.size(); i++) {
|
||||||
|
spec_i_batch.push_back(batch.n_tokens + i + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pos0 = prompt.tokens.pos_next();
|
||||||
|
|
||||||
|
common_batch_add(batch, sampled, pos0++, { this->id }, true);
|
||||||
|
for (auto token : spec_draft) {
|
||||||
|
common_batch_add(batch, token, pos0++, { this->id }, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.tokens.push_back(sampled);
|
||||||
|
prompt.tokens.insert(spec_draft);
|
||||||
|
}
|
||||||
|
|
||||||
void release() {
|
void release() {
|
||||||
if (is_processing()) {
|
if (is_processing()) {
|
||||||
GGML_ASSERT(task);
|
GGML_ASSERT(task);
|
||||||
@@ -400,7 +508,7 @@ struct server_slot {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
common_speculative_print_stats(spec);
|
common_speculative_print_stats(spec.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json(bool only_metrics = false) const {
|
json to_json(bool only_metrics = false) const {
|
||||||
@@ -591,16 +699,17 @@ private:
|
|||||||
|
|
||||||
void destroy() {
|
void destroy() {
|
||||||
llama_init.reset();
|
llama_init.reset();
|
||||||
|
|
||||||
ctx = nullptr;
|
ctx = nullptr;
|
||||||
model = nullptr;
|
model = nullptr;
|
||||||
|
|
||||||
mtmd_free(mctx);
|
mtmd_free(mctx);
|
||||||
mctx = nullptr;
|
mctx = nullptr;
|
||||||
|
|
||||||
// Clear any sampling context
|
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
common_speculative_free(slot.spec);
|
if (slot.can_speculate()) {
|
||||||
slot.spec = nullptr;
|
slot.spec.reset();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
@@ -642,9 +751,6 @@ private:
|
|||||||
|
|
||||||
llama_init = common_init_from_params(params_base);
|
llama_init = common_init_from_params(params_base);
|
||||||
|
|
||||||
// propagate model-metadata sampling defaults back to caller
|
|
||||||
params.sampling = params_base.sampling;
|
|
||||||
|
|
||||||
model = llama_init->model();
|
model = llama_init->model();
|
||||||
ctx = llama_init->context();
|
ctx = llama_init->context();
|
||||||
|
|
||||||
@@ -660,6 +766,7 @@ private:
|
|||||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||||
|
|
||||||
if (params_base.speculative.has_dft()) {
|
if (params_base.speculative.has_dft()) {
|
||||||
|
// TODO speculative: move to common/speculative.cpp?
|
||||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
|
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
|
||||||
|
|
||||||
const auto & params_spec = params_base.speculative;
|
const auto & params_spec = params_base.speculative;
|
||||||
@@ -727,11 +834,6 @@ private:
|
|||||||
params_base.n_cache_reuse = 0;
|
params_base.n_cache_reuse = 0;
|
||||||
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) {
|
|
||||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
||||||
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
|
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
|
||||||
@@ -769,14 +871,23 @@ private:
|
|||||||
|
|
||||||
slots.clear();
|
slots.clear();
|
||||||
|
|
||||||
const bool can_spec = common_speculative_is_compat(ctx);
|
const auto spec_type = common_speculative_is_compat(ctx);
|
||||||
if (!can_spec) {
|
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
|
||||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) {
|
||||||
|
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
|
||||||
|
params_base.speculative.use_checkpoints = true;
|
||||||
|
}
|
||||||
|
|
||||||
// initialize slots
|
// initialize slots
|
||||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||||
server_slot slot;
|
slots.emplace_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||||
|
server_slot & slot = slots[i];
|
||||||
|
|
||||||
slot.id = i;
|
slot.id = i;
|
||||||
slot.ctx = ctx;
|
slot.ctx = ctx;
|
||||||
@@ -786,16 +897,11 @@ private:
|
|||||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||||
|
|
||||||
// try speculative decoding
|
// try speculative decoding
|
||||||
if (can_spec) {
|
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
|
||||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
|
||||||
|
|
||||||
if (slot.spec) {
|
if (slot.spec) {
|
||||||
if (mctx) {
|
|
||||||
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||||
} else {
|
|
||||||
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -806,8 +912,6 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
slot.reset();
|
slot.reset();
|
||||||
|
|
||||||
slots.push_back(std::move(slot));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -854,6 +958,9 @@ private:
|
|||||||
model_aliases = params_base.model_alias;
|
model_aliases = params_base.model_alias;
|
||||||
model_tags = params_base.model_tags;
|
model_tags = params_base.model_tags;
|
||||||
|
|
||||||
|
// propagate new defaults back to caller
|
||||||
|
params = params_base;
|
||||||
|
|
||||||
if (!is_resume) {
|
if (!is_resume) {
|
||||||
return init();
|
return init();
|
||||||
}
|
}
|
||||||
@@ -1197,7 +1304,7 @@ private:
|
|||||||
backend_sampling &= task.params.sampling.backend_sampling;
|
backend_sampling &= task.params.sampling.backend_sampling;
|
||||||
|
|
||||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||||
backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
|
backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0);
|
||||||
|
|
||||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||||
backend_sampling &= !need_logits;
|
backend_sampling &= !need_logits;
|
||||||
@@ -1703,6 +1810,26 @@ private:
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// n_tokens_cur: the number of tokens added to the batch for the current slot
|
||||||
|
void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) {
|
||||||
|
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||||
|
// make room for the new checkpoint, if needed
|
||||||
|
const auto & cur = slot.prompt.checkpoints.front();
|
||||||
|
|
||||||
|
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||||
|
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
|
||||||
|
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max));
|
||||||
|
|
||||||
|
SLT_WRN(slot,
|
||||||
|
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||||
|
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
|
||||||
|
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
}
|
||||||
|
|
||||||
void process_single_task(server_task && task) {
|
void process_single_task(server_task && task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_COMPLETION:
|
||||||
@@ -1854,7 +1981,7 @@ private:
|
|||||||
std::string filename = task.slot_action.filename;
|
std::string filename = task.slot_action.filename;
|
||||||
std::string filepath = task.slot_action.filepath;
|
std::string filepath = task.slot_action.filepath;
|
||||||
|
|
||||||
const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
|
const llama_tokens & tokens = slot->prompt.tokens.get_tokens();
|
||||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
|
||||||
|
|
||||||
const int64_t t_end = ggml_time_us();
|
const int64_t t_end = ggml_time_us();
|
||||||
@@ -2061,7 +2188,7 @@ private:
|
|||||||
{
|
{
|
||||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||||
|
|
||||||
llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
|
llama_tokens new_tokens = slot.prompt.tokens.get_tokens(); // copy
|
||||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||||
new_tokens[i - n_discard] = new_tokens[i];
|
new_tokens[i - n_discard] = new_tokens[i];
|
||||||
}
|
}
|
||||||
@@ -2100,61 +2227,7 @@ private:
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate draft tokens in speculative decoding mode
|
slot.update_batch(batch);
|
||||||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
|
||||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
|
||||||
const int n_draft_max = slot.get_n_draft_max();
|
|
||||||
if (n_draft_max > 0) {
|
|
||||||
if (mctx) {
|
|
||||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
|
||||||
GGML_ABORT("not supported by multimodal");
|
|
||||||
}
|
|
||||||
|
|
||||||
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
|
||||||
|
|
||||||
const auto & params_spec = slot.task->params.speculative;
|
|
||||||
|
|
||||||
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
|
||||||
|
|
||||||
if (draft.size() > (size_t) n_draft_max) {
|
|
||||||
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
|
|
||||||
draft.resize(n_draft_max);
|
|
||||||
}
|
|
||||||
|
|
||||||
// add the sampled token to the batch
|
|
||||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
||||||
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
||||||
slot.prompt.tokens.push_back(slot.sampled);
|
|
||||||
|
|
||||||
if (slot.task->params.speculative.n_min > (int) draft.size()) {
|
|
||||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
|
|
||||||
// fallback to normal decoding
|
|
||||||
slot.i_batch = slot.i_batch_dft[0];
|
|
||||||
slot.drafted.clear();
|
|
||||||
slot.i_batch_dft.clear();
|
|
||||||
} else {
|
|
||||||
// keep track of total number of drafted tokens tested
|
|
||||||
slot.n_draft_total += draft.size();
|
|
||||||
|
|
||||||
// add all drafted tokens to the batch
|
|
||||||
for (size_t i = 0; i < draft.size(); i++) {
|
|
||||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
||||||
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
||||||
slot.prompt.tokens.push_back(draft[i]);
|
|
||||||
}
|
|
||||||
slot.drafted = std::move(draft);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// no speculative decoding
|
|
||||||
slot.i_batch = batch.n_tokens;
|
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
||||||
|
|
||||||
slot.prompt.tokens.push_back(slot.sampled);
|
|
||||||
|
|
||||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
|
||||||
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
@@ -2651,40 +2724,12 @@ private:
|
|||||||
|
|
||||||
// no need to create checkpoints that are too close together
|
// no need to create checkpoints that are too close together
|
||||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
|
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
|
||||||
|
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
|
||||||
|
|
||||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||||
// yet processed and therefore it is not part of the checkpoint.
|
// yet processed and therefore it is not part of the checkpoint.
|
||||||
if (do_checkpoint) {
|
if (do_checkpoint) {
|
||||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
create_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
|
||||||
// make room for the new checkpoint, if needed
|
|
||||||
const auto & cur = slot.prompt.checkpoints.front();
|
|
||||||
|
|
||||||
SLT_WRN(slot,
|
|
||||||
"erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
|
||||||
", size = %.3f MiB)\n",
|
|
||||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
|
||||||
|
|
||||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t checkpoint_size =
|
|
||||||
llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
||||||
|
|
||||||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
|
||||||
/*.pos_min = */ pos_min,
|
|
||||||
/*.pos_max = */ pos_max,
|
|
||||||
/*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur,
|
|
||||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
|
||||||
});
|
|
||||||
|
|
||||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
|
|
||||||
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
||||||
|
|
||||||
SLT_WRN(slot,
|
|
||||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
|
||||||
", size = %.3f MiB)\n",
|
|
||||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
|
|
||||||
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2856,19 +2901,19 @@ private:
|
|||||||
slot.state = SLOT_STATE_GENERATING;
|
slot.state = SLOT_STATE_GENERATING;
|
||||||
|
|
||||||
if (slot.can_speculate()) {
|
if (slot.can_speculate()) {
|
||||||
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
|
common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens());
|
||||||
}
|
}
|
||||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.i_batch_dft.size() > 0) {
|
if (slot.can_speculate() && !slot.spec_draft.empty()) {
|
||||||
continue; // sample using speculative decoding
|
continue; // sample using speculative decoding
|
||||||
}
|
}
|
||||||
|
|
||||||
const int tok_idx = slot.i_batch - i;
|
const int tok_idx = slot.i_batch - i;
|
||||||
|
|
||||||
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx);
|
||||||
|
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
|
|
||||||
@@ -2889,7 +2934,7 @@ private:
|
|||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
result.tok = id;
|
result.tok = id;
|
||||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok));
|
||||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||||
|
|
||||||
if (slot.task->params.sampling.n_probs > 0) {
|
if (slot.task->params.sampling.n_probs > 0) {
|
||||||
@@ -2909,43 +2954,86 @@ private:
|
|||||||
|
|
||||||
// speculative decoding - main model sample and accept
|
// speculative decoding - main model sample and accept
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
|
if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t n_draft = slot.drafted.size();
|
// save the original draft size
|
||||||
|
const size_t n_draft = slot.spec_draft.size();
|
||||||
|
|
||||||
// the accepted tokens from the speculation
|
GGML_ASSERT(n_draft > 0);
|
||||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
|
|
||||||
slot.i_batch_dft.clear();
|
// verify and try to accept the draft
|
||||||
slot.drafted.clear();
|
{
|
||||||
|
const auto & params_spec = slot.task->params.speculative;
|
||||||
|
|
||||||
|
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
|
||||||
|
|
||||||
|
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
||||||
|
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
|
||||||
|
slot.spec_i_batch.clear();
|
||||||
|
|
||||||
|
SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size());
|
||||||
|
|
||||||
|
GGML_ASSERT(accepted.size() >= 1);
|
||||||
|
|
||||||
|
// check for partial draft acceptance
|
||||||
|
if (accepted.size() < slot.spec_draft.size() + 1) {
|
||||||
|
if (params_spec.use_checkpoints) {
|
||||||
|
// partial acceptance is not supported by the context -> truncate the draft and restore the state
|
||||||
|
slot.spec_draft = std::move(accepted);
|
||||||
|
|
||||||
|
auto & ckpt = slot.spec_ckpt;
|
||||||
|
|
||||||
|
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||||
|
|
||||||
|
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
if (n != ckpt.size()) {
|
||||||
|
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||||
|
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
|
||||||
|
|
||||||
|
slot.prompt.tokens.keep_first(ckpt.n_tokens);
|
||||||
|
slot.smpl = std::move(smpl_save);
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
common_speculative_accept(slot.spec.get(), accepted.size() - 1);
|
||||||
|
|
||||||
|
slot.spec_draft = std::move(accepted);
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t t_current = ggml_time_us();
|
const int64_t t_current = ggml_time_us();
|
||||||
|
|
||||||
slot.n_decoded += ids.size();
|
const auto ids = std::move(slot.spec_draft);
|
||||||
|
|
||||||
|
slot.n_decoded += ids.size();
|
||||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||||
|
|
||||||
// update how many tokens out of those tested were accepted
|
// update how many tokens out of those tested were accepted
|
||||||
slot.n_draft_accepted += ids.size() - 1;
|
slot.n_draft_accepted += ids.size() - 1;
|
||||||
|
slot.n_draft_total += n_draft;
|
||||||
// inform the speculative decoding about the number of accepted tokens
|
|
||||||
common_speculative_accept(slot.spec, ids.size() - 1);
|
|
||||||
|
|
||||||
// rollback to the state before sampling the draft tokens
|
|
||||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
|
||||||
|
|
||||||
// add accepted tokens to the prompt
|
// add accepted tokens to the prompt
|
||||||
|
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||||
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
||||||
slot.sampled = ids.back(); // last accepted token
|
|
||||||
|
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
|
slot.sampled = ids.back(); // last accepted token
|
||||||
|
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||||
|
|
||||||
|
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.n_tokens(), -1);
|
||||||
|
|
||||||
for (size_t i = 0; i < ids.size(); ++i) {
|
for (size_t i = 0; i < ids.size(); ++i) {
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
|
|
||||||
result.tok = ids[i];
|
result.tok = ids[i];
|
||||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok));
|
||||||
result.prob = 1.0f; // set later
|
result.prob = 1.0f; // set later
|
||||||
|
|
||||||
// TODO: set result.probs
|
// TODO: set result.probs
|
||||||
@@ -3665,7 +3753,7 @@ void server_routes::init_routes() {
|
|||||||
params.n_predict,
|
params.n_predict,
|
||||||
meta->slot_n_ctx,
|
meta->slot_n_ctx,
|
||||||
params.spm_infill,
|
params.spm_infill,
|
||||||
tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
|
tokenized_prompts[0].get_tokens() // TODO: this could maybe be multimodal.
|
||||||
);
|
);
|
||||||
|
|
||||||
std::vector<raw_buffer> files; // dummy
|
std::vector<raw_buffer> files; // dummy
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ common_chat_msg task_result_state::update_chat_msg(
|
|||||||
bool filter_tool_calls) {
|
bool filter_tool_calls) {
|
||||||
generated_text += text_added;
|
generated_text += text_added;
|
||||||
auto msg_prv_copy = chat_msg;
|
auto msg_prv_copy = chat_msg;
|
||||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
//SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||||
auto new_msg = common_chat_parse(
|
auto new_msg = common_chat_parse(
|
||||||
generated_text,
|
generated_text,
|
||||||
is_partial,
|
is_partial,
|
||||||
@@ -304,6 +304,8 @@ task_params server_task::params_from_json_cmpl(
|
|||||||
params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
|
params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
|
||||||
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
||||||
|
|
||||||
|
params.speculative = defaults.speculative;
|
||||||
|
|
||||||
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||||
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
||||||
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
||||||
|
|||||||
@@ -576,6 +576,17 @@ struct server_prompt_checkpoint {
|
|||||||
size_t size() const {
|
size_t size() const {
|
||||||
return data.size();
|
return data.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return data.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
pos_min = 0;
|
||||||
|
pos_max = 0;
|
||||||
|
n_tokens = 0;
|
||||||
|
data.clear();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_prompt {
|
struct server_prompt {
|
||||||
|
|||||||
Reference in New Issue
Block a user