server: move msg diffs tracking to HTTP thread (#17740)
* server: move msg diffs tracking to HTTP thread * wip * tool call tests ok * minor : style * cont : fix * move states to server_response_reader * add safe-guard * fix * fix 2 --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@@ -161,6 +161,25 @@ struct result_prompt_progress {
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
// struct for tracking the state of a task (e.g., for streaming)
|
||||
struct task_result_state {
|
||||
// tracking diffs for partial tool calls
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
common_chat_msg chat_msg;
|
||||
std::string generated_text; // append new chunks of generated text here
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
||||
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
||||
|
||||
// parse partial tool calls and update the internal state
|
||||
common_chat_msg update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs);
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_slot = -1;
|
||||
@@ -175,6 +194,9 @@ struct server_task_result {
|
||||
virtual int get_index() {
|
||||
return -1;
|
||||
}
|
||||
virtual void update(task_result_state &) {
|
||||
// only used by server_task_result_cmpl_*
|
||||
}
|
||||
virtual json to_json() = 0;
|
||||
virtual ~server_task_result() = default;
|
||||
};
|
||||
@@ -233,9 +255,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
common_chat_msg oaicompat_msg; // to be populated by update()
|
||||
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -247,6 +270,11 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
virtual void update(task_result_state & state) override {
|
||||
is_updated = true;
|
||||
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
@@ -280,7 +308,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -292,6 +321,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
virtual void update(task_result_state & state) override {
|
||||
is_updated = true;
|
||||
state.update_chat_msg(content, true, oaicompat_msg_diffs);
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
|
||||
Reference in New Issue
Block a user