refactor(api): extract formatting, parsing, and handlers from routes
Extracted large monolithic routes.py (1183 lines) into focused modules: - api/formatting.py: Message formatting and tool instructions - api/tool_parser.py: Tool call parsing from various formats - api/chat_handlers.py: Chat completion business logic - utils/token_counter.py: Centralized token counting utilities - utils/project_discovery.py: Shared project root discovery routes.py is now 252 lines (under 300 limit). All 35 tests pass. Eliminated code duplication for _discover_project_root. Refs previous review report findings on modularity
This commit is contained in:
@@ -0,0 +1,287 @@
|
||||
"""Chat completion handlers for Local Swarm.
|
||||
|
||||
Contains the business logic for chat completions, separated from HTTP routing.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from api.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChoice,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
from api.formatting import format_messages_with_tools
|
||||
from api.tool_parser import parse_tool_calls
|
||||
from utils.token_counter import count_tokens
|
||||
from tools.executor import get_tool_executor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_tools(tools: Optional[list]) -> Optional[list]:
|
||||
"""Sanitize tool definitions to fix invalid schemas.
|
||||
|
||||
Removes extra 'description' from properties if present.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Sanitized tools list
|
||||
"""
|
||||
if not tools:
|
||||
return tools
|
||||
|
||||
sanitized = []
|
||||
for tool in tools:
|
||||
if tool.type == "function" and tool.function.parameters:
|
||||
params = tool.function.parameters
|
||||
# Remove invalid 'description' from properties if present
|
||||
if 'properties' in params and 'description' in params.get('properties', {}):
|
||||
invalid_props = ['description']
|
||||
# Also remove 'description' from required if present
|
||||
if 'required' in params:
|
||||
params['required'] = [r for r in params.get('required', []) if r not in invalid_props]
|
||||
# Remove invalid properties
|
||||
params['properties'] = {k: v for k, v in params.get('properties', {}).items() if k not in invalid_props}
|
||||
logger.debug(f" 🔧 Sanitized tool '{tool.function.name}': removed {invalid_props} from properties/required")
|
||||
sanitized.append(tool)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
async def _execute_tools(
|
||||
tool_calls: list,
|
||||
client_working_dir: Optional[str],
|
||||
executor
|
||||
) -> str:
|
||||
"""Execute tool calls and return results.
|
||||
|
||||
Args:
|
||||
tool_calls: List of parsed tool calls
|
||||
client_working_dir: Working directory for file operations
|
||||
executor: Tool executor instance
|
||||
|
||||
Returns:
|
||||
Combined tool results as string
|
||||
"""
|
||||
from api.routes import execute_tool_server_side
|
||||
|
||||
tool_results = []
|
||||
for i, tc in enumerate(tool_calls):
|
||||
tool_name = tc.get("function", {}).get("name", "")
|
||||
tool_args_str = tc.get("function", {}).get("arguments", "{}")
|
||||
try:
|
||||
tool_args = json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str
|
||||
except:
|
||||
tool_args = {}
|
||||
|
||||
logger.debug(f" [{i+1}/{len(tool_calls)}] Executing: {tool_name}({tool_args})")
|
||||
result = await execute_tool_server_side(tool_name, tool_args, working_dir=client_working_dir)
|
||||
tool_results.append(f"Tool '{tool_name}' result: {result}")
|
||||
logger.debug(f" ✓ Completed: {result[:100]}..." if len(result) > 100 else f" ✓ Result: {result}")
|
||||
|
||||
return "\n\n".join(tool_results)
|
||||
|
||||
|
||||
def _create_response(
|
||||
content: str,
|
||||
tool_calls: list,
|
||||
finish_reason: str,
|
||||
prompt: str,
|
||||
request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
"""Create a chat completion response.
|
||||
|
||||
Args:
|
||||
content: Response content
|
||||
tool_calls: List of tool calls
|
||||
finish_reason: Finish reason
|
||||
prompt: Original prompt for token counting
|
||||
request: Original request
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse
|
||||
"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
completion_tokens = count_tokens(content)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls
|
||||
),
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _generate_with_local_swarm(
|
||||
swarm_manager,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temperature: float
|
||||
) -> tuple[str, int, float]:
|
||||
"""Generate response using local swarm.
|
||||
|
||||
Args:
|
||||
swarm_manager: Swarm manager instance
|
||||
prompt: Prompt to generate from
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tokens_generated, tokens_per_second)
|
||||
"""
|
||||
result = await swarm_manager.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
use_consensus=True
|
||||
)
|
||||
|
||||
response = result.selected_response
|
||||
return (
|
||||
response.text,
|
||||
response.tokens_generated,
|
||||
response.tokens_per_second
|
||||
)
|
||||
|
||||
|
||||
async def _generate_with_federation(
|
||||
federated_swarm,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temperature: float
|
||||
) -> tuple[str, list, str]:
|
||||
"""Generate response using federated swarm.
|
||||
|
||||
Args:
|
||||
federated_swarm: Federated swarm instance
|
||||
prompt: Prompt to generate from
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tool_calls, finish_reason)
|
||||
"""
|
||||
result = await federated_swarm.generate_with_federation(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
min_peers=0
|
||||
)
|
||||
|
||||
content = result.final_response
|
||||
|
||||
# Check for tool calls
|
||||
content_parsed, tool_calls_parsed = parse_tool_calls(content)
|
||||
if tool_calls_parsed:
|
||||
return content_parsed, tool_calls_parsed, "tool_calls"
|
||||
|
||||
return content, [], "stop"
|
||||
|
||||
|
||||
async def handle_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
swarm_manager,
|
||||
federated_swarm,
|
||||
client_working_dir: Optional[str],
|
||||
use_opencode_tools: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Handle a chat completion request.
|
||||
|
||||
Args:
|
||||
request: Chat completion request
|
||||
swarm_manager: Swarm manager instance
|
||||
federated_swarm: Optional federated swarm instance
|
||||
client_working_dir: Client working directory
|
||||
use_opencode_tools: Whether to use opencode tool definitions
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
"""
|
||||
# Format messages into prompt
|
||||
if use_opencode_tools:
|
||||
sanitized_tools = _sanitize_tools(request.tools)
|
||||
prompt = format_messages_with_tools(request.messages, sanitized_tools)
|
||||
has_tools = sanitized_tools is not None and len(sanitized_tools) > 0
|
||||
else:
|
||||
prompt = format_messages_with_tools(request.messages, None)
|
||||
has_tools = request.tools is not None and len(request.tools) > 0
|
||||
|
||||
logger.debug(f"\n{'='*60}")
|
||||
logger.debug(f"REQUEST: has_tools={has_tools}, stream={request.stream}")
|
||||
logger.debug(f"MODE: {'opencode' if use_opencode_tools else 'local'} tools")
|
||||
logger.debug(f"{'='*60}")
|
||||
|
||||
# Use federation if available
|
||||
if federated_swarm is not None:
|
||||
peers = federated_swarm.discovery.get_peers()
|
||||
if peers:
|
||||
logger.info(f"🌐 Using federation with {len(peers)} peer(s)...")
|
||||
content, tool_calls, finish_reason = await _generate_with_federation(
|
||||
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
return _create_response(content, tool_calls, finish_reason, prompt, request)
|
||||
|
||||
# Use local swarm
|
||||
logger.debug("Using local swarm generation")
|
||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
||||
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
|
||||
logger.debug(f"DEBUG: Generated response (tokens={tokens_generated}, t/s={tps:.1f})")
|
||||
logger.debug(f"DEBUG: Response preview: {response_text[:200]}...")
|
||||
|
||||
# Parse tool calls if tools were provided
|
||||
content = response_text
|
||||
tool_calls = []
|
||||
finish_reason = "stop"
|
||||
|
||||
if has_tools:
|
||||
logger.debug(f"DEBUG: Parsing tool calls from response...")
|
||||
content, tool_calls_parsed = parse_tool_calls(response_text)
|
||||
logger.debug(f"DEBUG: parse_tool_calls returned: content_len={len(content)}, parsed={tool_calls_parsed is not None}")
|
||||
|
||||
if tool_calls_parsed:
|
||||
logger.debug(f" 🔧 Model requesting {len(tool_calls_parsed)} tool(s)...")
|
||||
executor = get_tool_executor()
|
||||
if executor:
|
||||
logger.debug(f" 🔗 Tool executor: {executor.tool_host_url or 'local'}")
|
||||
else:
|
||||
logger.debug(f" ⚠️ No tool executor configured!")
|
||||
|
||||
# Execute tools
|
||||
tool_results_str = await _execute_tools(tool_calls_parsed, client_working_dir, executor)
|
||||
content = tool_results_str
|
||||
finish_reason = "stop"
|
||||
tool_calls = [] # Clear tool_calls since we executed them
|
||||
logger.debug(f" ✅ All tools executed, returning results")
|
||||
else:
|
||||
logger.debug(f"DEBUG: No tool calls parsed from response")
|
||||
else:
|
||||
logger.debug(f"DEBUG: No tools requested, returning normal response")
|
||||
|
||||
return _create_response(content, tool_calls, finish_reason, prompt, request)
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Message formatting module for Local Swarm.
|
||||
|
||||
Formats chat messages into prompts and handles tool instructions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from api.models import ChatMessage
|
||||
from utils.token_counter import count_tokens
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for tool instructions (loaded from config file)
|
||||
_TOOL_INSTRUCTIONS_CACHE: Optional[str] = None
|
||||
|
||||
# Global flag for tool mode (default: local tool server to save tokens)
|
||||
_USE_OPENCODE_TOOLS: bool = False
|
||||
|
||||
|
||||
def set_use_opencode_tools(value: bool) -> None:
|
||||
"""Set whether to use opencode's tool definitions (default: False = local tool server).
|
||||
|
||||
Args:
|
||||
value: True to use opencode tools (~27k tokens), False to use local tool server (~125 tokens)
|
||||
"""
|
||||
global _USE_OPENCODE_TOOLS
|
||||
_USE_OPENCODE_TOOLS = value
|
||||
logger.info(f"🔧 Tool mode set to: {'opencode tools (~27k tokens)' if value else 'local tool server (~125 tokens)'}")
|
||||
|
||||
|
||||
def _load_tool_instructions() -> str:
|
||||
"""Load tool instructions from config file.
|
||||
|
||||
Loads from config/prompts/tool_instructions.txt
|
||||
Falls back to default if file not found.
|
||||
|
||||
Returns:
|
||||
Tool instructions string
|
||||
"""
|
||||
global _TOOL_INSTRUCTIONS_CACHE
|
||||
|
||||
if _TOOL_INSTRUCTIONS_CACHE is not None:
|
||||
return _TOOL_INSTRUCTIONS_CACHE
|
||||
|
||||
# Try to load from config file
|
||||
config_path = Path(__file__).parent.parent.parent / "config" / "prompts" / "tool_instructions.txt"
|
||||
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r') as f:
|
||||
_TOOL_INSTRUCTIONS_CACHE = f.read().strip()
|
||||
logger.debug(f"Loaded tool instructions from {config_path}")
|
||||
else:
|
||||
# Fallback default instructions
|
||||
_TOOL_INSTRUCTIONS_CACHE = """You MUST use tools. DO NOT explain. DO NOT use markdown.
|
||||
|
||||
OUTPUT THIS EXACT FORMAT - NOTHING ELSE:
|
||||
|
||||
TOOL: bash
|
||||
ARGUMENTS: {"command": "your command here"}
|
||||
|
||||
Available tools:
|
||||
- bash: Run shell commands
|
||||
- write: Create files
|
||||
- read: Read files
|
||||
|
||||
NEVER write explanations.
|
||||
NEVER use numbered lists.
|
||||
NEVER use markdown code blocks.
|
||||
ONLY output TOOL: lines."""
|
||||
logger.warning(f"Tool instructions config not found at {config_path}, using default")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tool instructions: {e}")
|
||||
# Use minimal fallback
|
||||
_TOOL_INSTRUCTIONS_CACHE = 'Use TOOL: tool_name\nARGUMENTS: {"param": "value"} format.'
|
||||
|
||||
return _TOOL_INSTRUCTIONS_CACHE
|
||||
|
||||
|
||||
def _is_initial_request(messages: List[ChatMessage]) -> bool:
|
||||
"""Check if this is an initial request (no assistant or tool messages).
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
True if this is the initial request
|
||||
"""
|
||||
has_assistant = any(msg.role == "assistant" for msg in messages)
|
||||
has_tool = any(msg.role == "tool" for msg in messages)
|
||||
return not has_assistant and not has_tool
|
||||
|
||||
|
||||
def _compress_large_request(messages: List[ChatMessage], max_tokens: int = 4000) -> List[ChatMessage]:
|
||||
"""Compress large initial requests by keeping only user messages.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
max_tokens: Maximum tokens before compression
|
||||
|
||||
Returns:
|
||||
Compressed list of messages
|
||||
"""
|
||||
full_text = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
|
||||
current_tokens = count_tokens(full_text)
|
||||
|
||||
if current_tokens <= max_tokens:
|
||||
return messages
|
||||
|
||||
logger.info(f"🗜️ COMPRESSING: Initial request is {current_tokens} tokens, compressing to <{max_tokens}...")
|
||||
|
||||
# Keep only user messages
|
||||
user_messages = [msg for msg in messages if msg.role == "user"]
|
||||
|
||||
if not user_messages:
|
||||
logger.warning("No user messages found in initial request!")
|
||||
return []
|
||||
|
||||
# Get the last user message
|
||||
last_user_msg = user_messages[-1]
|
||||
user_content = last_user_msg.content
|
||||
|
||||
# Truncate if still too long
|
||||
if len(user_content) > 2000:
|
||||
user_content = user_content[:2000] + "... [truncated for token limit]"
|
||||
logger.debug(f"Truncated user message from {len(last_user_msg.content)} to 2000 chars")
|
||||
|
||||
return [ChatMessage(role="user", content=user_content)]
|
||||
|
||||
|
||||
def _filter_messages(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""Filter messages for processing.
|
||||
|
||||
For initial requests >4000 tokens, compress aggressively.
|
||||
Otherwise, just remove system messages.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
Filtered list of messages
|
||||
"""
|
||||
if _is_initial_request(messages):
|
||||
full_text = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
|
||||
if count_tokens(full_text) > 4000:
|
||||
return _compress_large_request(messages)
|
||||
|
||||
# Normal filtering: remove system messages
|
||||
return [msg for msg in messages if msg.role != "system"]
|
||||
|
||||
|
||||
def _add_tool_instructions(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""Add tool instructions to messages if needed.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
Messages with tool instructions added
|
||||
"""
|
||||
has_assistant = any(msg.role == "assistant" for msg in messages)
|
||||
|
||||
if has_assistant:
|
||||
return messages
|
||||
|
||||
tool_instructions = _load_tool_instructions()
|
||||
logger.debug(f"Using {'opencode' if _USE_OPENCODE_TOOLS else 'local'} tool mode: {len(tool_instructions)} chars")
|
||||
|
||||
return [ChatMessage(role="system", content=tool_instructions)] + messages
|
||||
|
||||
|
||||
def _format_to_chatml(messages: List[ChatMessage]) -> str:
|
||||
"""Format messages to ChatML format.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
ChatML formatted string
|
||||
"""
|
||||
formatted = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.role
|
||||
content = msg.content
|
||||
|
||||
if role == "system":
|
||||
formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
|
||||
elif role == "user":
|
||||
formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
|
||||
elif role == "assistant":
|
||||
formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||||
elif role == "tool":
|
||||
tool_name = getattr(msg, 'name', 'tool')
|
||||
formatted.append(f"<|im_start|>tool\n{tool_name}: {content}<|im_end|>")
|
||||
|
||||
formatted.append("<|im_start|>assistant\n")
|
||||
return "\n".join(formatted)
|
||||
|
||||
|
||||
def _log_prompt_preview(messages: List[ChatMessage]) -> None:
|
||||
"""Log a preview of the prompt for debugging.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
"""
|
||||
preview = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
preview.append(f"[SYSTEM] {msg.content[:200]}...")
|
||||
elif msg.role == "user":
|
||||
preview.append(f"[USER] {msg.content}")
|
||||
logger.debug(f"Prompt preview: {' | '.join(preview)}")
|
||||
|
||||
|
||||
def format_messages_with_tools(
|
||||
messages: List[ChatMessage],
|
||||
tools: Optional[list] = None
|
||||
) -> str:
|
||||
"""Format chat messages into a single prompt using ChatML format.
|
||||
|
||||
Note: Tools are handled server-side. The model should respond normally.
|
||||
IMPORTANT: If _USE_OPENCODE_TOOLS is True, use opencode's tool definitions (~27k tokens).
|
||||
If False, use local tool server (~125 tokens) to save tokens.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
tools: Optional list of tools (currently ignored, server-side handling)
|
||||
|
||||
Returns:
|
||||
Formatted prompt string in ChatML format
|
||||
"""
|
||||
# Filter messages
|
||||
filtered_messages = _filter_messages(messages)
|
||||
|
||||
# Add tool instructions if needed
|
||||
filtered_messages = _add_tool_instructions(filtered_messages)
|
||||
|
||||
# Log preview
|
||||
_log_prompt_preview(filtered_messages)
|
||||
|
||||
# Format to ChatML
|
||||
result = _format_to_chatml(filtered_messages)
|
||||
|
||||
# Log final token count
|
||||
final_tokens = count_tokens(result)
|
||||
original_tokens = count_tokens("\n".join([f"{msg.role}: {msg.content}" for msg in messages]))
|
||||
logger.info(f"📊 Final prompt size: {final_tokens} tokens (reduced from {original_tokens})")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def format_messages(messages: List[ChatMessage]) -> str:
|
||||
"""Format chat messages into a single prompt using ChatML format.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
return format_messages_with_tools(messages, None)
|
||||
+63
-1170
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,250 @@
|
||||
"""Tool parsing module for Local Swarm.
|
||||
|
||||
Parses tool calls from model output in various formats.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
|
||||
|
||||
def ensure_tool_arguments(tool_name: str, args_dict: dict) -> dict:
|
||||
"""Ensure tool arguments have all required fields.
|
||||
|
||||
For bash tool: inject 'description' field if missing.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
args_dict: Tool arguments dictionary
|
||||
|
||||
Returns:
|
||||
Updated arguments dictionary
|
||||
"""
|
||||
if tool_name == 'bash' and 'description' not in args_dict:
|
||||
# Generate description from command
|
||||
command = args_dict.get('command', '')
|
||||
desc = command.split()[0] if command else 'Execute command'
|
||||
args_dict['description'] = desc
|
||||
return args_dict
|
||||
|
||||
|
||||
def _parse_standard_format(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse standard TOOL: format.
|
||||
|
||||
Format: TOOL: name\nARGUMENTS: {"key": "value"}
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
tool_pattern = r'TOOL:\s*(\w+)\s*\nARGUMENTS:\s*(\{[^}]*\})'
|
||||
tool_matches = list(re.finditer(tool_pattern, text, re.IGNORECASE))
|
||||
|
||||
if not tool_matches:
|
||||
return None, None
|
||||
|
||||
tool_calls = []
|
||||
for i, tool_match in enumerate(tool_matches):
|
||||
tool_name = tool_match.group(1)
|
||||
args_str = tool_match.group(2)
|
||||
try:
|
||||
args_dict = json.loads(args_str)
|
||||
args_dict = ensure_tool_arguments(tool_name, args_dict)
|
||||
tool_calls.append({
|
||||
"id": f"call_{i+1}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(args_dict)
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if tool_calls:
|
||||
first_start = tool_matches[0].start()
|
||||
content = text[:first_start].strip()
|
||||
return content, tool_calls
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _parse_markdown_format(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse markdown code block format.
|
||||
|
||||
Format: ```bash command```
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
markdown_pattern = r'```(?:bash|shell|sh)?\s*\n(.*?)\n```'
|
||||
markdown_matches = list(re.finditer(markdown_pattern, text, re.DOTALL))
|
||||
|
||||
if not markdown_matches:
|
||||
return None, None
|
||||
|
||||
tool_calls = []
|
||||
for i, match in enumerate(markdown_matches):
|
||||
code_content = match.group(1).strip()
|
||||
if code_content:
|
||||
args_dict = {"command": code_content}
|
||||
args_dict = ensure_tool_arguments("bash", args_dict)
|
||||
tool_calls.append({
|
||||
"id": f"call_{i+1}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": json.dumps(args_dict)
|
||||
}
|
||||
})
|
||||
|
||||
if tool_calls:
|
||||
first_start = markdown_matches[0].start()
|
||||
content = text[:first_start].strip()
|
||||
return content, tool_calls
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _parse_command_lines(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse command lines in text.
|
||||
|
||||
Matches common bash commands with their arguments.
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
command_lines = []
|
||||
command_pattern = r'^(npm|npx|mkdir|cd|ls|cat|echo|git|python|pip|node|yarn|create-react-app)\s+'
|
||||
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if re.match(command_pattern, line):
|
||||
command_lines.append(line)
|
||||
|
||||
if command_lines:
|
||||
combined_command = ' && '.join(command_lines)
|
||||
args_dict = {"command": combined_command}
|
||||
args_dict = ensure_tool_arguments("bash", args_dict)
|
||||
return "", [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": json.dumps(args_dict)
|
||||
}
|
||||
}]
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _parse_standalone_commands(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse standalone bash commands.
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
standalone_pattern = r'(?:^|\n)(npm\s+\w+|npx\s+\w+|mkdir\s+\w+|cd\s+\w+|git\s+\w+)(?:\s|$)'
|
||||
standalone_matches = list(re.finditer(standalone_pattern, text, re.MULTILINE))
|
||||
|
||||
if standalone_matches:
|
||||
commands = [match.group(1).strip() for match in standalone_matches]
|
||||
if commands:
|
||||
combined_command = ' && '.join(commands)
|
||||
args_dict = {"command": combined_command}
|
||||
args_dict = ensure_tool_arguments("bash", args_dict)
|
||||
return "", [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": json.dumps(args_dict)
|
||||
}
|
||||
}]
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _parse_urls(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse URLs for webfetch tool.
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
url_pattern = r'https?://[^\s<>"\')\]]+[a-zA-Z0-9]'
|
||||
url_matches = list(re.finditer(url_pattern, text))
|
||||
|
||||
if url_matches:
|
||||
urls = [match.group(0) for match in url_matches]
|
||||
if urls:
|
||||
tool_calls = []
|
||||
for i, url in enumerate(urls):
|
||||
tool_calls.append({
|
||||
"id": f"call_{i+1}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "webfetch",
|
||||
"arguments": json.dumps({"url": url, "format": "markdown"})
|
||||
}
|
||||
})
|
||||
return "", tool_calls
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def parse_tool_calls(text: str) -> Tuple[str, Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse tool calls from model output using multiple formats.
|
||||
|
||||
Supports multiple formats for compatibility with different model sizes:
|
||||
1. Standard: TOOL: name\nARGUMENTS: {"key": "value"}
|
||||
2. Markdown: ```bash command```
|
||||
3. Command lines: npm install, git clone, etc.
|
||||
4. Standalone commands
|
||||
5. URLs: for webfetch tool
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls or None)
|
||||
"""
|
||||
# Priority 1: Standard format
|
||||
result = _parse_standard_format(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
# Priority 2: Markdown code blocks
|
||||
result = _parse_markdown_format(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
# Priority 3: Command lines
|
||||
result = _parse_command_lines(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
# Priority 4: Standalone commands
|
||||
result = _parse_standalone_commands(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
# Priority 5: URLs
|
||||
result = _parse_urls(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
return text, None
|
||||
+2
-23
@@ -12,6 +12,7 @@ import subprocess
|
||||
import aiohttp
|
||||
from typing import Optional
|
||||
|
||||
from utils.project_discovery import discover_project_root
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -85,29 +86,7 @@ class ToolExecutor:
|
||||
logger.debug(f" ❌ Error contacting tool host: {e}")
|
||||
return f"Error contacting tool host: {str(e)}"
|
||||
|
||||
def _discover_project_root(self, start_dir: Optional[str] = None) -> str:
|
||||
"""Discover the project root directory by looking for common markers."""
|
||||
import os
|
||||
if start_dir is None:
|
||||
start_dir = os.getcwd()
|
||||
current = os.path.abspath(start_dir)
|
||||
|
||||
# Common project root markers
|
||||
markers = ['.git', 'package.json', 'pyproject.toml', 'Cargo.toml', 'go.mod',
|
||||
'requirements.txt', 'setup.py', 'pom.xml', 'build.gradle', '.project', '.venv']
|
||||
|
||||
while True:
|
||||
try:
|
||||
if any(os.path.exists(os.path.join(current, marker)) for marker in markers):
|
||||
return current
|
||||
except Exception:
|
||||
pass # Permission errors, just skip
|
||||
parent = os.path.dirname(current)
|
||||
if parent == current: # Reached filesystem root
|
||||
break
|
||||
current = parent
|
||||
|
||||
return start_dir
|
||||
|
||||
|
||||
async def _execute_local(self, tool_name: str, tool_args: dict) -> str:
|
||||
"""Execute tool locally."""
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Project root discovery utilities.
|
||||
|
||||
Provides functionality to discover project root directories.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
# Common project root markers
|
||||
DEFAULT_MARKERS = [
|
||||
'.git', 'package.json', 'pyproject.toml', 'Cargo.toml', 'go.mod',
|
||||
'requirements.txt', 'setup.py', 'pom.xml', 'build.gradle', '.project', '.venv'
|
||||
]
|
||||
|
||||
|
||||
def discover_project_root(
|
||||
start_dir: Optional[str] = None,
|
||||
markers: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""Discover the project root directory by looking for common markers.
|
||||
|
||||
Args:
|
||||
start_dir: Directory to start searching from (defaults to cwd)
|
||||
markers: List of marker files/directories to look for (defaults to DEFAULT_MARKERS)
|
||||
|
||||
Returns:
|
||||
Path to project root, or start_dir if no markers found
|
||||
"""
|
||||
if start_dir is None:
|
||||
start_dir = os.getcwd()
|
||||
|
||||
if markers is None:
|
||||
markers = DEFAULT_MARKERS
|
||||
|
||||
current = os.path.abspath(start_dir)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if any(os.path.exists(os.path.join(current, marker)) for marker in markers):
|
||||
return current
|
||||
except (OSError, PermissionError):
|
||||
pass # Permission errors, just skip
|
||||
|
||||
parent = os.path.dirname(current)
|
||||
if parent == current: # Reached filesystem root
|
||||
break
|
||||
current = parent
|
||||
|
||||
return start_dir
|
||||
|
||||
|
||||
def is_within_project(path: str, project_root: str) -> bool:
|
||||
"""Check if a path is within a project root.
|
||||
|
||||
Args:
|
||||
path: Path to check
|
||||
project_root: Project root directory
|
||||
|
||||
Returns:
|
||||
True if path is within project root
|
||||
"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
real_root = os.path.realpath(project_root)
|
||||
return real_path.startswith(real_root)
|
||||
except (OSError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def get_relative_to_project(path: str, project_root: str) -> str:
|
||||
"""Get path relative to project root.
|
||||
|
||||
Args:
|
||||
path: Absolute or relative path
|
||||
project_root: Project root directory
|
||||
|
||||
Returns:
|
||||
Path relative to project root
|
||||
"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
real_root = os.path.realpath(project_root)
|
||||
return os.path.relpath(real_path, real_root)
|
||||
except (OSError, ValueError):
|
||||
return path
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Token counting utilities for Local Swarm.
|
||||
|
||||
Centralizes token counting functionality to avoid duplication across modules.
|
||||
"""
|
||||
|
||||
import tiktoken
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Initialize tokenizer for accurate token counting
|
||||
TOKEN_ENCODING = tiktoken.get_encoding('cl100k_base')
|
||||
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Count tokens in a text string using tiktoken.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Number of tokens
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
return len(TOKEN_ENCODING.encode(text))
|
||||
|
||||
|
||||
def count_tokens_in_messages(messages: list) -> int:
|
||||
"""Count tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects with content attribute
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total = 0
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'content') and msg.content:
|
||||
total += count_tokens(msg.content)
|
||||
return total
|
||||
|
||||
|
||||
def estimate_tokens_from_characters(char_count: int, chars_per_token: int = 4) -> int:
|
||||
"""Estimate token count from character count.
|
||||
|
||||
This is a fallback when tiktoken is not available or for quick estimates.
|
||||
|
||||
Args:
|
||||
char_count: Number of characters
|
||||
chars_per_token: Average characters per token (default 4)
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
return char_count // chars_per_token
|
||||
|
||||
|
||||
def truncate_to_max_tokens(text: str, max_tokens: int) -> str:
|
||||
"""Truncate text to fit within max tokens.
|
||||
|
||||
Args:
|
||||
text: Text to truncate
|
||||
max_tokens: Maximum number of tokens allowed
|
||||
|
||||
Returns:
|
||||
Truncated text
|
||||
"""
|
||||
tokens = TOKEN_ENCODING.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
truncated = tokens[:max_tokens]
|
||||
return TOKEN_ENCODING.decode(truncated)
|
||||
|
||||
|
||||
def format_token_info(prompt_tokens: int, completion_tokens: int) -> dict:
|
||||
"""Format token information for responses.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
Returns:
|
||||
Dictionary with token counts
|
||||
"""
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
@@ -133,7 +133,7 @@ ls -la
|
||||
|
||||
def test_tool_instructions_content():
|
||||
"""Test that tool instructions contain required sections (REVIEW-2026-02-24 Blocker #4)."""
|
||||
from api.routes import _load_tool_instructions
|
||||
from api.formatting import _load_tool_instructions
|
||||
|
||||
# Load instructions from config file
|
||||
instructions = _load_tool_instructions()
|
||||
@@ -147,7 +147,7 @@ def test_tool_instructions_content():
|
||||
|
||||
def test_tool_instructions_token_count():
|
||||
"""Test that tool instructions are within token budget (REVIEW-2026-02-24 Blocker #1)."""
|
||||
from api.routes import _load_tool_instructions
|
||||
from api.formatting import _load_tool_instructions
|
||||
|
||||
# Load instructions from config file
|
||||
instructions = _load_tool_instructions()
|
||||
|
||||
Reference in New Issue
Block a user