spec : add self‑speculative decoding (no draft model required) + refactor (#18471)
* server: introduce self-speculative decoding * server: moved self-call into speculative.cpp * can_speculate() includes self-speculation Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: can_speculate() tests self-spec * server: replace can_speculate() with slot.can_speculate() Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * common: use %zu format specifier for size_t in logging Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * server: can_speculate() requires a task instance * common: ngram map, config self-speculative decoding * common: add enum common_speculative_type * common: add vector of speculative states * common: add option --spec-draftless * server: cleanup (remove slot.batch_spec, rename) * common: moved self-spec impl to ngram-map * common: cleanup (use common_speculative_state_draft) * spec : refactor * cont : naming * spec: remove --spec-config * doc: (draftless) speculative decoding * common: print performance in spec decoding * minor : cleanup * common : better names * minor : cleanup + fix build * minor: comments * CODEOWNERS: add common/ngram-map.* (#18471) * common : rename speculative.draftless_type -> speculative.type * ngram-map : fix uninitialized values * ngram-map : take into account the input can become shorter * ngram-map : revert len check for now * arg : change `--spec-draftless` -> `--spec-type` * spec : add common_speculative_state::accept() * spec : refactor + add common_speculative_begin() * spec : fix begin() call with mtmd * spec : additional refactor + remove common_speculative_params --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
|
||||
log.h
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
ngram-map.cpp
|
||||
ngram-map.h
|
||||
peg-parser.cpp
|
||||
peg-parser.h
|
||||
preset.cpp
|
||||
|
||||
+74
-13
@@ -6,6 +6,7 @@
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "preset.h"
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
@@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
params.mmproj = res.mmproj;
|
||||
}
|
||||
// only download mmproj if the current example is using it
|
||||
for (auto & ex : mmproj_examples) {
|
||||
for (const auto & ex : mmproj_examples) {
|
||||
if (ctx_arg.ex == ex) {
|
||||
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
|
||||
break;
|
||||
}
|
||||
}
|
||||
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
}
|
||||
|
||||
// model is required (except for server)
|
||||
@@ -1216,16 +1217,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-lcs", "--lookup-cache-static"}, "FNAME",
|
||||
"path to static lookup cache to use for lookup decoding (not updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_static = value;
|
||||
params.speculative.lookup_cache_static = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
|
||||
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_dynamic = value;
|
||||
params.speculative.lookup_cache_dynamic = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-c", "--ctx-size"}, "N",
|
||||
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
|
||||
@@ -2563,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
|
||||
"Same as --hf-repo, but for the draft model (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.model.hf_repo = value;
|
||||
params.speculative.mparams_dft.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HFD_REPO"));
|
||||
add_opt(common_arg(
|
||||
@@ -3384,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-md", "--model-draft"}, "FNAME",
|
||||
"draft model for speculative decoding (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.model.path = value;
|
||||
params.speculative.mparams_dft.path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
@@ -3394,6 +3395,66 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.replacements.push_back({ tgt, dft });
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]",
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
|
||||
common_speculative_type_to_str(params.speculative.type).c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "none") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
} else if (value == "ngram-cache") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
|
||||
} else if (value == "ngram-simple") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
|
||||
} else if (value == "ngram-map-k") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
|
||||
} else if (value == "ngram-map-k4v") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
||||
} else {
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_n = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-m"}, "N",
|
||||
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_m = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-check-rate"}, "N",
|
||||
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram check rate must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_check_rate = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram min hits must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_min_hits = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
|
||||
string_format(
|
||||
@@ -3620,8 +3681,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
@@ -3636,8 +3697,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
|
||||
+4
-5
@@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
|
||||
params.tensor_split,
|
||||
params.tensor_buft_overrides.data(),
|
||||
params.fit_params_target.data(),
|
||||
params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
@@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
|
||||
+46
-18
@@ -164,6 +164,16 @@ enum common_params_sampling_config : uint64_t {
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||
};
|
||||
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
|
||||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
@@ -243,16 +253,35 @@ struct common_params_model {
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
// general-purpose speculative decoding parameters
|
||||
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
// ngram-based speculative decoding
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
|
||||
// draft-model speculative decoding
|
||||
|
||||
struct common_params_model mparams_dft;
|
||||
|
||||
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
|
||||
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
@@ -260,7 +289,14 @@ struct common_params_speculative {
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
||||
struct common_params_model model;
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
|
||||
bool has_dft() const {
|
||||
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
@@ -378,8 +414,6 @@ struct common_params {
|
||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
||||
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
||||
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
|
||||
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||
|
||||
// llama-debug specific options
|
||||
@@ -575,10 +609,6 @@ struct common_params {
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
|
||||
bool has_speculative() const {
|
||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -714,8 +744,6 @@ struct common_init_result {
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
|
||||
@@ -192,12 +192,12 @@ void common_ngram_cache_draft(
|
||||
break;
|
||||
}
|
||||
|
||||
LOG(" - draft candidate: token=%d\n", drafted_token);
|
||||
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
|
||||
draft.push_back(drafted_token);
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
|
||||
std::ofstream file_out(filename, std::ios::binary);
|
||||
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
|
||||
const common_ngram ngram = item.first;
|
||||
@@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
|
||||
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename) {
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
|
||||
std::ifstream hashmap_file(filename, std::ios::binary);
|
||||
if (!hashmap_file) {
|
||||
throw std::ifstream::failure("Unable to open file " + filename);
|
||||
|
||||
@@ -88,12 +88,12 @@ void common_ngram_cache_draft(
|
||||
// Save an ngram cache to a file.
|
||||
// ngram_cache: the ngram cache to save.
|
||||
// filename: the path under which to save the ngram cache.
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
|
||||
|
||||
// Load an ngram cache saved with common_ngram_cache_save.
|
||||
// filename: the path from which to load the ngram cache.
|
||||
// returns: an ngram cache containing the information saved to filename.
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename);
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename);
|
||||
|
||||
// Merge two ngram caches.
|
||||
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
|
||||
|
||||
@@ -0,0 +1,367 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "ngram-map.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
/**
|
||||
* Perform speculative generation using the model's own token history.
|
||||
* Searches for a matching pattern in the token history and returns draft tokens.
|
||||
*
|
||||
* @param state Current state of this implementation
|
||||
* @param tokens Token history to search in
|
||||
* @param sampled Last sampled token
|
||||
* @return Vector of draft tokens, empty if no matching pattern is found
|
||||
*/
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled) {
|
||||
|
||||
// Simple implementation of self-speculative decoding without a draft model.
|
||||
//
|
||||
const size_t cur_len = tokens.size();
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (state.idx_last_check + state.config.check_rate > cur_len) {
|
||||
llama_tokens draft_tokens;
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
|
||||
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
|
||||
|
||||
// vector for tokens we want to verify.
|
||||
// return empty vector if there is no match.
|
||||
llama_tokens draft_tokens;
|
||||
|
||||
// We need at least n_draft_min + n_draft_max + 1 tokens.
|
||||
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
// pattern search
|
||||
llama_tokens pattern;
|
||||
pattern.reserve(n_draft_min);
|
||||
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
|
||||
pattern.push_back(tokens[j]);
|
||||
}
|
||||
pattern.push_back(sampled); // add the last token to the pattern
|
||||
|
||||
// We do a search in the token history.
|
||||
state.idx_last_check = cur_len;
|
||||
|
||||
size_t match_pos = 0; // we ignore position 0, position 0 == no match
|
||||
// search backwards, but skip the current match (we are currently there)
|
||||
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < pattern.size(); ++k) {
|
||||
if (tokens[j + k] != pattern[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos == 0) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
const size_t copy_max = std::min(
|
||||
n_draft_max,
|
||||
cur_len - (match_pos + n_draft_min)
|
||||
);
|
||||
if (copy_max < n_draft_min) {
|
||||
return draft_tokens;
|
||||
}
|
||||
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
|
||||
__func__, cur_len,
|
||||
match_pos, pattern.size(), copy_max);
|
||||
|
||||
draft_tokens.reserve(copy_max);
|
||||
for (size_t j = 0; j < copy_max; ++j) {
|
||||
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
|
||||
}
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of counted values of a ngram map value.
|
||||
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
|
||||
|
||||
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
|
||||
|
||||
void common_ngram_map_draft(common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft) {
|
||||
// reset last key and value.
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = 0;
|
||||
map.last_draft_value_idx = 0;
|
||||
|
||||
const size_t cur_len = inp.size();
|
||||
const uint16_t n = map.size_key;
|
||||
const uint16_t m = map.size_value;
|
||||
if (cur_len < static_cast<size_t>(2 * n + m)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (map.idx_last_check + map.check_rate > cur_len) {
|
||||
return;
|
||||
}
|
||||
map.idx_last_check = cur_len;
|
||||
|
||||
// search pattern, the key n-gram
|
||||
std::vector<llama_token> key_tokens;
|
||||
key_tokens.reserve(n);
|
||||
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
|
||||
key_tokens.push_back(inp[j]);
|
||||
}
|
||||
key_tokens.push_back(sampled);
|
||||
|
||||
// search for the key in the map
|
||||
size_t match_pos = 0;
|
||||
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[j + k] != key_tokens[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos > 0) {
|
||||
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
|
||||
cur_len, n, m, key_tokens.size(), sampled, match_pos);
|
||||
}
|
||||
|
||||
if (match_pos == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We have a match, now we look for the statistics of the key.
|
||||
size_t key_offset = map.keys.size(); // offset in the map
|
||||
// We iterate through the std::vector<common_ngram_map_key> map->keys.
|
||||
for (size_t i = 0; i < map.keys.size(); ++i) {
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
key_offset = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (key_offset == map.keys.size()) {
|
||||
// We create a new key-entry, it will get offset key_offset.
|
||||
common_ngram_map_key new_key;
|
||||
new_key.key_idx = match_pos;
|
||||
new_key.stat_idx = 0;
|
||||
new_key.key_num = 0;
|
||||
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
|
||||
new_key.values[i].value_num = 0;
|
||||
new_key.values[i].n_accepted = m;
|
||||
}
|
||||
map.keys.push_back(new_key);
|
||||
}
|
||||
|
||||
// our key n-gram:
|
||||
common_ngram_map_key & curr_key = map.keys[key_offset];
|
||||
|
||||
// update number of key hits
|
||||
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
|
||||
if (map.key_only) {
|
||||
// simple mode:
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
// We work with value values[0] only.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = 0; // value 0 is used for simple mode
|
||||
return;
|
||||
}
|
||||
|
||||
if (curr_key.key_num < map.min_hits) {
|
||||
// not enough hits to consider this a good draft
|
||||
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
|
||||
key_offset, curr_key.key_num, map.min_hits);
|
||||
return;
|
||||
}
|
||||
|
||||
// complex mode: examine the different m-grams after this key n-gram.
|
||||
//
|
||||
|
||||
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
|
||||
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
|
||||
// begins the key n-gram at index i?
|
||||
bool match_key = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[i + k] != key_tokens[k]) {
|
||||
match_key = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do we haven a existing value m-gram or a new one after the key at index i?
|
||||
size_t idx_begin_value_key = i + n;
|
||||
int idx_value = -1;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
size_t idx_begin_value_v = curr_key.values[v].value_idx;
|
||||
if (idx_begin_value_v == 0) {
|
||||
// We found an empty value slot => we found a new value m-gram after the key n-gram.
|
||||
curr_key.values[v].value_idx = idx_begin_value_key;
|
||||
curr_key.values[v].value_num = 0;
|
||||
curr_key.values[v].n_accepted = m;
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < m; ++j) {
|
||||
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
// We found an existing value m-gram after the key n-gram.
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx_value >= 0) {
|
||||
// We found a value m-gram of the key n-gram.
|
||||
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
}
|
||||
}
|
||||
// the statistics are updated up to match_pos.
|
||||
curr_key.stat_idx = match_pos;
|
||||
|
||||
// Do we have a value we could use for the draft?
|
||||
uint16_t max_occur = 0;
|
||||
int slot_max = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
if (curr_occur > max_occur) {
|
||||
max_occur = curr_occur;
|
||||
slot_max = v;
|
||||
}
|
||||
}
|
||||
// What is sum of the other occurences?
|
||||
uint32_t sum_occur = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (v == slot_max) {
|
||||
continue;
|
||||
}
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
sum_occur += curr_occur;
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
|
||||
key_offset,
|
||||
max_occur, sum_occur, slot_max,
|
||||
curr_key.values[0].value_idx, curr_key.values[0].value_num,
|
||||
curr_key.values[1].value_idx, curr_key.values[1].value_num,
|
||||
curr_key.values[2].value_idx, curr_key.values[2].value_num,
|
||||
curr_key.values[3].value_idx, curr_key.values[3].value_num
|
||||
);
|
||||
// Print the tokens of the four values (if idx != 0), use LOG_INF
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (curr_key.values[v].value_idx != 0) {
|
||||
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
|
||||
// The most frequent value is not much more frequent than the other values.
|
||||
// We do not use the draft.
|
||||
return;
|
||||
}
|
||||
|
||||
// We use the most frequent value values[slot_max] for the draft.
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, slot_max,
|
||||
curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = true;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = slot_max; // value used for draft generation.
|
||||
}
|
||||
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
|
||||
if (!map.last_draft_created) {
|
||||
return;
|
||||
}
|
||||
|
||||
// find the key and its chosen value.
|
||||
const size_t key_idx = map.last_draft_key_idx;
|
||||
const size_t val_idx = map.last_draft_value_idx;
|
||||
|
||||
// find key corresponding to key_idx.
|
||||
common_ngram_map_key & curr_key = map.keys[key_idx];
|
||||
// find value corresponding to val_idx.
|
||||
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
|
||||
|
||||
// update the value statistics
|
||||
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
||||
n_accepted, curr_value.n_accepted);
|
||||
curr_value.n_accepted = n_accepted;
|
||||
}
|
||||
|
||||
// Helper functions.
|
||||
//
|
||||
|
||||
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
|
||||
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
|
||||
std::ostringstream oss;
|
||||
oss << '[';
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (i > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << inp[start + i];
|
||||
}
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
#pragma once
|
||||
//
|
||||
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
|
||||
//
|
||||
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
|
||||
//
|
||||
// There are two algorithms implemented:
|
||||
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
|
||||
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
|
||||
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
|
||||
//
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
// config of n-gram simple.
|
||||
struct common_ngram_simple_config {
|
||||
uint16_t size_ngram; // size of n-grams to lookup in self-mode
|
||||
uint16_t size_mgram; // size of m-grams to draft in self-mode
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
};
|
||||
|
||||
// current state (and config) of n-gram simple.
|
||||
struct common_ngram_simple_state {
|
||||
common_ngram_simple_config config;
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history (mutable)
|
||||
|
||||
common_ngram_simple_state(const common_ngram_simple_config & config)
|
||||
: config(config) {}
|
||||
};
|
||||
|
||||
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// state: the ngram simple state to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled);
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of m-gram values stored for each key n-gram.
|
||||
#define COMMON_NGRAM_MAX_VALUES 4
|
||||
|
||||
// statistics of a m-gram after a known n-gram
|
||||
struct common_ngram_map_value {
|
||||
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
|
||||
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
|
||||
};
|
||||
|
||||
// statistics of a n-gram
|
||||
struct common_ngram_map_key {
|
||||
size_t key_idx; // index of key n-gram in token-history
|
||||
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
|
||||
|
||||
uint16_t key_num; // number of occurences of this key n-gram in token-history
|
||||
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
|
||||
};
|
||||
|
||||
// map from n-grams to following m-grams in token-history
|
||||
struct common_ngram_map {
|
||||
uint16_t size_key; // size of key n-grams
|
||||
uint16_t size_value; // size of value m-grams
|
||||
|
||||
bool key_only; // true if only key n-grams are used, no values.
|
||||
|
||||
// first draft: vector only, no map.
|
||||
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
uint16_t min_hits; // minimum number of key hits to consider a draft
|
||||
|
||||
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
|
||||
uint16_t check_rate, uint16_t min_hits)
|
||||
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
|
||||
check_rate(check_rate), min_hits(min_hits) {}
|
||||
|
||||
bool last_draft_created = false; // true if a draft was created at last call.
|
||||
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
|
||||
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history
|
||||
};
|
||||
|
||||
|
||||
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// map: the ngram map to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
void common_ngram_map_draft(
|
||||
common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft);
|
||||
|
||||
// Update the statistics of a value after a draft was processed.
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);
|
||||
+766
-246
File diff suppressed because it is too large
Load Diff
+23
-21
@@ -5,31 +5,33 @@
|
||||
|
||||
struct common_speculative;
|
||||
|
||||
struct common_speculative_params {
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
// comma separated list of all types
|
||||
std::string common_speculative_type_name_str();
|
||||
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
};
|
||||
// convert string to type
|
||||
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft
|
||||
);
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec);
|
||||
common_speculative * common_speculative_init(
|
||||
const common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
|
||||
bool common_speculative_are_compatible(
|
||||
const struct llama_context * ctx_tgt,
|
||||
const struct llama_context * ctx_dft);
|
||||
void common_speculative_free(common_speculative * spec);
|
||||
|
||||
void common_speculative_add_replacement_tgt_dft(
|
||||
struct common_speculative * spec,
|
||||
const char *source, const char *dest);
|
||||
// optionally call once at the beginning of a new generation
|
||||
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
llama_tokens common_speculative_draft(
|
||||
common_speculative * spec,
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
Reference in New Issue
Block a user