server: delegate result_state creation to server_task (#17835)
* server: delegate result_state creation to server_task * remove unued states * add more docs
This commit is contained in:
@@ -2589,6 +2589,10 @@ struct server_context_impl {
|
||||
int get_slot_n_ctx() {
|
||||
return slots.back().n_ctx;
|
||||
}
|
||||
|
||||
server_response_reader get_response_reader() {
|
||||
return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
@@ -2618,8 +2622,8 @@ llama_context * server_context::get_llama_context() const {
|
||||
return impl->ctx;
|
||||
}
|
||||
|
||||
std::pair<server_queue &, server_response &> server_context::get_queues() {
|
||||
return { impl->queue_tasks, impl->queue_results };
|
||||
server_response_reader server_context::get_response_reader() {
|
||||
return impl->get_response_reader();
|
||||
}
|
||||
|
||||
|
||||
@@ -2628,7 +2632,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
|
||||
struct server_res_generator : server_http_res {
|
||||
server_response_reader rd;
|
||||
server_res_generator(server_context_impl & ctx_server)
|
||||
: rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
|
||||
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
|
||||
void ok(const json & response_data) {
|
||||
status = 200;
|
||||
data = safe_json_to_str(response_data);
|
||||
@@ -2661,9 +2665,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
// tracking generation state and partial tool calls
|
||||
std::vector<task_result_state> states;
|
||||
|
||||
const auto & prompt = data.at("prompt");
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
@@ -2679,7 +2680,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
tasks.reserve(inputs.size());
|
||||
states.reserve(inputs.size());
|
||||
int idx = 0;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
@@ -2698,7 +2698,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
task.params.res_type = res_type;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
states.push_back(task.params.oaicompat_chat_syntax);
|
||||
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
@@ -2707,7 +2706,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
task.id,
|
||||
ctx_server.queue_tasks.get_new_id(),
|
||||
idx++);
|
||||
states.push_back(child.params.oaicompat_chat_syntax);
|
||||
tasks.push_back(std::move(child));
|
||||
}
|
||||
}
|
||||
@@ -2715,7 +2713,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
rd.set_states(std::move(states));
|
||||
rd.post_tasks(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
@@ -3445,7 +3442,7 @@ void server_routes::init_routes() {
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
tasks.reserve(documents.size());
|
||||
@@ -3705,7 +3702,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
|
||||
Reference in New Issue
Block a user