feat: add token usage tracking to streaming responses

Changes:
- Added usage field to ChatCompletionStreamResponse model
- Track prompt, completion, and total tokens in streaming responses
- Include usage info in final chunks of all streaming endpoints
- Clarified model descriptions as "Instruct variant" in registry
- Updated MLX repo mappings to prioritize instruction-following models

Fixes:
- CodeLlama: Using Instruct variant mapping
- DeepSeek: Using instruct-mlx variant
- StarCoder2: 15b only (has Instruct variant on MLX)

Token budget: 89 tokens (unchanged, 4.45% of 2000 limit)
Tests pass: 13/13

Minor improvements:
- Calculate prompt tokens once at function start
- Track completion tokens in streaming generators
- Include usage in tool_calls and content streaming
This commit is contained in:
2026-02-24 22:59:03 +01:00
parent 580d1e5d17
commit 0a97e4af8c
3 changed files with 75 additions and 36 deletions
+1
View File
@@ -102,6 +102,7 @@ class ChatCompletionStreamResponse(BaseModel):
created: int = Field(..., description="Unix timestamp") created: int = Field(..., description="Unix timestamp")
model: str = Field(..., description="Model used") model: str = Field(..., description="Model used")
choices: List[ChatCompletionStreamChoice] = Field(..., description="Content chunks") choices: List[ChatCompletionStreamChoice] = Field(..., description="Content chunks")
usage: Optional[UsageInfo] = Field(default=None, description="Token usage (only in final chunk)")
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
+65 -28
View File
@@ -544,6 +544,9 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time()) created = int(time.time())
# Calculate prompt tokens once
prompt_tokens = len(TOKEN_ENCODING.encode(prompt))
if request.stream: if request.stream:
# For streaming with tools, return tool_calls to client (opencode) for execution # For streaming with tools, return tool_calls to client (opencode) for execution
# This enables multi-turn conversations where client executes tools and sends results back # This enables multi-turn conversations where client executes tools and sends results back
@@ -579,6 +582,10 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
# Client (opencode) will execute them and send results back # Client (opencode) will execute them and send results back
async def tool_calls_stream_generator() -> AsyncIterator[str]: async def tool_calls_stream_generator() -> AsyncIterator[str]:
"""Generate SSE stream with tool_calls for client execution.""" """Generate SSE stream with tool_calls for client execution."""
# Track completion tokens
completion_tokens = len(TOKEN_ENCODING.encode(content)) if content else 0
total_tokens = prompt_tokens + completion_tokens
# Send role chunk # Send role chunk
first_chunk = ChatCompletionStreamResponse( first_chunk = ChatCompletionStreamResponse(
id=completion_id, id=completion_id,
@@ -600,9 +607,9 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
model=request.model, model=request.model,
choices=[ choices=[
ChatCompletionStreamChoice( ChatCompletionStreamChoice(
delta={"content": content} delta={"content": content}
) )
] ]
) )
yield f"data: {content_chunk.model_dump_json()}\n\n" yield f"data: {content_chunk.model_dump_json()}\n\n"
@@ -622,10 +629,11 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
"arguments": tc["function"]["arguments"] "arguments": tc["function"]["arguments"]
} }
}) })
logger.debug(f" 🔧 Sending tool_calls in delta: {tool_calls_delta}") logger.debug(f" 🔧 Sending tool_calls in delta: {tool_calls_delta}")
# Build response in OpenAI streaming format # Build response in OpenAI streaming format with usage
from api.models import UsageInfo
final_delta = {"tool_calls": tool_calls_delta} final_delta = {"tool_calls": tool_calls_delta}
final_chunk = { final_chunk = {
"id": completion_id, "id": completion_id,
@@ -638,7 +646,12 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
"delta": final_delta, "delta": final_delta,
"finish_reason": "tool_calls" "finish_reason": "tool_calls"
} }
] ],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}
} }
import json import json
@@ -660,6 +673,11 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
async def content_stream_generator() -> AsyncIterator[str]: async def content_stream_generator() -> AsyncIterator[str]:
"""Generate SSE stream with content.""" """Generate SSE stream with content."""
# Track completion tokens
completion_tokens = len(TOKEN_ENCODING.encode(content)) if content else 0
total_tokens = prompt_tokens + completion_tokens
from api.models import UsageInfo
# Send role chunk # Send role chunk
first_chunk = ChatCompletionStreamResponse( first_chunk = ChatCompletionStreamResponse(
id=completion_id, id=completion_id,
@@ -667,9 +685,9 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
model=request.model, model=request.model,
choices=[ choices=[
ChatCompletionStreamChoice( ChatCompletionStreamChoice(
delta={"role": "assistant"} delta={"role": "assistant"}
) )
] ]
) )
yield f"data: {first_chunk.model_dump_json()}\n\n" yield f"data: {first_chunk.model_dump_json()}\n\n"
@@ -683,23 +701,28 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
model=request.model, model=request.model,
choices=[ choices=[
ChatCompletionStreamChoice( ChatCompletionStreamChoice(
delta={"content": chunk} delta={"content": chunk}
) )
] ]
) )
yield f"data: {stream_chunk.model_dump_json()}\n\n" yield f"data: {stream_chunk.model_dump_json()}\n\n"
# Send final chunk with finish_reason # Send final chunk with finish_reason and usage
final_chunk = ChatCompletionStreamResponse( final_chunk = ChatCompletionStreamResponse(
id=completion_id, id=completion_id,
created=created, created=created,
model=request.model, model=request.model,
choices=[ choices=[
ChatCompletionStreamChoice( ChatCompletionStreamChoice(
delta={}, delta={},
finish_reason="stop" finish_reason="stop"
) )
] ],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
) )
yield f"data: {final_chunk.model_dump_json()}\n\n" yield f"data: {final_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@@ -712,6 +735,10 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
# Regular streaming without tools # Regular streaming without tools
async def stream_generator() -> AsyncIterator[str]: async def stream_generator() -> AsyncIterator[str]:
"""Generate SSE stream.""" """Generate SSE stream."""
# Track completion tokens
full_response = ""
from api.models import UsageInfo
# Send first chunk with role # Send first chunk with role
first_chunk = ChatCompletionStreamResponse( first_chunk = ChatCompletionStreamResponse(
id=completion_id, id=completion_id,
@@ -719,9 +746,9 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
model=request.model, model=request.model,
choices=[ choices=[
ChatCompletionStreamChoice( ChatCompletionStreamChoice(
delta={"role": "assistant"} delta={"role": "assistant"}
) )
] ]
) )
yield f"data: {first_chunk.model_dump_json()}\n\n" yield f"data: {first_chunk.model_dump_json()}\n\n"
@@ -731,29 +758,39 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
max_tokens=request.max_tokens or 1024, max_tokens=request.max_tokens or 1024,
temperature=request.temperature or 0.7 temperature=request.temperature or 0.7
): ):
full_response += chunk
stream_chunk = ChatCompletionStreamResponse( stream_chunk = ChatCompletionStreamResponse(
id=completion_id, id=completion_id,
created=created, created=created,
model=request.model, model=request.model,
choices=[ choices=[
ChatCompletionStreamChoice( ChatCompletionStreamChoice(
delta={"content": chunk} delta={"content": chunk}
) )
] ]
) )
yield f"data: {stream_chunk.model_dump_json()}\n\n" yield f"data: {stream_chunk.model_dump_json()}\n\n"
# Send final chunk # Calculate final token counts
completion_tokens = len(TOKEN_ENCODING.encode(full_response)) if full_response else 0
total_tokens = prompt_tokens + completion_tokens
# Send final chunk with usage
final_chunk = ChatCompletionStreamResponse( final_chunk = ChatCompletionStreamResponse(
id=completion_id, id=completion_id,
created=created, created=created,
model=request.model, model=request.model,
choices=[ choices=[
ChatCompletionStreamChoice( ChatCompletionStreamChoice(
delta={}, delta={},
finish_reason="stop" finish_reason="stop"
) )
] ],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
) )
yield f"data: {final_chunk.model_dump_json()}\n\n" yield f"data: {final_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
+9 -8
View File
@@ -153,28 +153,28 @@ MLX_QUALITY_MAP = {
MODEL_METADATA = { MODEL_METADATA = {
"qwen2.5-coder": { "qwen2.5-coder": {
"name": "Qwen 2.5 Coder", "name": "Qwen 2.5 Coder",
"description": "Alibaba's code-focused model, excellent for small sizes", "description": "Alibaba's code-focused Instruct model, excellent for small sizes",
"priority": 1, "priority": 1,
"max_context": 128000, "max_context": 128000,
"variants": ["3b", "7b", "14b"], "variants": ["3b", "7b", "14b"],
}, },
"deepseek-coder": { "deepseek-coder": {
"name": "DeepSeek Coder", "name": "DeepSeek Coder",
"description": "DeepSeek's code model, good alternative", "description": "DeepSeek's code model (Instruct variant)",
"priority": 2, "priority": 2,
"max_context": 16384, "max_context": 16384,
"variants": ["1.3b", "6.7b"], "variants": ["1.3b", "6.7b"],
}, },
"codellama": { "codellama": {
"name": "CodeLlama", "name": "CodeLlama",
"description": "Meta's code model", "description": "Meta's code model (Instruct variant)",
"priority": 3, "priority": 3,
"max_context": 16384, "max_context": 16384,
"variants": ["7b", "13b"], "variants": ["7b", "13b"],
}, },
"llama-3.2": { "llama-3.2": {
"name": "Llama 3.2", "name": "Llama 3.2",
"description": "Meta's latest general-purpose model with strong coding abilities", "description": "Meta's latest general-purpose model with strong coding abilities (Instruct variant)",
"priority": 4, "priority": 4,
"max_context": 128000, "max_context": 128000,
"variants": ["1b", "3b"], "variants": ["1b", "3b"],
@@ -195,10 +195,10 @@ MODEL_METADATA = {
}, },
"starcoder2": { "starcoder2": {
"name": "StarCoder2", "name": "StarCoder2",
"description": "BigCode's open code generation model", "description": "BigCode's open code generation model (Instruct variant)",
"priority": 7, "priority": 7,
"max_context": 8192, "max_context": 8192,
"variants": ["3b", "7b", "15b"], "variants": ["15b"], # Only 15b has Instruct variant on MLX
}, },
} }
@@ -366,14 +366,15 @@ def get_model_hf_repo_mlx(model_id: str, variant: ModelVariant, quant: Quantizat
# MLX quantized models are in mlx-community org with -{quant}bit suffix # MLX quantized models are in mlx-community org with -{quant}bit suffix
# Map base model names to mlx-community quantized versions # Map base model names to mlx-community quantized versions
# IMPORTANT: Always use Instruct variants for instruction-following
mlx_repo_map = { mlx_repo_map = {
"qwen2.5-coder": f"mlx-community/Qwen2.5-Coder-{variant.size.capitalize()}-Instruct", "qwen2.5-coder": f"mlx-community/Qwen2.5-Coder-{variant.size.capitalize()}-Instruct",
"deepseek-coder": f"mlx-community/deepseek-coder-{variant.size}-base", "deepseek-coder": f"mlx-community/deepseek-coder-{variant.size}-instruct-mlx",
"codellama": f"mlx-community/CodeLlama-{variant.size}-Instruct", "codellama": f"mlx-community/CodeLlama-{variant.size}-Instruct",
"llama-3.2": f"mlx-community/Llama-3.2-{variant.size}-Instruct", "llama-3.2": f"mlx-community/Llama-3.2-{variant.size}-Instruct",
"phi-4": f"mlx-community/phi-4", "phi-4": f"mlx-community/phi-4",
"gemma-2": f"mlx-community/gemma-2-{variant.size}-it", "gemma-2": f"mlx-community/gemma-2-{variant.size}-it",
"starcoder2": f"mlx-community/starcoder2-{variant.size}", "starcoder2": f"mlx-community/starcoder2-{variant.size}-instruct-v0.1",
} }
base_repo = mlx_repo_map.get(model_id, "") base_repo = mlx_repo_map.get(model_id, "")