server: support multiple generations from one prompt (OAI "n" option) (#17775)
* backend support * server: support multiple generations from one prompt (OAI "n" option) * fix invalid batch * format oai * clean up * disable ctx shift * add test * update comments * fix style * add n_cmpl to docs [no ci] * allowing using both n_cmpl and n
This commit is contained in:
@@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
||||
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
||||
SLOT_STATE_PROCESSING_PROMPT,
|
||||
SLOT_STATE_DONE_PROMPT,
|
||||
SLOT_STATE_GENERATING,
|
||||
@@ -254,6 +255,15 @@ struct server_slot {
|
||||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
// note: a slot can also be either a parent or a child
|
||||
bool is_parent() const {
|
||||
return is_processing() && task->n_children > 0;
|
||||
}
|
||||
|
||||
bool is_child() const {
|
||||
return is_processing() && task->id_parent >= 0;
|
||||
}
|
||||
|
||||
void release() {
|
||||
if (is_processing()) {
|
||||
GGML_ASSERT(task);
|
||||
@@ -383,6 +393,17 @@ struct server_slot {
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void copy_state_to(server_slot & other) const {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
|
||||
other.n_decoded = n_decoded;
|
||||
other.n_remaining = n_remaining;
|
||||
other.i_batch = i_batch;
|
||||
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
|
||||
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
|
||||
other.prompt = prompt.clone();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1022,7 +1043,9 @@ struct server_context_impl {
|
||||
|
||||
slot.task = std::make_unique<const server_task>(std::move(task));
|
||||
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
slot.state = slot.is_child()
|
||||
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
|
||||
: SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
|
||||
@@ -1684,6 +1707,12 @@ struct server_context_impl {
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
if (slot.is_parent() || slot.is_child()) {
|
||||
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Shift context
|
||||
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
|
||||
|
||||
@@ -2308,6 +2337,26 @@ struct server_context_impl {
|
||||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// may need to copy state to other slots
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
||||
std::vector<server_slot *> child_slots;
|
||||
for (auto & other : slots) {
|
||||
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
||||
child_slots.push_back(&other);
|
||||
}
|
||||
}
|
||||
|
||||
// we can only proceed if all child slots are having the correct tasks
|
||||
if (child_slots.size() == slot.task->n_children) {
|
||||
// copy state to the child slots
|
||||
for (auto & child : child_slots) {
|
||||
SLT_INF(slot, "copying state to child %d\n", child->id);
|
||||
slot.copy_state_to(*child);
|
||||
child->state = SLOT_STATE_DONE_PROMPT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
@@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
}
|
||||
tasks.reserve(inputs.size());
|
||||
states.reserve(inputs.size());
|
||||
int idx = 0;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.index = idx++;
|
||||
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
@@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
states.push_back(task.params.oaicompat_chat_syntax);
|
||||
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
for (size_t j = 0; j < task.n_children; j++) {
|
||||
server_task child = task.create_child(
|
||||
task.id,
|
||||
ctx_server.queue_tasks.get_new_id(),
|
||||
idx++);
|
||||
states.push_back(child.params.oaicompat_chat_syntax);
|
||||
tasks.push_back(std::move(child));
|
||||
}
|
||||
}
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
@@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
arr.push_back(res->to_json());
|
||||
}
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr.size() == 1 ? arr[0] : arr);
|
||||
GGML_ASSERT(!arr.empty() && "empty results");
|
||||
if (arr.size() == 1) {
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr[0]);
|
||||
} else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
|
||||
// if multiple results in OAI format, we need to re-format them
|
||||
json & choices = arr[0]["choices"];
|
||||
for (size_t i = 1; i < arr.size(); i++) {
|
||||
choices.push_back(std::move(arr[i]["choices"][0]));
|
||||
}
|
||||
res->ok(arr[0]);
|
||||
} else {
|
||||
// multi-results, non-OAI compat
|
||||
res->ok(arr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// in streaming mode, the first error must be treated as non-stream response
|
||||
|
||||
Reference in New Issue
Block a user