d30eedaa63
- Fix streaming to work even when tools are present (was forcing JSON mode) - Fix response format: use empty list [] instead of null for tool_calls - Add exclude_none config to ChatMessage model to match OpenAI format - Remove tool instructions from prompt (were confusing 3B model) - Fix tool call parsing to handle markdown code blocks properly - Change default instances from 3 to 1 for faster debugging - Allow 1 instance minimum in interactive config (was 2 on Mac) - Add debug logging to track requests and responses Fixes infinite loop issue where opencode would retry requests repeatedly
429 lines
14 KiB
Python
429 lines
14 KiB
Python
"""OpenAI-compatible API routes for Local Swarm."""
|
|
|
|
import time
|
|
import uuid
|
|
from typing import AsyncIterator, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from api.models import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionChoice,
|
|
ChatCompletionStreamResponse,
|
|
ChatCompletionStreamChoice,
|
|
ChatMessage,
|
|
UsageInfo,
|
|
ModelListResponse,
|
|
ModelInfo,
|
|
HealthResponse,
|
|
)
|
|
from swarm.manager import SwarmManager
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
# Global swarm manager instance (set during startup)
|
|
swarm_manager: Optional[SwarmManager] = None
|
|
|
|
|
|
def set_swarm_manager(manager: SwarmManager):
|
|
"""Set the global swarm manager instance."""
|
|
global swarm_manager
|
|
swarm_manager = manager
|
|
|
|
|
|
def format_tool_description(tool) -> str:
|
|
"""Format a tool definition for the prompt."""
|
|
func = tool.function
|
|
desc = f"### {func.name}\n"
|
|
desc += f"Description: {func.description}\n"
|
|
if func.parameters and func.parameters.get('properties'):
|
|
desc += "Parameters:\n"
|
|
for param_name, param_info in func.parameters['properties'].items():
|
|
param_desc = param_info.get('description', 'No description')
|
|
param_type = param_info.get('type', 'any')
|
|
required = param_name in func.parameters.get('required', [])
|
|
req_marker = " (required)" if required else ""
|
|
desc += f" - {param_name} ({param_type}){req_marker}: {param_desc}\n"
|
|
return desc
|
|
|
|
|
|
def format_messages_with_tools(messages: list, tools: Optional[list] = None) -> str:
|
|
"""Format chat messages into a single prompt using ChatML format.
|
|
|
|
Note: Tools are currently ignored - the model will respond normally.
|
|
"""
|
|
formatted = []
|
|
|
|
# Tools are accepted but ignored for now - model responds normally
|
|
|
|
for msg in messages:
|
|
role = msg.role
|
|
content = msg.content
|
|
|
|
if role == "system":
|
|
formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
|
|
elif role == "user":
|
|
formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
|
|
elif role == "assistant":
|
|
formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
|
elif role == "tool":
|
|
tool_name = getattr(msg, 'name', 'tool')
|
|
formatted.append(f"<|im_start|>tool\n{tool_name}: {content}<|im_end|>")
|
|
|
|
formatted.append("<|im_start|>assistant\n")
|
|
return "\n".join(formatted)
|
|
|
|
|
|
def parse_tool_calls(text: str) -> tuple:
|
|
"""Parse tool calls from model output.
|
|
|
|
Returns:
|
|
tuple: (content_without_tools, list_of_tool_calls or None)
|
|
"""
|
|
import json
|
|
import re
|
|
|
|
# Try to find JSON with tool_calls
|
|
try:
|
|
# Look for JSON object with tool_calls
|
|
json_match = re.search(r'\{[^}]*"tool_calls"[^}]*\}', text, re.DOTALL)
|
|
if json_match:
|
|
data = json.loads(json_match.group())
|
|
if "tool_calls" in data:
|
|
tool_calls = data["tool_calls"]
|
|
# Remove the JSON from the text
|
|
content = text[:json_match.start()].strip()
|
|
return content, tool_calls
|
|
except (json.JSONDecodeError, AttributeError):
|
|
pass
|
|
|
|
# Try alternative format: look for function call patterns
|
|
# Pattern: function_name(arg1=value1, arg2=value2)
|
|
func_pattern = r'(\w+)\s*\(([^)]*)\)'
|
|
matches = list(re.finditer(func_pattern, text))
|
|
|
|
if matches:
|
|
tool_calls = []
|
|
last_end = 0
|
|
content_parts = []
|
|
|
|
for i, match in enumerate(matches):
|
|
func_name = match.group(1)
|
|
args_str = match.group(2)
|
|
|
|
# Add text before this function call
|
|
content_parts.append(text[last_end:match.start()].strip())
|
|
last_end = match.end()
|
|
|
|
# Parse arguments
|
|
args_dict = {}
|
|
if args_str:
|
|
# Simple arg parsing: key=value
|
|
for arg in args_str.split(','):
|
|
if '=' in arg:
|
|
key, value = arg.split('=', 1)
|
|
args_dict[key.strip()] = value.strip().strip('"\'')
|
|
|
|
tool_calls.append({
|
|
"id": f"call_{i}",
|
|
"type": "function",
|
|
"function": {
|
|
"name": func_name,
|
|
"arguments": json.dumps(args_dict)
|
|
}
|
|
})
|
|
|
|
# Add remaining text
|
|
content_parts.append(text[last_end:].strip())
|
|
content = " ".join(p for p in content_parts if p)
|
|
|
|
return content, tool_calls
|
|
|
|
# No tool calls found
|
|
return text, None
|
|
|
|
|
|
# Keep old function for backward compatibility
|
|
def format_messages(messages: list) -> str:
|
|
"""Format chat messages into a single prompt using ChatML format."""
|
|
return format_messages_with_tools(messages, None)
|
|
|
|
|
|
@router.get("/v1/models", response_model=ModelListResponse)
|
|
async def list_models():
|
|
"""List available models."""
|
|
if swarm_manager is None:
|
|
raise HTTPException(status_code=503, detail="Swarm not initialized")
|
|
|
|
status = swarm_manager.get_status()
|
|
|
|
return ModelListResponse(
|
|
data=[
|
|
ModelInfo(
|
|
id="local-swarm",
|
|
created=int(time.time()),
|
|
owned_by="local-swarm"
|
|
),
|
|
ModelInfo(
|
|
id=status.model_name.lower().replace(" ", "-"),
|
|
created=int(time.time()),
|
|
owned_by="local-swarm"
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
@router.post("/v1/chat/completions")
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
"""
|
|
Generate chat completion.
|
|
|
|
Supports both regular and streaming responses.
|
|
"""
|
|
if swarm_manager is None:
|
|
raise HTTPException(status_code=503, detail="Swarm not initialized")
|
|
|
|
if not swarm_manager.get_status().is_running:
|
|
raise HTTPException(status_code=503, detail="Swarm not running")
|
|
|
|
# Format messages into prompt (with tools if provided)
|
|
prompt = format_messages_with_tools(request.messages, request.tools)
|
|
has_tools = request.tools is not None and len(request.tools) > 0
|
|
|
|
# Generate ID
|
|
completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
created = int(time.time())
|
|
|
|
if request.stream:
|
|
# Return streaming response
|
|
async def stream_generator() -> AsyncIterator[str]:
|
|
"""Generate SSE stream."""
|
|
# Send first chunk with role
|
|
first_chunk = ChatCompletionStreamResponse(
|
|
id=completion_id,
|
|
created=created,
|
|
model=request.model,
|
|
choices=[
|
|
ChatCompletionStreamChoice(
|
|
delta={"role": "assistant"}
|
|
)
|
|
]
|
|
)
|
|
yield f"data: {first_chunk.model_dump_json()}\n\n"
|
|
|
|
# Stream content
|
|
content_buffer = ""
|
|
async for chunk in swarm_manager.generate_stream(
|
|
prompt=prompt,
|
|
max_tokens=request.max_tokens or 1024,
|
|
temperature=request.temperature or 0.7
|
|
):
|
|
content_buffer += chunk
|
|
|
|
stream_chunk = ChatCompletionStreamResponse(
|
|
id=completion_id,
|
|
created=created,
|
|
model=request.model,
|
|
choices=[
|
|
ChatCompletionStreamChoice(
|
|
delta={"content": chunk}
|
|
)
|
|
]
|
|
)
|
|
yield f"data: {stream_chunk.model_dump_json()}\n\n"
|
|
|
|
# Send final chunk
|
|
final_chunk = ChatCompletionStreamResponse(
|
|
id=completion_id,
|
|
created=created,
|
|
model=request.model,
|
|
choices=[
|
|
ChatCompletionStreamChoice(
|
|
delta={},
|
|
finish_reason="stop"
|
|
)
|
|
]
|
|
)
|
|
yield f"data: {final_chunk.model_dump_json()}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
stream_generator(),
|
|
media_type="text/event-stream"
|
|
)
|
|
|
|
else:
|
|
# Regular response with consensus
|
|
try:
|
|
result = await swarm_manager.generate(
|
|
prompt=prompt,
|
|
max_tokens=request.max_tokens or 1024,
|
|
temperature=request.temperature or 0.7,
|
|
use_consensus=True
|
|
)
|
|
|
|
response_text = result.selected_response.text
|
|
tokens_generated = result.selected_response.tokens_generated
|
|
|
|
# Parse tool calls if tools were provided
|
|
content = response_text
|
|
tool_calls = []
|
|
finish_reason = "stop"
|
|
|
|
if has_tools:
|
|
content, tool_calls = parse_tool_calls(response_text)
|
|
if tool_calls:
|
|
finish_reason = "tool_calls"
|
|
# Convert to ToolCall objects
|
|
from api.models import ToolCall
|
|
tool_calls = [
|
|
ToolCall(
|
|
id=tc.get("id", f"call_{i}"),
|
|
type=tc.get("type", "function"),
|
|
function=tc.get("function", {})
|
|
)
|
|
for i, tc in enumerate(tool_calls)
|
|
]
|
|
|
|
# Estimate prompt tokens (rough approximation)
|
|
prompt_tokens = len(prompt) // 4
|
|
|
|
return ChatCompletionResponse(
|
|
id=completion_id,
|
|
created=created,
|
|
model=request.model,
|
|
choices=[
|
|
ChatCompletionChoice(
|
|
index=0,
|
|
message=ChatMessage(
|
|
role="assistant",
|
|
content=content,
|
|
tool_calls=tool_calls
|
|
),
|
|
finish_reason=finish_reason
|
|
)
|
|
],
|
|
usage=UsageInfo(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=tokens_generated,
|
|
total_tokens=prompt_tokens + tokens_generated
|
|
)
|
|
)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
|
|
|
|
|
@router.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""Check API and swarm health."""
|
|
if swarm_manager is None:
|
|
return HealthResponse(
|
|
status="initializing",
|
|
version="0.1.0",
|
|
workers=0,
|
|
model="unknown"
|
|
)
|
|
|
|
status = swarm_manager.get_status()
|
|
|
|
return HealthResponse(
|
|
status="healthy" if status.is_running else "degraded",
|
|
version="0.1.0",
|
|
workers=status.healthy_workers,
|
|
model=status.model_name
|
|
)
|
|
|
|
|
|
@router.get("/v1/health", response_model=HealthResponse)
|
|
async def health_check_v1():
|
|
"""Health check at /v1/health endpoint."""
|
|
return await health_check()
|
|
|
|
|
|
# Global federation instance (set during startup)
|
|
federated_swarm = None
|
|
|
|
|
|
def set_federated_swarm(federation):
|
|
"""Set the global federation instance."""
|
|
global federated_swarm
|
|
federated_swarm = federation
|
|
|
|
|
|
@router.post("/v1/federation/vote")
|
|
async def federation_vote(request: dict):
|
|
"""
|
|
Receive a vote request from a peer swarm.
|
|
|
|
This endpoint allows other swarms to request our "best local" response
|
|
for federated consensus.
|
|
"""
|
|
if swarm_manager is None:
|
|
raise HTTPException(status_code=503, detail="Swarm not initialized")
|
|
|
|
if not swarm_manager.get_status().is_running:
|
|
raise HTTPException(status_code=503, detail="Swarm not running")
|
|
|
|
prompt = request.get("prompt", "")
|
|
max_tokens = request.get("max_tokens", 1024)
|
|
temperature = request.get("temperature", 0.7)
|
|
|
|
try:
|
|
# Generate with local consensus
|
|
result = await swarm_manager.generate(
|
|
prompt=prompt,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
use_consensus=True
|
|
)
|
|
|
|
return {
|
|
"response": result.selected_response.text,
|
|
"confidence": result.confidence,
|
|
"latency_ms": result.selected_response.latency_ms,
|
|
"worker_count": len(result.all_responses),
|
|
"strategy": result.strategy
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
|
|
|
|
|
@router.get("/v1/federation/status")
|
|
async def federation_status():
|
|
"""Get federation status."""
|
|
if federated_swarm is None:
|
|
return {
|
|
"enabled": False,
|
|
"message": "Federation not enabled"
|
|
}
|
|
|
|
status = await federated_swarm.get_federation_status()
|
|
return status
|
|
|
|
|
|
@router.get("/v1/federation/peers")
|
|
async def federation_peers():
|
|
"""Get list of discovered peers."""
|
|
if federated_swarm is None or federated_swarm.discovery is None:
|
|
return {"peers": []}
|
|
|
|
peers = federated_swarm.discovery.get_peers()
|
|
return {
|
|
"peers": [
|
|
{
|
|
"name": p.name,
|
|
"host": p.host,
|
|
"port": p.port,
|
|
"model_id": p.model_id,
|
|
"instances": p.instances,
|
|
"api_url": p.api_url
|
|
}
|
|
for p in peers
|
|
]
|
|
}
|