server: support multiple generations from one prompt (OAI "n" option) (#17775)

* backend support

* server: support multiple generations from one prompt (OAI "n" option)

* fix invalid batch

* format oai

* clean up

* disable ctx shift

* add test

* update comments

* fix style

* add n_cmpl to docs [no ci]

* allowing using both n_cmpl and n
This commit is contained in:
Xuan-Son Nguyen
2025-12-06 15:54:38 +01:00
committed by GitHub
parent 09c7c50e64
commit c42712b056
7 changed files with 146 additions and 19 deletions
+80 -5
View File
@@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING,
@@ -254,6 +255,15 @@ struct server_slot {
generated_token_probs.push_back(token);
}
// note: a slot can also be either a parent or a child
bool is_parent() const {
return is_processing() && task->n_children > 0;
}
bool is_child() const {
return is_processing() && task->id_parent >= 0;
}
void release() {
if (is_processing()) {
GGML_ASSERT(task);
@@ -383,6 +393,17 @@ struct server_slot {
return res;
}
void copy_state_to(server_slot & other) const {
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
other.n_decoded = n_decoded;
other.n_remaining = n_remaining;
other.i_batch = i_batch;
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
other.prompt = prompt.clone();
}
};
@@ -1022,7 +1043,9 @@ struct server_context_impl {
slot.task = std::make_unique<const server_task>(std::move(task));
slot.state = SLOT_STATE_STARTED;
slot.state = slot.is_child()
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
: SLOT_STATE_STARTED;
SLT_INF(slot, "%s", "processing task\n");
@@ -1684,6 +1707,12 @@ struct server_context_impl {
GGML_ABORT("not supported by multimodal");
}
if (slot.is_parent() || slot.is_child()) {
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
slot.release();
continue;
}
// Shift context
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
@@ -2308,6 +2337,26 @@ struct server_context_impl {
n_batch = llama_n_batch(ctx);
for (auto & slot : slots) {
// may need to copy state to other slots
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
std::vector<server_slot *> child_slots;
for (auto & other : slots) {
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
child_slots.push_back(&other);
}
}
// we can only proceed if all child slots are having the correct tasks
if (child_slots.size() == slot.task->n_children) {
// copy state to the child slots
for (auto & child : child_slots) {
SLT_INF(slot, "copying state to child %d\n", child->id);
slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
}
}
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
@@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
}
tasks.reserve(inputs.size());
states.reserve(inputs.size());
int idx = 0;
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.index = idx++;
task.tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
@@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);
if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
for (size_t j = 0; j < task.n_children; j++) {
server_task child = task.create_child(
task.id,
ctx_server.queue_tasks.get_new_id(),
idx++);
states.push_back(child.params.oaicompat_chat_syntax);
tasks.push_back(std::move(child));
}
}
tasks.push_back(std::move(task));
}
@@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
arr.push_back(res->to_json());
}
// if single request, return single object instead of array
res->ok(arr.size() == 1 ? arr[0] : arr);
GGML_ASSERT(!arr.empty() && "empty results");
if (arr.size() == 1) {
// if single request, return single object instead of array
res->ok(arr[0]);
} else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
// if multiple results in OAI format, we need to re-format them
json & choices = arr[0]["choices"];
for (size_t i = 1; i < arr.size(); i++) {
choices.push_back(std::move(arr[i]["choices"][0]));
}
res->ok(arr[0]);
} else {
// multi-results, non-OAI compat
res->ok(arr);
}
}
} else {
// in streaming mode, the first error must be treated as non-stream response