spec : refactor params (#22397)
* spec : refactor params * cont : fix * cont : rename "sparam" to "sampling" * cont : add spec params category * cont : add info about removed arguments * cont : skip param length check for spec params * cont : adapt server tests
This commit is contained in:
+377
-217
File diff suppressed because it is too large
Load Diff
+4
-2
@@ -25,7 +25,8 @@ struct common_arg {
|
||||
const char * value_hint_2 = nullptr; // for second arg value
|
||||
const char * env = nullptr;
|
||||
std::string help;
|
||||
bool is_sparam = false; // is current arg a sampling param?
|
||||
bool is_sampling = false; // is current arg a sampling param?
|
||||
bool is_spec = false; // is current arg a speculative decoding param?
|
||||
bool is_preset_only = false; // is current arg preset-only (not treated as CLI arg)
|
||||
void (*handler_void) (common_params & params) = nullptr;
|
||||
void (*handler_string) (common_params & params, const std::string &) = nullptr;
|
||||
@@ -74,7 +75,8 @@ struct common_arg {
|
||||
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
|
||||
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
|
||||
common_arg & set_env(const char * env);
|
||||
common_arg & set_sparam();
|
||||
common_arg & set_sampling();
|
||||
common_arg & set_spec();
|
||||
common_arg & set_preset_only();
|
||||
bool in_example(enum llama_example ex);
|
||||
bool is_exclude(enum llama_example ex);
|
||||
|
||||
+7
-7
@@ -70,7 +70,7 @@ common_time_meas::~common_time_meas() {
|
||||
// CPU utils
|
||||
//
|
||||
|
||||
int32_t cpu_get_num_physical_cores() {
|
||||
int32_t common_cpu_get_num_physical_cores() {
|
||||
#ifdef __linux__
|
||||
// enumerate the set of thread siblings, num entries is num cores
|
||||
std::unordered_set<std::string> siblings;
|
||||
@@ -185,11 +185,11 @@ static int cpu_count_math_cpus(int n_cpu) {
|
||||
/**
|
||||
* Returns number of CPUs on system that are useful for math.
|
||||
*/
|
||||
int32_t cpu_get_num_math() {
|
||||
int32_t common_cpu_get_num_math() {
|
||||
#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
|
||||
int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
|
||||
if (n_cpu < 1) {
|
||||
return cpu_get_num_physical_cores();
|
||||
return common_cpu_get_num_physical_cores();
|
||||
}
|
||||
if (is_hybrid_cpu()) {
|
||||
cpu_set_t affinity;
|
||||
@@ -202,7 +202,7 @@ int32_t cpu_get_num_math() {
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return cpu_get_num_physical_cores();
|
||||
return common_cpu_get_num_physical_cores();
|
||||
}
|
||||
|
||||
// Helper for setting process priority
|
||||
@@ -263,7 +263,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
//
|
||||
|
||||
|
||||
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
|
||||
void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_params * role_model) {
|
||||
int32_t n_set = 0;
|
||||
|
||||
if (cpuparams.n_threads < 0) {
|
||||
@@ -271,7 +271,7 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model)
|
||||
if (role_model != nullptr) {
|
||||
cpuparams = *role_model;
|
||||
} else {
|
||||
cpuparams.n_threads = cpu_get_num_math();
|
||||
cpuparams.n_threads = common_cpu_get_num_math();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1521,7 +1521,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
return cparams;
|
||||
}
|
||||
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const common_cpu_params & params) {
|
||||
struct ggml_threadpool_params tpp;
|
||||
|
||||
ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
|
||||
|
||||
+56
-36
@@ -54,7 +54,7 @@ struct common_control_vector_load_info;
|
||||
// CPU utils
|
||||
//
|
||||
|
||||
struct cpu_params {
|
||||
struct common_cpu_params {
|
||||
int n_threads = -1;
|
||||
bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
|
||||
bool mask_valid = false; // Default: any CPU
|
||||
@@ -63,8 +63,8 @@ struct cpu_params {
|
||||
uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
|
||||
};
|
||||
|
||||
int32_t cpu_get_num_physical_cores();
|
||||
int32_t cpu_get_num_math();
|
||||
int32_t common_cpu_get_num_physical_cores();
|
||||
int32_t common_cpu_get_num_math();
|
||||
|
||||
//
|
||||
// Common params
|
||||
@@ -297,34 +297,19 @@ struct common_params_model {
|
||||
|
||||
struct common_ngram_mod;
|
||||
|
||||
struct common_params_speculative {
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
|
||||
// draft-model-based speculative decoding parameters
|
||||
struct common_params_speculative_draft {
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
|
||||
// general-purpose speculative decoding parameters
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
common_params_model mparams;
|
||||
|
||||
// ngram-based speculative decoding
|
||||
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::shared_ptr<common_ngram_mod> ngram_mod;
|
||||
|
||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
|
||||
// draft-model speculative decoding
|
||||
|
||||
struct common_params_model mparams_dft;
|
||||
|
||||
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
|
||||
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
|
||||
llama_context_params cparams; // these are the parameters for the draft llama_context
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
@@ -332,25 +317,60 @@ struct common_params_speculative {
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
common_cpu_params cpuparams;
|
||||
common_cpu_params cpuparams_batch;
|
||||
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
};
|
||||
|
||||
struct common_params_speculative_ngram_mod {
|
||||
int32_t n_match = 24;
|
||||
|
||||
int32_t n_max = 64;
|
||||
int32_t n_min = 48;
|
||||
|
||||
// shared instance of the ngram container for all speculative decoding contexts
|
||||
std::shared_ptr<common_ngram_mod> obj;
|
||||
};
|
||||
|
||||
struct common_params_speculative_ngram_map {
|
||||
uint16_t size_n = 12; // ngram size for lookup
|
||||
uint16_t size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
};
|
||||
|
||||
struct common_params_speculative_ngram_cache {
|
||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding
|
||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
// TODO: become a vector in order to support "chains of speculators"
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
|
||||
common_params_speculative_draft draft;
|
||||
|
||||
common_params_speculative_ngram_mod ngram_mod;
|
||||
common_params_speculative_ngram_map ngram_simple;
|
||||
common_params_speculative_ngram_map ngram_map_k;
|
||||
common_params_speculative_ngram_map ngram_map_k4v;
|
||||
|
||||
common_params_speculative_ngram_cache ngram_cache;
|
||||
|
||||
bool has_dft() const {
|
||||
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
|
||||
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
struct common_params_model model;
|
||||
|
||||
std::string speaker_file = ""; // speaker file path // NOLINT
|
||||
std::string speaker_file; // speaker file path
|
||||
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy
|
||||
};
|
||||
|
||||
struct common_params_diffusion {
|
||||
@@ -433,8 +453,8 @@ struct common_params {
|
||||
|
||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
common_cpu_params cpuparams;
|
||||
common_cpu_params cpuparams_batch;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||
void * cb_eval_user_data = nullptr;
|
||||
@@ -678,7 +698,7 @@ std::string common_params_get_system_info(const common_params & params);
|
||||
|
||||
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
||||
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
||||
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
|
||||
void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_params * role_model = nullptr);
|
||||
bool set_process_priority(enum ggml_sched_priority prio);
|
||||
|
||||
//
|
||||
@@ -846,7 +866,7 @@ common_init_result_ptr common_init_from_params(common_params & params);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const common_cpu_params & params);
|
||||
|
||||
// clear LoRA adapters from context, then apply new list of adapters
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
||||
|
||||
+1
-1
@@ -43,7 +43,7 @@ static std::set<std::string> get_remote_preset_whitelist(const std::map<std::str
|
||||
for (const auto & it : key_to_opt) {
|
||||
const std::string & key = it.first;
|
||||
const common_arg & opt = it.second;
|
||||
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
|
||||
if (allowed_options.find(key) != allowed_options.end() || opt.is_sampling) {
|
||||
allowed_keys.insert(key);
|
||||
// also add variant keys (args without leading dashes and env vars)
|
||||
for (const auto & arg : opt.get_args()) {
|
||||
|
||||
+140
-46
@@ -151,6 +151,9 @@ struct common_speculative_state {
|
||||
llama_tokens & result) = 0;
|
||||
|
||||
virtual void accept(uint16_t n_accepted) = 0;
|
||||
|
||||
virtual int32_t n_max(const common_params_speculative & params) const = 0;
|
||||
virtual int32_t n_min(const common_params_speculative & params) const = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_checkpoint {
|
||||
@@ -296,6 +299,8 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
const auto & sparams = params.draft;
|
||||
|
||||
auto * spec = this;
|
||||
|
||||
auto & batch = spec->batch;
|
||||
@@ -309,7 +314,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
int reuse_i = 0; // index of part to be reused in prompt_dft
|
||||
int reuse_n = 0; // length of part to be reused in prompt_dft
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
|
||||
const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max;
|
||||
|
||||
llama_tokens prompt_cnv;
|
||||
if (!spec->vocab_cmpt) {
|
||||
@@ -367,7 +372,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
}
|
||||
|
||||
result.clear();
|
||||
result.reserve(params.n_max);
|
||||
result.reserve(sparams.n_max);
|
||||
|
||||
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
|
||||
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
|
||||
@@ -380,7 +385,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
||||
result.push_back(prompt_dft[i]);
|
||||
|
||||
if (params.n_max <= (int) result.size()) {
|
||||
if (sparams.n_max <= (int) result.size()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -473,7 +478,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
common_sampler_reset(smpl);
|
||||
|
||||
// sample n_draft tokens from the draft model
|
||||
for (int i = 0; i < params.n_max; ++i) {
|
||||
for (int i = 0; i < sparams.n_max; ++i) {
|
||||
common_batch_clear(batch);
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, 0, true);
|
||||
@@ -492,12 +497,12 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
|
||||
result.push_back(id);
|
||||
|
||||
if (params.n_max <= (int) result.size()) {
|
||||
if (sparams.n_max <= (int) result.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
if (cur_p->data[0].p < sparams.p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -518,10 +523,14 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
detokenized = replace_to_tgt(detokenized);
|
||||
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
|
||||
result = common_tokenize(ctx_tgt, detokenized, false, true);
|
||||
if (result.size() > (size_t)params.n_max) {
|
||||
result.resize(params.n_max);
|
||||
if (result.size() > (size_t) sparams.n_max) {
|
||||
result.resize(sparams.n_max);
|
||||
}
|
||||
}
|
||||
|
||||
if (result.size() < (size_t) sparams.n_min) {
|
||||
result.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
@@ -529,6 +538,14 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & params) const override {
|
||||
return params.draft.n_max;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & params) const override {
|
||||
return params.draft.n_min;
|
||||
}
|
||||
|
||||
std::string replace_to_dft(const std::string & input) const {
|
||||
std::string result = input;
|
||||
|
||||
@@ -581,6 +598,14 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
|
||||
// noop
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & params) const override {
|
||||
return params.draft.n_max;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & params) const override {
|
||||
return params.draft.n_min;
|
||||
}
|
||||
};
|
||||
|
||||
// state of self-speculation (simple implementation, not ngram-map)
|
||||
@@ -610,19 +635,27 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
|
||||
// noop
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & /*params*/) const override {
|
||||
return config.size_mgram;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & /*params*/) const override {
|
||||
return config.size_mgram;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_ngram_map_k : public common_speculative_state {
|
||||
// draft ngram map for speculative decoding without draft model
|
||||
common_ngram_map map;
|
||||
common_ngram_map config;
|
||||
|
||||
common_speculative_state_ngram_map_k(
|
||||
enum common_speculative_type type,
|
||||
common_ngram_map map)
|
||||
: common_speculative_state(type), map(std::move(map)) {}
|
||||
common_ngram_map config)
|
||||
: common_speculative_state(type), config(std::move(config)) {}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
common_ngram_map_begin(map, prompt);
|
||||
common_ngram_map_begin(config, prompt);
|
||||
}
|
||||
|
||||
void draft(
|
||||
@@ -630,12 +663,20 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
common_ngram_map_draft(map, prompt_tgt, id_last, result);
|
||||
common_ngram_map_draft(config, prompt_tgt, id_last, result);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
common_ngram_map_accept(map, n_accepted);
|
||||
common_ngram_map_accept(config, n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & /*params*/) const override {
|
||||
return config.size_value;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & /*params*/) const override {
|
||||
return config.size_value;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -692,7 +733,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(params);
|
||||
const auto & sparams = params.ngram_mod;
|
||||
|
||||
n_draft_last = 0;
|
||||
|
||||
@@ -712,16 +753,16 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
i_last = cur_len - n;
|
||||
}
|
||||
|
||||
result.resize(n + params.n_max);
|
||||
result.resize(n + sparams.n_max);
|
||||
for (size_t i = 0; i < n - 1; ++i) {
|
||||
result[i] = prompt_tgt[cur_len - n + 1 + i];
|
||||
}
|
||||
result[n - 1] = id_last;
|
||||
|
||||
for (int i = 0; i < params.n_max; ++i) {
|
||||
for (int i = 0; i < sparams.n_max; ++i) {
|
||||
const llama_token token = mod.get(result.data() + i);
|
||||
if (token == common_ngram_mod::EMPTY) {
|
||||
if (i < params.n_min) {
|
||||
if (i < sparams.n_min) {
|
||||
result.clear();
|
||||
return;
|
||||
}
|
||||
@@ -764,6 +805,14 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & params) const override {
|
||||
return params.ngram_mod.n_max;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & params) const override {
|
||||
return params.ngram_mod.n_min;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_ngram_cache : public common_speculative_state {
|
||||
@@ -857,6 +906,14 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
|
||||
// TODO: noop
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & /*params*/) const override {
|
||||
return n_draft;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & /*params*/) const override {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative {
|
||||
@@ -865,11 +922,13 @@ struct common_speculative {
|
||||
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
||||
};
|
||||
|
||||
static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
|
||||
uint16_t size_key = config.params.ngram_size_n;
|
||||
uint16_t size_value = config.params.ngram_size_m;
|
||||
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
|
||||
uint16_t min_hits = config.params.ngram_min_hits;
|
||||
static common_ngram_map get_common_ngram_map(
|
||||
common_speculative_type type,
|
||||
const common_params_speculative_ngram_map & config) {
|
||||
uint16_t size_key = config.size_n;
|
||||
uint16_t size_value = config.size_m;
|
||||
bool key_only = type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
|
||||
uint16_t min_hits = config.min_hits;
|
||||
|
||||
return common_ngram_map(size_key, size_value, key_only, min_hits);
|
||||
}
|
||||
@@ -927,8 +986,8 @@ common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt) {
|
||||
llama_context * ctx_dft = nullptr;
|
||||
if (params.model_dft) {
|
||||
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
|
||||
if (params.draft.model) {
|
||||
ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams);
|
||||
if (ctx_dft == nullptr) {
|
||||
LOG_ERR("%s", "failed to create draft context\n");
|
||||
return nullptr;
|
||||
@@ -938,7 +997,7 @@ common_speculative * common_speculative_init(
|
||||
// Compute the implementations to use based on the config and their order of preference
|
||||
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
|
||||
{
|
||||
bool has_draft = !params.mparams_dft.path.empty();
|
||||
bool has_draft = !params.draft.mparams.path.empty();
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
|
||||
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
|
||||
@@ -961,16 +1020,17 @@ common_speculative * common_speculative_init(
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
|
||||
}
|
||||
if (has_ngram_mod) {
|
||||
// shared instance for all speculative decoding contexts
|
||||
if (!params.ngram_mod) {
|
||||
params.ngram_mod = std::make_shared<common_ngram_mod>(params.ngram_size_n, 4*1024*1024);
|
||||
auto & sparams = params.ngram_mod;
|
||||
|
||||
LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__,
|
||||
params.ngram_size_n, params.ngram_mod->size(),
|
||||
(float)(params.ngram_mod->size_bytes())/1024/1024);
|
||||
if (!sparams.obj) {
|
||||
sparams.obj = std::make_shared<common_ngram_mod>(sparams.n_match, 4*1024*1024);
|
||||
|
||||
if (params.ngram_size_n < 16) {
|
||||
LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, params.ngram_size_n);
|
||||
LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__,
|
||||
sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024);
|
||||
|
||||
if (sparams.n_match < 16) {
|
||||
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1000,7 +1060,7 @@ common_speculative * common_speculative_init(
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .replacements = */ params.replacements,
|
||||
/* .replacements = */ params.draft.replacements,
|
||||
/* .use_ckpt = */ use_ckpt
|
||||
));
|
||||
break;
|
||||
@@ -1010,18 +1070,18 @@ common_speculative * common_speculative_init(
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config);
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
|
||||
|
||||
uint16_t ngram_size_key = ngram_map.size_key;
|
||||
uint16_t mgram_size_value = ngram_map.size_value;
|
||||
|
||||
auto config_simple = common_ngram_simple_config {
|
||||
/* .size_ngram = */ ngram_size_key,
|
||||
/* .size_mgram = */ mgram_size_value
|
||||
/* .size_ngram = */ ngram_size_key,
|
||||
/* .size_mgram = */ mgram_size_value
|
||||
};
|
||||
auto state = std::make_unique<common_speculative_state_ngram_simple>(
|
||||
/* .type = */ config.type,
|
||||
/* .state = */ config_simple
|
||||
/* .type = */ config.type,
|
||||
/* .state = */ config_simple
|
||||
);
|
||||
impls.push_back(std::move(state));
|
||||
break;
|
||||
@@ -1030,18 +1090,17 @@ common_speculative * common_speculative_init(
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
|
||||
(config.type),
|
||||
get_common_ngram_map(config)
|
||||
get_common_ngram_map(config.type, config.params.ngram_map_k)
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
|
||||
GGML_ASSERT(config.params.ngram_mod);
|
||||
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod));
|
||||
GGML_ASSERT(config.params.ngram_mod.obj);
|
||||
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod.obj));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
|
||||
auto state = create_state_ngram_cache(
|
||||
params.lookup_cache_static, params.lookup_cache_dynamic, config);
|
||||
auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config);
|
||||
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
|
||||
break;
|
||||
}
|
||||
@@ -1099,6 +1158,15 @@ llama_tokens common_speculative_draft(
|
||||
impl->n_call_draft++;
|
||||
}
|
||||
|
||||
{
|
||||
const int n_min = impl->n_min(params);
|
||||
|
||||
if (!result.empty() && (int) result.size() < n_min) {
|
||||
LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min);
|
||||
result.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!result.empty()) {
|
||||
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
|
||||
common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
|
||||
@@ -1108,7 +1176,7 @@ llama_tokens common_speculative_draft(
|
||||
impl->n_gen_drafts++;
|
||||
impl->n_gen_tokens += result.size();
|
||||
|
||||
break; // We have a draft, so break out of the loop and return it.
|
||||
break; // we have a draft, so break out of the loop and return it.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1136,6 +1204,32 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
|
||||
}
|
||||
}
|
||||
|
||||
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) {
|
||||
if (spec == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t n_max = 0;
|
||||
for (const auto & impl : spec->impls) {
|
||||
n_max = std::max(n_max, impl->n_max(params));
|
||||
}
|
||||
|
||||
return n_max;
|
||||
}
|
||||
|
||||
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) {
|
||||
if (spec == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t n_min = 0;
|
||||
for (const auto & impl : spec->impls) {
|
||||
n_min = std::max(n_min, impl->n_min(params));
|
||||
}
|
||||
|
||||
return n_min;
|
||||
}
|
||||
|
||||
void common_speculative_print_stats(const common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
|
||||
@@ -33,6 +33,9 @@ llama_tokens common_speculative_draft(
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
|
||||
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user