spec : refactor params (#22397)

* spec : refactor params

* cont : fix

* cont : rename "sparam" to "sampling"

* cont : add spec params category

* cont : add info about removed arguments

* cont : skip param length check for spec params

* cont : adapt server tests
This commit is contained in:
Georgi Gerganov
2026-04-28 09:07:33 +03:00
committed by GitHub
parent 516e8d7a8a
commit 14e733e36f
18 changed files with 661 additions and 409 deletions
+377 -217
View File
File diff suppressed because it is too large Load Diff
+4 -2
View File
@@ -25,7 +25,8 @@ struct common_arg {
const char * value_hint_2 = nullptr; // for second arg value
const char * env = nullptr;
std::string help;
bool is_sparam = false; // is current arg a sampling param?
bool is_sampling = false; // is current arg a sampling param?
bool is_spec = false; // is current arg a speculative decoding param?
bool is_preset_only = false; // is current arg preset-only (not treated as CLI arg)
void (*handler_void) (common_params & params) = nullptr;
void (*handler_string) (common_params & params, const std::string &) = nullptr;
@@ -74,7 +75,8 @@ struct common_arg {
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
common_arg & set_env(const char * env);
common_arg & set_sparam();
common_arg & set_sampling();
common_arg & set_spec();
common_arg & set_preset_only();
bool in_example(enum llama_example ex);
bool is_exclude(enum llama_example ex);
+7 -7
View File
@@ -70,7 +70,7 @@ common_time_meas::~common_time_meas() {
// CPU utils
//
int32_t cpu_get_num_physical_cores() {
int32_t common_cpu_get_num_physical_cores() {
#ifdef __linux__
// enumerate the set of thread siblings, num entries is num cores
std::unordered_set<std::string> siblings;
@@ -185,11 +185,11 @@ static int cpu_count_math_cpus(int n_cpu) {
/**
* Returns number of CPUs on system that are useful for math.
*/
int32_t cpu_get_num_math() {
int32_t common_cpu_get_num_math() {
#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
if (n_cpu < 1) {
return cpu_get_num_physical_cores();
return common_cpu_get_num_physical_cores();
}
if (is_hybrid_cpu()) {
cpu_set_t affinity;
@@ -202,7 +202,7 @@ int32_t cpu_get_num_math() {
}
}
#endif
return cpu_get_num_physical_cores();
return common_cpu_get_num_physical_cores();
}
// Helper for setting process priority
@@ -263,7 +263,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
//
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_params * role_model) {
int32_t n_set = 0;
if (cpuparams.n_threads < 0) {
@@ -271,7 +271,7 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model)
if (role_model != nullptr) {
cpuparams = *role_model;
} else {
cpuparams.n_threads = cpu_get_num_math();
cpuparams.n_threads = common_cpu_get_num_math();
}
}
@@ -1521,7 +1521,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
return cparams;
}
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const common_cpu_params & params) {
struct ggml_threadpool_params tpp;
ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
+56 -36
View File
@@ -54,7 +54,7 @@ struct common_control_vector_load_info;
// CPU utils
//
struct cpu_params {
struct common_cpu_params {
int n_threads = -1;
bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
bool mask_valid = false; // Default: any CPU
@@ -63,8 +63,8 @@ struct cpu_params {
uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
};
int32_t cpu_get_num_physical_cores();
int32_t cpu_get_num_math();
int32_t common_cpu_get_num_physical_cores();
int32_t common_cpu_get_num_math();
//
// Common params
@@ -297,34 +297,19 @@ struct common_params_model {
struct common_ngram_mod;
struct common_params_speculative {
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
// draft-model-based speculative decoding parameters
struct common_params_speculative_draft {
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
// general-purpose speculative decoding parameters
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
common_params_model mparams;
// ngram-based speculative decoding
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
std::shared_ptr<common_ngram_mod> ngram_mod;
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
// draft-model speculative decoding
struct common_params_model mparams_dft;
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
llama_context_params cparams; // these are the parameters for the draft llama_context
int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
@@ -332,25 +317,60 @@ struct common_params_speculative {
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
common_cpu_params cpuparams;
common_cpu_params cpuparams_batch;
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
};
struct common_params_speculative_ngram_mod {
int32_t n_match = 24;
int32_t n_max = 64;
int32_t n_min = 48;
// shared instance of the ngram container for all speculative decoding contexts
std::shared_ptr<common_ngram_mod> obj;
};
struct common_params_speculative_ngram_map {
uint16_t size_n = 12; // ngram size for lookup
uint16_t size_m = 48; // mgram size for speculative tokens
uint16_t min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
};
struct common_params_speculative_ngram_cache {
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
};
struct common_params_speculative {
// TODO: become a vector in order to support "chains of speculators"
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
common_params_speculative_draft draft;
common_params_speculative_ngram_mod ngram_mod;
common_params_speculative_ngram_map ngram_simple;
common_params_speculative_ngram_map ngram_map_k;
common_params_speculative_ngram_map ngram_map_k4v;
common_params_speculative_ngram_cache ngram_cache;
bool has_dft() const {
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
}
};
struct common_params_vocoder {
struct common_params_model model;
std::string speaker_file = ""; // speaker file path // NOLINT
std::string speaker_file; // speaker file path
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy
};
struct common_params_diffusion {
@@ -433,8 +453,8 @@ struct common_params {
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
common_cpu_params cpuparams;
common_cpu_params cpuparams_batch;
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
@@ -678,7 +698,7 @@ std::string common_params_get_system_info(const common_params & params);
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_params * role_model = nullptr);
bool set_process_priority(enum ggml_sched_priority prio);
//
@@ -846,7 +866,7 @@ common_init_result_ptr common_init_from_params(common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const common_cpu_params & params);
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
+1 -1
View File
@@ -43,7 +43,7 @@ static std::set<std::string> get_remote_preset_whitelist(const std::map<std::str
for (const auto & it : key_to_opt) {
const std::string & key = it.first;
const common_arg & opt = it.second;
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
if (allowed_options.find(key) != allowed_options.end() || opt.is_sampling) {
allowed_keys.insert(key);
// also add variant keys (args without leading dashes and env vars)
for (const auto & arg : opt.get_args()) {
+140 -46
View File
@@ -151,6 +151,9 @@ struct common_speculative_state {
llama_tokens & result) = 0;
virtual void accept(uint16_t n_accepted) = 0;
virtual int32_t n_max(const common_params_speculative & params) const = 0;
virtual int32_t n_min(const common_params_speculative & params) const = 0;
};
struct common_speculative_checkpoint {
@@ -296,6 +299,8 @@ struct common_speculative_state_draft : public common_speculative_state {
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
const auto & sparams = params.draft;
auto * spec = this;
auto & batch = spec->batch;
@@ -309,7 +314,7 @@ struct common_speculative_state_draft : public common_speculative_state {
int reuse_i = 0; // index of part to be reused in prompt_dft
int reuse_n = 0; // length of part to be reused in prompt_dft
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max;
llama_tokens prompt_cnv;
if (!spec->vocab_cmpt) {
@@ -367,7 +372,7 @@ struct common_speculative_state_draft : public common_speculative_state {
}
result.clear();
result.reserve(params.n_max);
result.reserve(sparams.n_max);
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
@@ -380,7 +385,7 @@ struct common_speculative_state_draft : public common_speculative_state {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);
if (params.n_max <= (int) result.size()) {
if (sparams.n_max <= (int) result.size()) {
break;
}
}
@@ -473,7 +478,7 @@ struct common_speculative_state_draft : public common_speculative_state {
common_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_max; ++i) {
for (int i = 0; i < sparams.n_max; ++i) {
common_batch_clear(batch);
common_sampler_sample(smpl, ctx_dft, 0, true);
@@ -492,12 +497,12 @@ struct common_speculative_state_draft : public common_speculative_state {
result.push_back(id);
if (params.n_max <= (int) result.size()) {
if (sparams.n_max <= (int) result.size()) {
break;
}
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
if (cur_p->data[0].p < sparams.p_min) {
break;
}
@@ -518,10 +523,14 @@ struct common_speculative_state_draft : public common_speculative_state {
detokenized = replace_to_tgt(detokenized);
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
result = common_tokenize(ctx_tgt, detokenized, false, true);
if (result.size() > (size_t)params.n_max) {
result.resize(params.n_max);
if (result.size() > (size_t) sparams.n_max) {
result.resize(sparams.n_max);
}
}
if (result.size() < (size_t) sparams.n_min) {
result.clear();
}
}
void accept(uint16_t n_accepted) override {
@@ -529,6 +538,14 @@ struct common_speculative_state_draft : public common_speculative_state {
GGML_UNUSED(n_accepted);
}
int32_t n_max(const common_params_speculative & params) const override {
return params.draft.n_max;
}
int32_t n_min(const common_params_speculative & params) const override {
return params.draft.n_min;
}
std::string replace_to_dft(const std::string & input) const {
std::string result = input;
@@ -581,6 +598,14 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
// noop
GGML_UNUSED(n_accepted);
}
int32_t n_max(const common_params_speculative & params) const override {
return params.draft.n_max;
}
int32_t n_min(const common_params_speculative & params) const override {
return params.draft.n_min;
}
};
// state of self-speculation (simple implementation, not ngram-map)
@@ -610,19 +635,27 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
// noop
GGML_UNUSED(n_accepted);
}
int32_t n_max(const common_params_speculative & /*params*/) const override {
return config.size_mgram;
}
int32_t n_min(const common_params_speculative & /*params*/) const override {
return config.size_mgram;
}
};
struct common_speculative_state_ngram_map_k : public common_speculative_state {
// draft ngram map for speculative decoding without draft model
common_ngram_map map;
common_ngram_map config;
common_speculative_state_ngram_map_k(
enum common_speculative_type type,
common_ngram_map map)
: common_speculative_state(type), map(std::move(map)) {}
common_ngram_map config)
: common_speculative_state(type), config(std::move(config)) {}
void begin(const llama_tokens & prompt) override {
common_ngram_map_begin(map, prompt);
common_ngram_map_begin(config, prompt);
}
void draft(
@@ -630,12 +663,20 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
common_ngram_map_draft(map, prompt_tgt, id_last, result);
common_ngram_map_draft(config, prompt_tgt, id_last, result);
GGML_UNUSED(params);
}
void accept(uint16_t n_accepted) override {
common_ngram_map_accept(map, n_accepted);
common_ngram_map_accept(config, n_accepted);
}
int32_t n_max(const common_params_speculative & /*params*/) const override {
return config.size_value;
}
int32_t n_min(const common_params_speculative & /*params*/) const override {
return config.size_value;
}
};
@@ -692,7 +733,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(params);
const auto & sparams = params.ngram_mod;
n_draft_last = 0;
@@ -712,16 +753,16 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
i_last = cur_len - n;
}
result.resize(n + params.n_max);
result.resize(n + sparams.n_max);
for (size_t i = 0; i < n - 1; ++i) {
result[i] = prompt_tgt[cur_len - n + 1 + i];
}
result[n - 1] = id_last;
for (int i = 0; i < params.n_max; ++i) {
for (int i = 0; i < sparams.n_max; ++i) {
const llama_token token = mod.get(result.data() + i);
if (token == common_ngram_mod::EMPTY) {
if (i < params.n_min) {
if (i < sparams.n_min) {
result.clear();
return;
}
@@ -764,6 +805,14 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
}
}
}
int32_t n_max(const common_params_speculative & params) const override {
return params.ngram_mod.n_max;
}
int32_t n_min(const common_params_speculative & params) const override {
return params.ngram_mod.n_min;
}
};
struct common_speculative_state_ngram_cache : public common_speculative_state {
@@ -857,6 +906,14 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
// TODO: noop
GGML_UNUSED(n_accepted);
}
int32_t n_max(const common_params_speculative & /*params*/) const override {
return n_draft;
}
int32_t n_min(const common_params_speculative & /*params*/) const override {
return 0;
}
};
struct common_speculative {
@@ -865,11 +922,13 @@ struct common_speculative {
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
};
static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
uint16_t size_key = config.params.ngram_size_n;
uint16_t size_value = config.params.ngram_size_m;
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
uint16_t min_hits = config.params.ngram_min_hits;
static common_ngram_map get_common_ngram_map(
common_speculative_type type,
const common_params_speculative_ngram_map & config) {
uint16_t size_key = config.size_n;
uint16_t size_value = config.size_m;
bool key_only = type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
uint16_t min_hits = config.min_hits;
return common_ngram_map(size_key, size_value, key_only, min_hits);
}
@@ -927,8 +986,8 @@ common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt) {
llama_context * ctx_dft = nullptr;
if (params.model_dft) {
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
if (params.draft.model) {
ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams);
if (ctx_dft == nullptr) {
LOG_ERR("%s", "failed to create draft context\n");
return nullptr;
@@ -938,7 +997,7 @@ common_speculative * common_speculative_init(
// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft = !params.draft.mparams.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
@@ -961,16 +1020,17 @@ common_speculative * common_speculative_init(
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
}
if (has_ngram_mod) {
// shared instance for all speculative decoding contexts
if (!params.ngram_mod) {
params.ngram_mod = std::make_shared<common_ngram_mod>(params.ngram_size_n, 4*1024*1024);
auto & sparams = params.ngram_mod;
LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__,
params.ngram_size_n, params.ngram_mod->size(),
(float)(params.ngram_mod->size_bytes())/1024/1024);
if (!sparams.obj) {
sparams.obj = std::make_shared<common_ngram_mod>(sparams.n_match, 4*1024*1024);
if (params.ngram_size_n < 16) {
LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, params.ngram_size_n);
LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__,
sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024);
if (sparams.n_match < 16) {
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match);
}
}
@@ -1000,7 +1060,7 @@ common_speculative * common_speculative_init(
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .replacements = */ params.draft.replacements,
/* .use_ckpt = */ use_ckpt
));
break;
@@ -1010,18 +1070,18 @@ common_speculative * common_speculative_init(
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config);
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
uint16_t ngram_size_key = ngram_map.size_key;
uint16_t mgram_size_value = ngram_map.size_value;
auto config_simple = common_ngram_simple_config {
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value
};
auto state = std::make_unique<common_speculative_state_ngram_simple>(
/* .type = */ config.type,
/* .state = */ config_simple
/* .type = */ config.type,
/* .state = */ config_simple
);
impls.push_back(std::move(state));
break;
@@ -1030,18 +1090,17 @@ common_speculative * common_speculative_init(
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
impls.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
(config.type),
get_common_ngram_map(config)
get_common_ngram_map(config.type, config.params.ngram_map_k)
));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
GGML_ASSERT(config.params.ngram_mod);
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod));
GGML_ASSERT(config.params.ngram_mod.obj);
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod.obj));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
auto state = create_state_ngram_cache(
params.lookup_cache_static, params.lookup_cache_dynamic, config);
auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config);
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
break;
}
@@ -1099,6 +1158,15 @@ llama_tokens common_speculative_draft(
impl->n_call_draft++;
}
{
const int n_min = impl->n_min(params);
if (!result.empty() && (int) result.size() < n_min) {
LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min);
result.clear();
}
}
if (!result.empty()) {
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
@@ -1108,7 +1176,7 @@ llama_tokens common_speculative_draft(
impl->n_gen_drafts++;
impl->n_gen_tokens += result.size();
break; // We have a draft, so break out of the loop and return it.
break; // we have a draft, so break out of the loop and return it.
}
}
@@ -1136,6 +1204,32 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
}
}
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) {
if (spec == nullptr) {
return 0;
}
int32_t n_max = 0;
for (const auto & impl : spec->impls) {
n_max = std::max(n_max, impl->n_max(params));
}
return n_max;
}
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) {
if (spec == nullptr) {
return 0;
}
int32_t n_min = 0;
for (const auto & impl : spec->impls) {
n_min = std::max(n_min, impl->n_min(params));
}
return n_min;
}
void common_speculative_print_stats(const common_speculative * spec) {
if (spec == nullptr) {
return;
+3
View File
@@ -33,6 +33,9 @@ llama_tokens common_speculative_draft(
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);