llama : add adaptive-p sampler (#17927)
* initial commit for branch * simplify constants * add params to `struct common_params_sampling`, add reference to PR * explicitly clamp `min_target` and `max_target` to `[0.0, 1.0]` * add args, rename `queue_size` -> `window_size` * improved comments * minor * remove old unused code from algorithm * minor * add power law case to `common_sampler_init`, add sampler name mappings * clarify behaviour when `window_size = 0` * add missing enums * remove `target_range` param, make `target == 1` no-op, cleanup code * oops, straggler * add missing parameters in `server-task.cpp` * copy from author ref: https://gist.github.com/MrJackSpade/9be99c7efbba7b95a41377e123b7b069 * remove old debug log, style nit * fix compiler warning, add commented-out logging per token * re-write + change parameters + simplify * oops forgot args.cpp * fix leftover `window_size` * add missing values to `common_params_sampling::print()` * with logging * does this fix it? * no, but does this? * update default decay * optimize * fix bad merge my git skills are lacking * silence `missing initializer for member` * update default decay to 0.9 * fix logging * format (double) * add power law to the new `samplers` vector * log sampler init values * improve logging messages in llama_sampler_power_law * remove extraneous logging * simplify target computation last commit with debug logging! * remove debug logging, explicitly clamp params at init * add `use_power_law` flag + logic, minor cleanup * update `power-law` -> `adaptive-p` * fix cold start EMA - `ctx->weighted_sum` is now initialized and reset to `target / (1.0f - clamped_decay)` - `ctx->total_weight` is now initialized and reset to `1.0f / (1.0f - clamped_decay)` this fixes a "cold start" problem with the moving average * update `SHARPNESS` constant to `10.0f` * minor style fixes no functional changes * minor style fixes cont. * update `llama_sampler_adaptive_p_i` for backend sampling (ref: #17004) * separate into `apply` + `accept` functions * `pending_token_idx`: switch from `llama_token` to `int32` functionally identical (`llama.h` has `typedef int32_t llama_token;`), but its more correct now * don't transform logits <= -1e9f * fix masking in backend top-p, min-p * address review comments * typo in comments `RND` -> `RNG` * add docs * add recommended values in completion docs * address PR feedback * remove trailing whitespace (for CI `editorconfig`) * add to adaptive-p to `common_sampler_types_from_chars`
This commit is contained in:
@@ -1395,6 +1395,33 @@ extern "C" {
|
||||
const char ** seq_breakers,
|
||||
size_t num_breakers);
|
||||
|
||||
/// adaptive-p: select tokens near a configurable target probability over time.
|
||||
///
|
||||
/// the adaptive-p sampler transforms the token probability distribution to favor tokens
|
||||
/// that fall near a user-configurable probability target.
|
||||
///
|
||||
/// internally, the sampler maintains an exponential moving average of the *ORIGINAL*
|
||||
/// probabilities of selected tokens at each sampling step. it uses this EMA to compute an
|
||||
/// adapted target probability at each sampling step, thus maintaining the desired target
|
||||
/// probability over time.
|
||||
///
|
||||
/// adaptive-p selects a token ID rather than just mutating candidates, so it must be last
|
||||
/// in the sampler chain (like mirostat, dist, greedy).
|
||||
///
|
||||
/// only mild truncation before this sampler is recommended. we suggest applying min-p
|
||||
/// before adaptive-p as the only other active sampler in the chain.
|
||||
///
|
||||
/// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
|
||||
/// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99)
|
||||
/// @param seed RNG seed
|
||||
///
|
||||
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927
|
||||
///
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p(
|
||||
float target,
|
||||
float decay,
|
||||
uint32_t seed);
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||
int32_t n_vocab,
|
||||
int32_t n_logit_bias,
|
||||
|
||||
Reference in New Issue
Block a user