server : add Anthropic Messages API support (#17570)
* server : add Anthropic Messages API support * remove -@pytest.mark.slow from tool calling/jinja tests * server : remove unused code and slow/skip on test_anthropic_vision_base64_with_multimodal_model in test_anthropic_api.py * server : removed redundant n field logic in anthropic_params_from_json * server : use single error object instead of error_array in streaming response handler for /v1/chat/completions and use unordered_set instead of set in to_json_anthropic_stream() * server : refactor Anthropic API to use OAI conversion * make sure basic test always go first * clean up * clean up api key check, add test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
+74
-23
@@ -1255,7 +1255,7 @@ struct server_context {
|
||||
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.task->params.verbose;
|
||||
res->oaicompat = slot.task->params.oaicompat;
|
||||
res->res_type = slot.task->params.res_type;
|
||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||
|
||||
@@ -1297,7 +1297,7 @@ struct server_context {
|
||||
res->verbose = slot.task->params.verbose;
|
||||
res->stream = slot.task->params.stream;
|
||||
res->include_usage = slot.task->params.include_usage;
|
||||
res->oaicompat = slot.task->params.oaicompat;
|
||||
res->res_type = slot.task->params.res_type;
|
||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
@@ -1328,7 +1328,7 @@ struct server_context {
|
||||
res->id = slot.task->id;
|
||||
res->index = slot.task->index;
|
||||
res->n_tokens = slot.task->n_tokens();
|
||||
res->oaicompat = slot.task->params.oaicompat;
|
||||
res->res_type = slot.task->params.res_type;
|
||||
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
|
||||
@@ -2951,7 +2951,7 @@ public:
|
||||
data,
|
||||
files,
|
||||
req.should_stop,
|
||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||
TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_completions = [this](const server_http_req & req) {
|
||||
@@ -2962,7 +2962,7 @@ public:
|
||||
body,
|
||||
files,
|
||||
req.should_stop,
|
||||
OAICOMPAT_TYPE_NONE);
|
||||
TASK_RESPONSE_TYPE_NONE);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) {
|
||||
@@ -2973,7 +2973,7 @@ public:
|
||||
body,
|
||||
files,
|
||||
req.should_stop,
|
||||
OAICOMPAT_TYPE_COMPLETION);
|
||||
TASK_RESPONSE_TYPE_OAI_CMPL);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) {
|
||||
@@ -2988,7 +2988,38 @@ public:
|
||||
body_parsed,
|
||||
files,
|
||||
req.should_stop,
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
TASK_RESPONSE_TYPE_OAI_CHAT);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_anthropic_messages = [this](const server_http_req & req) {
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
body,
|
||||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body_parsed,
|
||||
files,
|
||||
req.should_stop,
|
||||
TASK_RESPONSE_TYPE_ANTHROPIC);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_anthropic_count_tokens = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
body,
|
||||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
|
||||
json prompt = body_parsed.at("prompt");
|
||||
llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true);
|
||||
|
||||
res->ok({{"input_tokens", static_cast<int>(tokens.size())}});
|
||||
return res;
|
||||
};
|
||||
|
||||
// same with handle_chat_completions, but without inference part
|
||||
@@ -3107,11 +3138,11 @@ public:
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_embeddings = [this](const server_http_req & req) {
|
||||
return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE);
|
||||
return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) {
|
||||
return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING);
|
||||
return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_rerank = [this](const server_http_req & req) {
|
||||
@@ -3262,7 +3293,7 @@ private:
|
||||
const json & data,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const std::function<bool()> & should_stop,
|
||||
oaicompat_type oaicompat) {
|
||||
task_response_type res_type) {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
@@ -3279,7 +3310,7 @@ private:
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
|
||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||
if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) {
|
||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||
} else {
|
||||
@@ -3301,8 +3332,8 @@ private:
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.res_type = res_type;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
@@ -3352,10 +3383,14 @@ private:
|
||||
}
|
||||
|
||||
// next responses are streamed
|
||||
res->data = format_sse(first_result->to_json()); // to be sent immediately
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
res->data = format_anthropic_sse(first_result->to_json());
|
||||
} else {
|
||||
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
|
||||
}
|
||||
res->status = 200;
|
||||
res->content_type = "text/event-stream";
|
||||
res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool {
|
||||
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
@@ -3372,7 +3407,10 @@ private:
|
||||
|
||||
// check if there is more data
|
||||
if (!rd.has_next()) {
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
// Anthropic doesn't send [DONE], message_stop was already sent
|
||||
output = "";
|
||||
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
||||
output = "data: [DONE]\n\n";
|
||||
} else {
|
||||
output = "";
|
||||
@@ -3391,7 +3429,14 @@ private:
|
||||
// send the results
|
||||
json res_json = result->to_json();
|
||||
if (result->is_error()) {
|
||||
output = format_sse(json {{ "error", res_json }});
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse({
|
||||
{"event", "error"},
|
||||
{"data", res_json},
|
||||
});
|
||||
} else {
|
||||
output = format_oai_sse(json {{ "error", res_json }});
|
||||
}
|
||||
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
||||
return false; // terminate on error
|
||||
} else {
|
||||
@@ -3399,7 +3444,11 @@ private:
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
output = format_sse(res_json);
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse(res_json);
|
||||
} else {
|
||||
output = format_oai_sse(res_json);
|
||||
}
|
||||
}
|
||||
|
||||
// has next data, continue
|
||||
@@ -3507,14 +3556,14 @@ private:
|
||||
return res;
|
||||
}
|
||||
|
||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) {
|
||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
if (!ctx_server.params_base.embedding) {
|
||||
res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return res;
|
||||
}
|
||||
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
@@ -3526,7 +3575,7 @@ private:
|
||||
if (body.count("input") != 0) {
|
||||
prompt = body.at("input");
|
||||
} else if (body.contains("content")) {
|
||||
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
||||
res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible
|
||||
prompt = body.at("content");
|
||||
} else {
|
||||
res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
@@ -3574,7 +3623,7 @@ private:
|
||||
task.tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.res_type = res_type;
|
||||
task.params.embd_normalize = embd_normalize;
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
@@ -3599,7 +3648,7 @@ private:
|
||||
}
|
||||
|
||||
// write JSON response
|
||||
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
|
||||
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
||||
: json(responses);
|
||||
res->ok(root);
|
||||
@@ -3712,6 +3761,8 @@ int main(int argc, char ** argv) {
|
||||
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
|
||||
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
|
||||
ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
|
||||
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
||||
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
|
||||
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
|
||||
|
||||
Reference in New Issue
Block a user