server: move msg diffs tracking to HTTP thread (#17740)

* server: move msg diffs tracking to HTTP thread

* wip

* tool call tests ok

* minor : style

* cont : fix

* move states to server_response_reader

* add safe-guard

* fix

* fix 2

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan-Son Nguyen
2025-12-04 15:46:08 +01:00
committed by GitHub
parent 817d743cc1
commit c4c10bfb86
5 changed files with 167 additions and 94 deletions
+87 -88
View File
@@ -101,8 +101,6 @@ struct server_slot {
std::string generated_text;
llama_tokens generated_tokens;
common_chat_msg chat_msg;
std::vector<completion_token_output> generated_token_probs;
bool has_next_token = true;
@@ -153,9 +151,6 @@ struct server_slot {
llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;
// stats
size_t n_sent_text = 0; // number of sent text character
@@ -183,13 +178,10 @@ struct server_slot {
stop = STOP_TYPE_NONE;
stopping_word = "";
n_sent_text = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear();
generated_token_probs.clear();
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
// clear speculative decoding stats
n_draft_total = 0;
@@ -302,23 +294,6 @@ struct server_slot {
return timings;
}
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
GGML_ASSERT(task);
auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
task->params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
return chat_msg;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
GGML_ASSERT(task);
@@ -1284,8 +1259,6 @@ struct server_context_impl {
} else {
res->content = tkn.text_to_send;
res->tokens = { tkn.tok };
slot.update_chat_msg(res->oaicompat_msg_diffs);
}
res->n_decoded = slot.n_decoded;
@@ -1317,8 +1290,14 @@ struct server_context_impl {
res->id_slot = slot.id;
res->index = slot.task->index;
res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens);
// in stream mode, content and tokens are already in last partial chunk
if (slot.task->params.stream) {
res->content = "";
res->tokens = llama_tokens{};
} else {
res->content = std::move(slot.generated_text);
res->tokens = std::move(slot.generated_tokens);
}
res->timings = slot.get_timings();
res->prompt = slot.task->tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.task->params.response_fields);
@@ -1338,7 +1317,6 @@ struct server_context_impl {
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) {
@@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
try {
std::vector<server_task> tasks;
// tracking generation state and partial tool calls
std::vector<task_result_state> states;
const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
states.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
@@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);
tasks.push_back(std::move(task));
}
rd.set_states(std::move(states));
rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
// if single request, return single object instead of array
res->ok(arr.size() == 1 ? arr[0] : arr);
}
} else {
// in streaming mode, the first error must be treated as non-stream response
// this is to match the OAI API behavior
@@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
}
// next responses are streamed
// to be sent immediately
json first_result_json = first_result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
res->data = format_anthropic_sse(first_result->to_json());
res->data = format_anthropic_sse(first_result_json);
} else {
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
res->data = format_oai_sse(first_result_json);
}
res->status = 200;
res->content_type = "text/event-stream";
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
if (should_stop()) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
if (!res_this->data.empty()) {
// flush the first chunk
output = std::move(res_this->data);
res_this->data.clear();
return true;
}
server_response_reader & rd = res_this->rd;
// check if there is more data
if (!rd.has_next()) {
static auto format_error = [](task_response_type res_type, const json & res_json) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
}
SRV_DBG("%s", "all results received, terminating stream\n");
return false; // no more data, terminate
}
// receive subsequent results
auto result = rd.next(should_stop);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
// send the results
json res_json = result->to_json();
if (result->is_error()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse({
return format_anthropic_sse({
{"event", "error"},
{"data", res_json},
});
} else {
output = format_oai_sse(json {{ "error", res_json }});
return format_oai_sse(json {{ "error", res_json }});
}
SRV_DBG("%s", "error received during streaming, terminating stream\n");
return false; // terminate on error
} else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
} else {
output = format_oai_sse(res_json);
}
}
};
// has next data, continue
return true;
try {
if (should_stop()) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
if (!res_this->data.empty()) {
// flush the first chunk
output = std::move(res_this->data);
res_this->data.clear();
return true;
}
server_response_reader & rd = res_this->rd;
// check if there is more data
if (!rd.has_next()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
}
SRV_DBG("%s", "all results received, terminating stream\n");
return false; // no more data, terminate
}
// receive subsequent results
auto result = rd.next(should_stop);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
// send the results
if (result->is_error()) {
json res_json = result->to_json();
output = format_error(res_type, res_json);
SRV_DBG("%s", "error received during streaming, terminating stream\n");
return false; // terminate on error
} else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
json res_json = result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
} else {
output = format_oai_sse(res_json);
}
}
// has next data, continue
return true;
} catch (const std::exception & e) {
json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
output = format_error(res_type, error_json);
// terminate on exception
return false;
}
};
}