Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bdc8db9678 | |||
| 0134ccae53 | |||
| 4ea36783d6 | |||
| 6ab726b46c | |||
| d22c52ec04 | |||
| 5fa8cd4e0e | |||
| 2c46d48004 | |||
| 0945cee162 | |||
| 58e4b2c645 | |||
| 929f069d14 | |||
| bdcb013d6b | |||
| 9fdc3a6d02 | |||
| c18c20487c | |||
| 1d1d7b4468 | |||
| 4f2b9252c4 | |||
| 3dbc76de04 | |||
| af2d616f76 | |||
| 1ac32c7ec3 | |||
| d33fa406b6 |
@@ -0,0 +1,20 @@
|
||||
# opencode ignore patterns
|
||||
# Excludes large documentation files from context padding
|
||||
|
||||
# Agent rules (not project context)
|
||||
AGENT_WORKER.md
|
||||
AGENT_REVIEW.md
|
||||
|
||||
# Review reports
|
||||
reports/
|
||||
|
||||
# Design docs and test plans (historical documentation)
|
||||
docs/design/
|
||||
docs/test-plans/
|
||||
|
||||
# TODO file
|
||||
TODO.md
|
||||
|
||||
# Non-code files
|
||||
*.md
|
||||
!README.md
|
||||
@@ -64,6 +64,19 @@
|
||||
- No circular imports
|
||||
- No duplicate code (>3 lines copied)
|
||||
|
||||
- [ ] **Minimal, Maintainable, Modular Code**
|
||||
- **Minimal:** Only code needed to solve the problem, no over-engineering
|
||||
- **Maintainable:** Clear names, self-documenting, consistent style
|
||||
- **Modular:** Single Responsibility Principle, loose coupling, clear interfaces
|
||||
- **STRICT ENFORCEMENT:**
|
||||
- Functions should do ONE thing (if it does 2+ things, break it up)
|
||||
- No monolithic blocks (>50 lines in one function)
|
||||
- Clear separation of concerns
|
||||
- Interfaces between modules are stable and well-defined
|
||||
- Easy to understand for new maintainers
|
||||
- No "temp" or "quick" solutions - production quality only
|
||||
- **BLOCKING:** Code that is too complex, monolithic, or poorly structured must be rejected
|
||||
|
||||
- [ ] **Error handling is robust**
|
||||
- No bare `except:` clauses
|
||||
- All errors have clear messages
|
||||
|
||||
+58
-1
@@ -84,7 +84,64 @@ def test_parse_simple_tool():
|
||||
# Then write minimal code to pass
|
||||
```
|
||||
|
||||
### Rule 3: No Production Debugging
|
||||
### Rule 3: Minimal, Maintainable, Modular Code
|
||||
**Core Focus:** Keep code minimal, maintainable, and modular.
|
||||
|
||||
#### Minimal
|
||||
- Write only the code needed to solve the problem
|
||||
- Avoid unnecessary abstractions or over-engineering
|
||||
- Keep functions small and focused (max 50 lines)
|
||||
- Prefer simple solutions over complex ones
|
||||
- Remove dead code and unused imports immediately
|
||||
|
||||
#### Maintainable
|
||||
- Clear, descriptive variable and function names
|
||||
- One concept per file/module
|
||||
- Self-documenting code with minimal comments
|
||||
- Consistent code style throughout
|
||||
- Easy to understand for future maintainers
|
||||
|
||||
#### Modular
|
||||
- Single Responsibility Principle: One purpose per module/function
|
||||
- Loose coupling between components
|
||||
- Clear, stable interfaces between modules
|
||||
- Easy to test in isolation
|
||||
- Reusable components where appropriate
|
||||
|
||||
```python
|
||||
# BAD: Monolithic, complex, hard to maintain
|
||||
def process_user_request(request_data, validate=True, save=True, notify=True, format_output=False):
|
||||
# 200+ lines doing everything
|
||||
validation_result = validate_request(request_data)
|
||||
if validation_result.is_valid:
|
||||
if save:
|
||||
db_connection = get_db_connection()
|
||||
cursor = db_connection.cursor()
|
||||
cursor.execute("INSERT INTO requests ...", request_data)
|
||||
db_connection.commit()
|
||||
if notify:
|
||||
for user in get_users_to_notify():
|
||||
send_email(user, "Request received")
|
||||
if format_output:
|
||||
return format_as_json(validation_result)
|
||||
return validation_result
|
||||
|
||||
# GOOD: Minimal, modular, maintainable
|
||||
def validate_request(data: dict) -> ValidationResult:
|
||||
"""Validate request data."""
|
||||
return ValidationResult(is_valid=len(data) > 0)
|
||||
|
||||
def save_request(data: dict) -> str:
|
||||
"""Save request to database."""
|
||||
return db.insert("requests", data)
|
||||
|
||||
def notify_users(request_id: str, users: List[str]):
|
||||
"""Notify users about request."""
|
||||
for user in users:
|
||||
send_email(user, f"Request {request_id} received")
|
||||
```
|
||||
|
||||
### Rule 4: No Production Debugging
|
||||
- NEVER add `print()` statements for debugging
|
||||
- Use `logging` module with appropriate levels
|
||||
- Remove ALL debug logging before committing
|
||||
|
||||
@@ -54,8 +54,13 @@ python main.py --port 8080 # Custom port
|
||||
python main.py --detect # Show hardware info only
|
||||
python main.py --federation # Enable network federation
|
||||
python main.py --mcp # Enable MCP server
|
||||
python main.py --use-opencode-tools # Use opencode tools (adds ~27k tokens)
|
||||
```
|
||||
|
||||
**Tool Mode Options:**
|
||||
- Default: Local tool server (~125 tokens, saves context window space)
|
||||
- `--use-opencode-tools`: Full opencode tool definitions (~27k tokens, more capabilities)
|
||||
|
||||
## Connect to Opencode
|
||||
|
||||
Add to your opencode config:
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
# TODO: CUDA and Android Support in Federation
|
||||
|
||||
## Overview
|
||||
|
||||
This document tracks known issues and recommendations for adding CUDA (NVIDIA) and Android nodes to the local_swarm federation system.
|
||||
|
||||
## Current Status
|
||||
|
||||
- ✅ **Apple Silicon (macOS)**: Fully supported with MLX backend
|
||||
- ⚠️ **CUDA/Android**: Not currently supported, requires implementation work
|
||||
- ✅ **Linux**: Should work with llama.cpp + CUDA
|
||||
- ✅ **Windows**: Should work with llama.cpp + CUDA (not tested)
|
||||
|
||||
## Known Issues
|
||||
|
||||
### 1. No CUDA Backend for macOS
|
||||
|
||||
**Problem:**
|
||||
- `__init__.py` only chooses MLX or llama.cpp
|
||||
- No CUDA path for macOS
|
||||
- Apple Silicon only supports Metal acceleration, not CUDA
|
||||
|
||||
**Impact:**
|
||||
- CUDA/Android nodes on macOS cannot use GPU acceleration
|
||||
- These nodes will fall back to CPU-only mode
|
||||
|
||||
**References:**
|
||||
- `src/backends/__init__.py` (lines 26-32)
|
||||
- `src/hardware/detector.py` (Apple Silicon detection)
|
||||
|
||||
**Recommendation:**
|
||||
- Current architecture is correct for macOS - CUDA is not supported on Apple Silicon
|
||||
- Would need separate CUDA backend implementation (not recommended)
|
||||
|
||||
---
|
||||
|
||||
### 2. Platform Detection in `hardware/detector.py`
|
||||
|
||||
**Current Detection:**
|
||||
```python
|
||||
def detect_gpu():
|
||||
# macOS: Apple Silicon (Metal only, no CUDA)
|
||||
# Linux/Windows: NVIDIA/AMD/Intel GPU (potential CUDA)
|
||||
# Android/Termux: CPU-only (no GPU)
|
||||
```
|
||||
|
||||
**Impact:**
|
||||
- Android/Termux devices detected as Linux
|
||||
- Will use CPU-only mode (expected)
|
||||
- No special handling for Android platform
|
||||
|
||||
**Potential Issue:**
|
||||
- Termux on Android reports as "linux"
|
||||
- May have different requirements (file paths, permissions)
|
||||
- Need to test if file paths work correctly on Android
|
||||
|
||||
**References:**
|
||||
- `src/hardware/detector.py:170-221` (Android/Termux detection via `is_termux()`)
|
||||
|
||||
**Recommendation:**
|
||||
- Add explicit Android platform detection beyond `is_termux()`
|
||||
- Test file path handling on Termux
|
||||
- Consider Android's unique file system limitations
|
||||
|
||||
---
|
||||
|
||||
### 3. Llama.cpp Backend Configuration
|
||||
|
||||
**Current GPU Layer Logic:**
|
||||
```python
|
||||
# src/backends/__init__.py (line 35)
|
||||
if hardware.gpu and not hardware.is_apple_silicon:
|
||||
n_gpu_layers = -1 # Offload all to GPU (Metal/CUDA)
|
||||
else:
|
||||
n_gpu_layers = 0 # CPU-only
|
||||
```
|
||||
|
||||
**For CUDA Support on Linux:**
|
||||
- Should set `n_gpu_layers` based on actual GPU count
|
||||
- NVIDIA: Set to GPU count (1-8 for multi-GPU)
|
||||
- AMD ROCm: Different backend, not tested
|
||||
|
||||
**Impact:**
|
||||
- Currently hardcoded to -1 on Apple Silicon (Metal)
|
||||
- CUDA nodes on Linux need proper layer configuration
|
||||
- No validation that requested layers match available GPU
|
||||
|
||||
**References:**
|
||||
- `src/backends/llamacpp.py` (line 16, n_gpu_layers parameter)
|
||||
- `src/backends/__init__.py` (line 35)
|
||||
|
||||
**Recommendation:**
|
||||
- Make `n_gpu_layers` configurable per backend
|
||||
- Auto-detect GPU capabilities from `pynvml` or system
|
||||
- Add GPU layer validation
|
||||
|
||||
---
|
||||
|
||||
### 4. Seed Variation Mode (Not an Issue, but Important)
|
||||
|
||||
**Current Behavior:**
|
||||
```python
|
||||
# src/swarm/manager.py (line 76-82)
|
||||
if use_seed_variation is None and hardware.is_apple_silicon:
|
||||
self.use_seed_variation = True # Auto-enabled on macOS
|
||||
```
|
||||
|
||||
**How It Works:**
|
||||
- Runs 1 model instance with different random seeds
|
||||
- Simulates multiple "workers" for consensus
|
||||
- Saves memory by not loading multiple models
|
||||
|
||||
**Impact on Federation:**
|
||||
- Your Mac: 1 worker → 2 votes (from 2 seeds)
|
||||
- Peer Mac: 2 workers → 2 votes (from 2 seeds)
|
||||
- Total: 4 votes instead of 8 (if using 4 actual instances)
|
||||
|
||||
**This is CORRECT behavior** for seed variation mode.
|
||||
|
||||
**Recommendation:**
|
||||
- To get 4 votes per machine (8 total), use `--instances 4` flag
|
||||
- Seed variation is a design choice, not a bug
|
||||
|
||||
---
|
||||
|
||||
### 5. Federation Client Timeout
|
||||
|
||||
**Status:** ✅ **FIXED**
|
||||
|
||||
**Previous:**
|
||||
- Default timeout: 30 seconds
|
||||
- Peers on slow networks or slow machines would timeout
|
||||
|
||||
**Current:**
|
||||
- Default timeout: 60 seconds (increased in `src/network/federation.py:38`)
|
||||
- Gives peers more time to respond
|
||||
|
||||
**References:**
|
||||
- `src/network/federation.py` (line 38)
|
||||
|
||||
**Recommendation:**
|
||||
- Current 60s is reasonable
|
||||
- Consider making timeout configurable per peer in discovery
|
||||
- Add retry logic for failed requests
|
||||
|
||||
---
|
||||
|
||||
### 6. Network Discovery
|
||||
|
||||
**Current Implementation:** ✅ **PLATFORM AGNOSTIC**
|
||||
|
||||
**Uses:**
|
||||
- mDNS/Bonjour for peer discovery
|
||||
- Standard network protocols
|
||||
- No platform-specific blocking
|
||||
|
||||
**Status:** Should work on all platforms (macOS, Linux, Windows, Android)
|
||||
|
||||
**References:**
|
||||
- `src/network/discovery.py` (standard mDNS implementation)
|
||||
|
||||
**Recommendation:**
|
||||
- No changes needed
|
||||
- Test on Linux/Windows/Android if needed
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priorities
|
||||
|
||||
### High Priority (Breaking Features)
|
||||
|
||||
1. **CUDA Backend for Linux** (if needed)
|
||||
- Add CUDA-specific backend or extend llama.cpp
|
||||
- Auto-detect NVIDIA GPU and configure layers
|
||||
- Test on actual CUDA hardware
|
||||
- **Effort:** 3-5 days
|
||||
|
||||
2. **Android Platform Detection**
|
||||
- Add explicit Android detection beyond Termux
|
||||
- Handle Android's file system and package manager differences
|
||||
- Test on real Android device
|
||||
- **Effort:** 2-3 days
|
||||
|
||||
### Medium Priority (Improvements)
|
||||
|
||||
1. **GPU Layer Auto-Configuration**
|
||||
- Auto-detect GPU capabilities from system
|
||||
- Match requested layers to available hardware
|
||||
- Add validation and helpful error messages
|
||||
- **Effort:** 1-2 days
|
||||
|
||||
2. **Federation Metrics**
|
||||
- Add per-peer timeout in PeerInfo
|
||||
- Track latency and success rates
|
||||
- Better error handling for retry logic
|
||||
- **Effort:** 1 day
|
||||
|
||||
### Low Priority (Nice to Have)
|
||||
|
||||
1. **GPU Backend Selection UI**
|
||||
- Allow users to manually select MLX vs llama.cpp
|
||||
- Add warning for CUDA backend on macOS (not supported)
|
||||
- **Effort:** 2 hours
|
||||
|
||||
2. **Seed Variation Toggle**
|
||||
- Add command-line flag to disable seed variation
|
||||
- Document the trade-offs clearly
|
||||
- **Effort:** 30 minutes
|
||||
|
||||
## Testing Checklist
|
||||
|
||||
Before marking any issue as complete, test on:
|
||||
|
||||
### macOS (Apple Silicon)
|
||||
- [ ] Federation with macOS peers (current environment)
|
||||
- [ ] Seed variation mode works correctly
|
||||
- [ ] MLX backend loads and generates
|
||||
- [ ] No crashes with multiple instances
|
||||
|
||||
### Linux (NVIDIA GPU)
|
||||
- [ ] llama.cpp backend loads with CUDA support
|
||||
- [ ] Federation with Linux peers works
|
||||
- [ ] GPU layers configured correctly
|
||||
- [ ] No GPU conflicts
|
||||
|
||||
### Windows (NVIDIA GPU)
|
||||
- [ ] llama.cpp backend loads with CUDA support
|
||||
- [ ] Federation with Windows peers works
|
||||
- [ ] No GPU conflicts
|
||||
|
||||
### Android (CPU-only)
|
||||
- [ ] Federation with Android peers works (mDNS should work)
|
||||
- [ ] CPU-only generation works
|
||||
- [ ] File paths work on Termux/Android
|
||||
|
||||
## Notes
|
||||
|
||||
### Architecture Decisions
|
||||
|
||||
**Why not per-platform backends:**
|
||||
- Simplifies codebase (single MLX path, single llama.cpp path)
|
||||
- Reduces maintenance burden
|
||||
- Trade-off: Can't optimize for platform-specific GPUs in backends
|
||||
|
||||
**Why seed variation on macOS:**
|
||||
- Apple Silicon has unified memory, not discrete VRAM
|
||||
- Loading multiple models would consume too much RAM
|
||||
- Seed variation allows consensus quality with 1 model instance
|
||||
|
||||
**CUDA/Android is not a bug:**
|
||||
- Current system is designed for Apple Silicon + llama.cpp
|
||||
- Adding CUDA support requires significant architecture work
|
||||
- Focus on federation quality for current platforms first
|
||||
|
||||
## Related Files
|
||||
|
||||
- `src/backends/__init__.py` - Backend selection logic
|
||||
- `src/backends/mlx.py` - Apple Silicon MLX backend
|
||||
- `src/backends/llamacpp.py` - llama.cpp backend (supports CUDA)
|
||||
- `src/hardware/detector.py` - Platform and GPU detection
|
||||
- `src/network/federation.py` - Federation communication
|
||||
- `src/network/discovery.py` - Peer discovery via mDNS
|
||||
- `src/swarm/manager.py` - Swarm orchestration
|
||||
|
||||
## Conclusion
|
||||
|
||||
The current federation implementation is **platform-agnostic** and should work on Linux/Windows with CUDA nodes. The main limitation is that macOS (Apple Silicon) only supports Metal/MLX, not CUDA.
|
||||
|
||||
**For immediate use:**
|
||||
- Use `--instances 4` flag on each machine to get 4 votes per machine
|
||||
- Test federation between different platforms (macOS + Linux)
|
||||
- Android/Termux should work as-is (CPU-only mode)
|
||||
|
||||
**For future work:**
|
||||
- Implement high-priority items if CUDA/Android support is needed
|
||||
- Add GPU layer auto-configuration for better hardware utilization
|
||||
@@ -1,12 +1,18 @@
|
||||
Use tools to execute commands and fetch information. Output only tool calls.
|
||||
|
||||
Available tools:
|
||||
- bash: Execute shell commands
|
||||
- webfetch: Fetch web content (supports text/markdown/html formats)
|
||||
- read: Read files
|
||||
- write: Create files
|
||||
|
||||
IMPORTANT: When requesting webfetch, ALWAYS provide a URL that actually exists. Do not hallucinate or guess URLs. If a URL returns 404 or errors, stop trying.
|
||||
|
||||
Format:
|
||||
TOOL: bash
|
||||
ARGUMENTS: {"command": "ls -la", "description": "Lists files in directory"}
|
||||
ARGUMENTS: {"command": "your command here"}
|
||||
|
||||
TOOL: webfetch
|
||||
ARGUMENTS: {"url": "https://example.com", "format": "markdown"}
|
||||
ARGUMENTS: {"url": "https://example.com", "format": "text"}
|
||||
|
||||
Available tools: bash, webfetch
|
||||
|
||||
No explanations. No numbered lists. No markdown. Only tool calls.
|
||||
No explanations. No numbered lists. No markdown. Only output tool calls.
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
# Investigation: 31k Token Context Issue
|
||||
|
||||
## Problem
|
||||
When making requests through opencode to local_swarm, the LLM receives ~31k tokens of context even for simple empty directory queries.
|
||||
|
||||
## Root Cause Identified
|
||||
|
||||
**NOT an issue with this repo's codebase - this is expected behavior for function calling.**
|
||||
|
||||
### How it works:
|
||||
|
||||
1. **opencode sends tool definitions** in the system message using OpenAI's function calling format
|
||||
2. **Each tool definition is ~450 tokens** (name + description + parameters)
|
||||
3. **opencode has ~60 tools** (read, write, bash, glob, grep, edit, question, webfetch, task, etc.)
|
||||
4. **Total tool definition tokens:** ~27,000 tokens
|
||||
|
||||
### Calculation:
|
||||
```
|
||||
Single tool definition: ~450 tokens
|
||||
Number of tools: ~60
|
||||
Tool schemas total: ~27,000 tokens
|
||||
System message: ~500 tokens
|
||||
User query: ~100 tokens
|
||||
---
|
||||
Total: ~27,600 tokens
|
||||
```
|
||||
|
||||
**This matches the observed ~31k tokens.**
|
||||
|
||||
## Why This Happens
|
||||
|
||||
OpenAI's function calling protocol requires sending the **complete function schemas** to the LLM with every request. This is how the model:
|
||||
- Knows what tools are available
|
||||
- Understands parameter requirements
|
||||
- Knows how to format tool calls
|
||||
|
||||
All major LLM providers using function calling work this way (OpenAI, Anthropic, local models, etc.).
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
import tiktoken
|
||||
enc = tiktoken.get_encoding('cl100k_base')
|
||||
|
||||
# Example from actual opencode tool definition
|
||||
read_tool_schema = '''{\"type\": \"function\", \"function\": {\"name\": \"read\", \"description\": \"Read a file or directory from the local filesystem...[full description]\", \"parameters\": {...}}}'''
|
||||
|
||||
print(f'Single tool schema: {len(enc.encode(read_tool_schema))} tokens')
|
||||
print(f'Estimated 60 tools: {len(enc.encode(read_tool_schema)) * 60:,} tokens')
|
||||
"
|
||||
```
|
||||
|
||||
Result:
|
||||
- Single tool definition: ~451 tokens
|
||||
- 60 tools: ~27,060 tokens
|
||||
- Plus system + user message: ~27,660 total
|
||||
|
||||
## This Is NOT a Bug
|
||||
|
||||
The 31k token context is **correct and expected** for function calling with 60+ tools. This is how:
|
||||
- OpenAI API works
|
||||
- Claude API works
|
||||
- Local models with function calling work
|
||||
|
||||
## Potential Optimizations (Optional)
|
||||
|
||||
If reducing context size is critical, consider:
|
||||
|
||||
### Option 1: Dynamic Tool Selection
|
||||
- Only send tools relevant to current task
|
||||
- Example: For file operations, only send [read, write, glob, edit]
|
||||
- Trade-off: Requires opencode to intelligently filter tools
|
||||
|
||||
### Option 2: Compressed Tool Descriptions
|
||||
- Shorten tool descriptions to essentials
|
||||
- Example: "Read file at path (required: filePath)"
|
||||
- Trade-off: Model may make more errors with less guidance
|
||||
|
||||
### Option 3: Tool Grouping
|
||||
- Group similar tools into single "tools: [read, write, glob]" parameter
|
||||
- Trade-off: Breaks OpenAI compatibility
|
||||
|
||||
## Recommendation
|
||||
|
||||
**NO ACTION REQUIRED.** The 31k token context is:
|
||||
- Standard for function calling with many tools
|
||||
- Within capabilities of modern LLMs (32k-128k context windows)
|
||||
- Not caused by this repo's code
|
||||
|
||||
The `.opencodeignore` created earlier will help with opencode's own system prompt, but doesn't affect the LLM context sent to local_swarm.
|
||||
|
||||
## Additional Finding
|
||||
|
||||
While investigating, verified:
|
||||
- `config/prompts/tool_instructions.txt`: 125 tokens ✅
|
||||
- This repo's tool execution code: No token bloat ✅
|
||||
- Issue is purely opencode's function calling protocol ✅
|
||||
@@ -10,218 +10,63 @@ import sys
|
||||
import multiprocessing as mp
|
||||
|
||||
# CRITICAL: Set spawn method BEFORE any other imports on macOS
|
||||
# This prevents fork-related issues with Metal GPU
|
||||
if sys.platform == "darwin":
|
||||
try:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
except RuntimeError:
|
||||
pass # Already set
|
||||
pass
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path - resolve for Windows compatibility
|
||||
# Add src to path
|
||||
src_path = Path(__file__).parent.resolve() / "src"
|
||||
sys.path.insert(0, str(src_path))
|
||||
|
||||
# Also add parent dir for Windows import issues
|
||||
if str(Path(__file__).parent.resolve()) not in sys.path:
|
||||
sys.path.insert(0, str(Path(__file__).parent.resolve()))
|
||||
|
||||
# These imports must come AFTER setting spawn method on macOS
|
||||
from hardware.detector import detect_hardware
|
||||
from models.selector import select_optimal_model
|
||||
from models.downloader import download_model_for_config
|
||||
from swarm import SwarmManager
|
||||
from api import create_server
|
||||
from api.routes import set_federated_swarm
|
||||
from mcp_server import create_mcp_server
|
||||
from interactive import (
|
||||
interactive_model_selection,
|
||||
show_startup_summary,
|
||||
show_runtime_menu,
|
||||
custom_configuration,
|
||||
)
|
||||
from network import create_discovery_service, FederatedSwarm
|
||||
from tools.executor import ToolExecutor, set_tool_executor
|
||||
from cli.parser import parse_args
|
||||
from cli.tool_server import run_tool_server
|
||||
from utils.network import get_local_ip
|
||||
from utils.logging_config import setup_logging
|
||||
from hardware.detector import detect_hardware
|
||||
from interactive import print_hardware_info
|
||||
|
||||
# Set up logging (DEBUG level for development)
|
||||
# Set up logging
|
||||
setup_logging()
|
||||
|
||||
|
||||
async def setup_swarm(model_config, hardware):
|
||||
"""Download model and initialize swarm."""
|
||||
# Download model
|
||||
print("\n⬇️ Downloading model...")
|
||||
try:
|
||||
model_path = download_model_for_config(model_config)
|
||||
print(f"✓ Model ready at: {model_path}")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error downloading model: {e}", file=sys.stderr)
|
||||
return None
|
||||
def handle_detect_mode(hardware) -> int:
|
||||
"""Handle --detect mode."""
|
||||
print_hardware_info(hardware)
|
||||
print("\n✅ Detection complete")
|
||||
return 0
|
||||
|
||||
|
||||
def handle_tool_server_mode(args, hardware) -> int:
|
||||
"""Handle --tool-server mode."""
|
||||
print("\n🔧 Starting Tool Execution Server...")
|
||||
host = args.host if args.host else get_local_ip()
|
||||
|
||||
# Initialize swarm
|
||||
print("\n🚀 Initializing swarm...")
|
||||
try:
|
||||
swarm = SwarmManager(
|
||||
model_config=model_config,
|
||||
hardware=hardware,
|
||||
consensus_strategy="similarity"
|
||||
)
|
||||
|
||||
success = await swarm.initialize(str(model_path))
|
||||
if not success:
|
||||
print("❌ Failed to initialize swarm")
|
||||
return None
|
||||
|
||||
return swarm
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error initializing swarm: {e}", file=sys.stderr)
|
||||
return None
|
||||
asyncio.run(run_tool_server(host, args.tool_port))
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nTool server stopped")
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
"""Get the local network IP address (private networks only)."""
|
||||
import socket
|
||||
try:
|
||||
# Create a socket and connect to a public DNS server
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.settimeout(2)
|
||||
# Try to connect to Google's DNS - this doesn't actually send data
|
||||
s.connect(("8.8.8.8", 80))
|
||||
ip = s.getsockname()[0]
|
||||
s.close()
|
||||
|
||||
# Check if it's a private IP (only 192.168.x.x for this network)
|
||||
is_private = (
|
||||
ip.startswith('192.168.')
|
||||
)
|
||||
|
||||
if is_private:
|
||||
print(f" 📡 Detected local IP: {ip}")
|
||||
return ip
|
||||
else:
|
||||
# If not private, return localhost for safety
|
||||
print(f" ⚠️ IP {ip} is not a private network, binding to localhost")
|
||||
return "127.0.0.1"
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Could not detect local IP: {e}, using localhost")
|
||||
return "127.0.0.1"
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Local Swarm - AI-powered coding LLM swarm",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python main.py # Interactive setup and start
|
||||
python main.py --auto # Auto-detect and start without menu
|
||||
python main.py --detect # Show hardware detection only
|
||||
python main.py --model qwen:3b:q4 # Use specific model (skip menu)
|
||||
python main.py --port 17615 # Use custom port (default: 17615)
|
||||
python main.py --host 192.168.1.5 # Bind to specific IP
|
||||
python main.py --instances 4 # Force number of instances
|
||||
python main.py --download-only # Download model only
|
||||
python main.py --test # Test with sample prompt
|
||||
python main.py --mcp # Enable MCP server
|
||||
python main.py --federation # Enable federation with other instances
|
||||
python main.py --federation --peer 192.168.1.10:17615 # Manual peer
|
||||
"""
|
||||
)
|
||||
async def run_main_mode(args, hardware) -> int:
|
||||
"""Run the main application mode."""
|
||||
from cli.main_runner import MainRunner
|
||||
|
||||
parser.add_argument(
|
||||
"--auto",
|
||||
action="store_true",
|
||||
help="Auto-detect best configuration without interactive menu"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detect",
|
||||
action="store_true",
|
||||
help="Show hardware detection and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Model to use (format: name:size:quant, e.g., qwen:3b:q4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=17615,
|
||||
help="Port to run the API server on (default: 17615)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instances",
|
||||
type=int,
|
||||
help="Force number of instances (overrides auto-calculation)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download-only",
|
||||
action="store_true",
|
||||
help="Download models only, don't start server"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Test with a sample prompt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mcp",
|
||||
action="store_true",
|
||||
help="Enable MCP server alongside HTTP API"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="config.yaml",
|
||||
help="Path to config file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Host IP to bind to (default: auto-detect)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--federation",
|
||||
action="store_true",
|
||||
help="Enable federation with other Local Swarm instances on the network"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--peer",
|
||||
action="append",
|
||||
dest="peers",
|
||||
help="Manually add a peer (format: host:port, can be used multiple times)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tool-server",
|
||||
action="store_true",
|
||||
help="Run as dedicated tool execution server (executes read/write/bash tools)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tool-port",
|
||||
type=int,
|
||||
default=17616,
|
||||
help="Port for tool execution server (default: 17616)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tool-host",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs='?',
|
||||
const='', # When --tool-host is used without a value, use empty string
|
||||
help="URL of tool execution server. Use without value for auto-detected local IP (http://<local-ip>:17616), or provide explicit URL."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
version="%(prog)s 0.1.0"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
runner = MainRunner(hardware, args)
|
||||
return await runner.run()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
|
||||
# Detect hardware first
|
||||
print("\n🔍 Detecting hardware...")
|
||||
@@ -229,316 +74,26 @@ Examples:
|
||||
hardware = detect_hardware()
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error detecting hardware: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return 1
|
||||
|
||||
# Handle detect mode
|
||||
if args.detect:
|
||||
# Just show hardware info
|
||||
from interactive import print_hardware_info
|
||||
print_hardware_info(hardware)
|
||||
print("\n✅ Detection complete")
|
||||
return
|
||||
return handle_detect_mode(hardware)
|
||||
|
||||
# Tool server mode - run minimal tool-only server
|
||||
# Handle tool server mode
|
||||
if args.tool_server:
|
||||
print("\n🔧 Starting Tool Execution Server...")
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
# Initialize local tool executor
|
||||
tool_executor = ToolExecutor(tool_host_url=None)
|
||||
set_tool_executor(tool_executor)
|
||||
|
||||
app = FastAPI(title="Local Swarm Tool Server")
|
||||
|
||||
@app.post("/v1/tools/execute")
|
||||
async def execute_tool(request: dict):
|
||||
tool_name = request.get("tool", "")
|
||||
tool_args = request.get("arguments", {})
|
||||
result = await tool_executor.execute(tool_name, tool_args)
|
||||
return {"result": result}
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy", "mode": "tool-server"}
|
||||
|
||||
host = args.host if args.host else get_local_ip()
|
||||
tool_port = args.tool_port
|
||||
print(f"🔗 Tool server running at http://{host}:{tool_port}")
|
||||
print(f" Endpoints:")
|
||||
print(f" - POST /v1/tools/execute")
|
||||
print(f" - GET /health")
|
||||
print(f"\n✅ Tool server ready!")
|
||||
|
||||
uvicorn.run(app, host=host, port=tool_port)
|
||||
return
|
||||
return handle_tool_server_mode(args, hardware)
|
||||
|
||||
# Determine model configuration
|
||||
config = None
|
||||
|
||||
if args.model or args.instances or args.auto:
|
||||
# Use command-line arguments or auto-detect
|
||||
print("\n📊 Calculating optimal configuration...")
|
||||
try:
|
||||
config = select_optimal_model(
|
||||
hardware,
|
||||
preferred_model=args.model,
|
||||
force_instances=args.instances
|
||||
)
|
||||
|
||||
if not config:
|
||||
print("\n❌ No suitable model found for your hardware")
|
||||
print(" Minimum requirement: 2 GB available memory")
|
||||
sys.exit(1)
|
||||
|
||||
# Show brief summary
|
||||
print(f"\n✓ Selected: {config.display_name}")
|
||||
print(f" Instances: {config.instances}")
|
||||
print(f" Memory: {config.total_memory_gb:.1f} GB")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error selecting model: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Interactive mode - show menu
|
||||
config = interactive_model_selection(hardware)
|
||||
|
||||
if not config:
|
||||
print("\n❌ No configuration selected")
|
||||
sys.exit(1)
|
||||
|
||||
if args.download_only:
|
||||
# Download model only
|
||||
print("\n" + "=" * 70)
|
||||
print("⬇️ Download Mode: Downloading model only")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
model_path = download_model_for_config(config)
|
||||
print(f"✓ Model downloaded to: {model_path}")
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ Download complete")
|
||||
print("=" * 70)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Download failed: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
elif args.test:
|
||||
# Test mode with sample prompt
|
||||
print("\n" + "=" * 70)
|
||||
print("🧪 Test Mode: Running sample inference")
|
||||
print("=" * 70)
|
||||
|
||||
async def test_inference():
|
||||
show_startup_summary(hardware, config)
|
||||
swarm = await setup_swarm(config, hardware)
|
||||
if not swarm:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Test prompt
|
||||
prompt = "Write a Python function to calculate factorial:"
|
||||
print(f"\nPrompt: {prompt}\n")
|
||||
print("Generating responses...\n")
|
||||
|
||||
result = await swarm.generate(prompt, max_tokens=200)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("SELECTED RESPONSE:")
|
||||
print("=" * 70)
|
||||
print(result.selected_response.text)
|
||||
print("\n" + "=" * 70)
|
||||
print(f"Strategy: {result.strategy}")
|
||||
print(f"Confidence: {result.confidence:.2f}")
|
||||
print(f"Latency: {result.selected_response.latency_ms:.1f}ms")
|
||||
print(f"Tokens/sec: {result.selected_response.tokens_per_second:.1f}")
|
||||
|
||||
# Show all responses
|
||||
print("\nAll responses received:")
|
||||
for i, resp in enumerate(result.all_responses):
|
||||
preview = resp.text[:60].replace('\n', ' ')
|
||||
print(f" Worker {i}: {preview}... ({resp.latency_ms:.1f}ms)")
|
||||
|
||||
return True
|
||||
finally:
|
||||
await swarm.shutdown()
|
||||
|
||||
success = asyncio.run(test_inference())
|
||||
|
||||
if success:
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ Test complete")
|
||||
print("=" * 70)
|
||||
else:
|
||||
print("\n❌ Test failed")
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
# Full mode (download + start API server + optional MCP)
|
||||
show_startup_summary(hardware, config)
|
||||
|
||||
async def run_server():
|
||||
swarm = await setup_swarm(config, hardware)
|
||||
if not swarm:
|
||||
return False
|
||||
|
||||
# Initialize tool executor
|
||||
if args.tool_host is not None:
|
||||
# --tool-host was provided
|
||||
if args.tool_host == "":
|
||||
# --tool-host with no value - use local IP with default port
|
||||
local_ip = get_local_ip()
|
||||
tool_host_url = f"http://{local_ip}:17616"
|
||||
print(f"\n🔧 Using remote tool host: {tool_host_url} (auto-detected local IP)")
|
||||
else:
|
||||
# --tool-host with explicit value
|
||||
tool_host_url = args.tool_host
|
||||
print(f"\n🔧 Using remote tool host: {tool_host_url}")
|
||||
tool_executor = ToolExecutor(tool_host_url=tool_host_url)
|
||||
set_tool_executor(tool_executor)
|
||||
else:
|
||||
# Local tool execution (default)
|
||||
tool_executor = ToolExecutor(tool_host_url=None)
|
||||
set_tool_executor(tool_executor)
|
||||
|
||||
# Update summary with runtime info
|
||||
show_startup_summary(hardware, config, swarm)
|
||||
|
||||
# Initialize federation if enabled
|
||||
discovery = None
|
||||
federated_swarm = None
|
||||
if args.federation:
|
||||
print("\n🌐 Initializing federation...")
|
||||
try:
|
||||
# Use specified host for advertising if provided
|
||||
advertise_ip = args.host if args.host else None
|
||||
discovery = await create_discovery_service(args.port, advertise_ip=advertise_ip)
|
||||
|
||||
# Get swarm info for advertising
|
||||
swarm_info = {
|
||||
"version": "0.1.0",
|
||||
"instances": config.instances,
|
||||
"model_id": config.model_id,
|
||||
"hardware_summary": f"{hardware.cpu_cores} CPU, {hardware.ram_gb:.1f}GB RAM"
|
||||
}
|
||||
|
||||
await discovery.start_advertising(swarm_info)
|
||||
await discovery.start_listening()
|
||||
|
||||
# Add manual peers if specified
|
||||
if args.peers:
|
||||
print(f" 📍 Adding {len(args.peers)} manual peer(s)...")
|
||||
from network.discovery import PeerInfo
|
||||
from datetime import datetime
|
||||
for peer_str in args.peers:
|
||||
try:
|
||||
host, port = peer_str.rsplit(':', 1)
|
||||
port = int(port)
|
||||
peer = PeerInfo(
|
||||
host=host,
|
||||
port=port,
|
||||
name=f"manual_{host}_{port}",
|
||||
version="0.1.0",
|
||||
instances=0,
|
||||
model_id="unknown",
|
||||
hardware_summary="manual",
|
||||
last_seen=datetime.now()
|
||||
)
|
||||
discovery.peers[peer.name] = peer
|
||||
print(f" ✓ Added peer: {host}:{port}")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to add peer {peer_str}: {e}")
|
||||
|
||||
# Create federated swarm wrapper
|
||||
federated_swarm = FederatedSwarm(swarm, discovery)
|
||||
set_federated_swarm(federated_swarm)
|
||||
|
||||
# Start health check loop in background
|
||||
asyncio.create_task(discovery.start_health_check_loop(interval_seconds=10))
|
||||
|
||||
print(f" ✓ Federation enabled")
|
||||
print(f" ✓ Discovery active on port {discovery.discovery_port}")
|
||||
print(f" ✓ Peer health checks every 10s")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to initialize federation: {e}")
|
||||
print(" Continuing without federation...")
|
||||
|
||||
mcp_server = None
|
||||
try:
|
||||
# Create and start API server
|
||||
print("\n🌐 Starting HTTP API server...")
|
||||
# Use provided host or auto-detect
|
||||
if args.host:
|
||||
host = args.host
|
||||
print(f"🔗 Using specified host: {host}:{args.port}")
|
||||
else:
|
||||
# Use local network IP instead of 0.0.0.0 for security
|
||||
host = get_local_ip()
|
||||
print(f"🔗 Binding to {host}:{args.port}")
|
||||
server = create_server(swarm, host=host, port=args.port)
|
||||
|
||||
print(f"\n✅ Local Swarm is running!")
|
||||
print(f" API: http://{host}:{args.port}/v1")
|
||||
print(f" Health: http://{host}:{args.port}/health")
|
||||
|
||||
if args.federation and discovery:
|
||||
peers = discovery.get_peers()
|
||||
print(f"\n🌐 Federation: Enabled")
|
||||
print(f" Discovery port: {discovery.discovery_port}")
|
||||
if peers:
|
||||
print(f" Peers discovered: {len(peers)}")
|
||||
for peer in peers:
|
||||
print(f" - {peer.name} ({peer.model_id})")
|
||||
else:
|
||||
print(f" Peers discovered: 0 (waiting for peers...)")
|
||||
|
||||
# Show tool server status
|
||||
if args.tool_host is not None:
|
||||
print(f"\n🔧 Tool Server: Remote")
|
||||
if args.tool_host == "":
|
||||
local_ip = get_local_ip()
|
||||
print(f" URL: http://{local_ip}:17616 (auto-detected)")
|
||||
else:
|
||||
print(f" URL: {args.tool_host}")
|
||||
print(f" Mode: Tools executed remotely on tool host")
|
||||
else:
|
||||
print(f"\n🔧 Tool Server: Local")
|
||||
print(f" Mode: Tools executed on this machine")
|
||||
|
||||
if args.mcp:
|
||||
# Start MCP server alongside HTTP API
|
||||
print("\n🤖 Starting MCP server...")
|
||||
mcp_server = await create_mcp_server(swarm)
|
||||
print(" MCP server active (stdio)")
|
||||
|
||||
print(f"\n💡 Configure opencode to use:")
|
||||
print(f' base_url: http://127.0.0.1:{args.port}/v1')
|
||||
print(f' api_key: any (not used)')
|
||||
print(f"\nPress Ctrl+C to stop...\n")
|
||||
|
||||
# Start HTTP server (this will block)
|
||||
await server.start()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nReceived stop signal")
|
||||
finally:
|
||||
if federated_swarm:
|
||||
await federated_swarm.close()
|
||||
if discovery:
|
||||
await discovery.stop()
|
||||
await swarm.shutdown()
|
||||
|
||||
return True
|
||||
|
||||
try:
|
||||
success = asyncio.run(run_server())
|
||||
if success:
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ Server stopped gracefully")
|
||||
print("=" * 70)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error running server: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
# Run main mode
|
||||
try:
|
||||
return asyncio.run(run_main_mode(args, hardware))
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nReceived stop signal")
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
sys.exit(main())
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
"""Chat completion handlers for Local Swarm.
|
||||
|
||||
Contains the business logic for chat completions, separated from HTTP routing.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from api.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChoice,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
from api.formatting import format_messages_with_tools
|
||||
from api.tool_parser import parse_tool_calls
|
||||
from utils.token_counter import count_tokens
|
||||
from tools.executor import get_tool_executor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_tools(tools: Optional[list]) -> Optional[list]:
|
||||
"""Sanitize tool definitions to fix invalid schemas.
|
||||
|
||||
Removes extra 'description' from properties if present.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Sanitized tools list
|
||||
"""
|
||||
if not tools:
|
||||
return tools
|
||||
|
||||
sanitized = []
|
||||
for tool in tools:
|
||||
if tool.type == "function" and tool.function.parameters:
|
||||
params = tool.function.parameters
|
||||
# Remove invalid 'description' from properties if present
|
||||
if 'properties' in params and 'description' in params.get('properties', {}):
|
||||
invalid_props = ['description']
|
||||
# Also remove 'description' from required if present
|
||||
if 'required' in params:
|
||||
params['required'] = [r for r in params.get('required', []) if r not in invalid_props]
|
||||
# Remove invalid properties
|
||||
params['properties'] = {k: v for k, v in params.get('properties', {}).items() if k not in invalid_props}
|
||||
logger.debug(f" 🔧 Sanitized tool '{tool.function.name}': removed {invalid_props} from properties/required")
|
||||
sanitized.append(tool)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
async def _execute_tools(
|
||||
tool_calls: list,
|
||||
client_working_dir: Optional[str],
|
||||
executor
|
||||
) -> str:
|
||||
"""Execute tool calls and return results.
|
||||
|
||||
Args:
|
||||
tool_calls: List of parsed tool calls
|
||||
client_working_dir: Working directory for file operations
|
||||
executor: Tool executor instance
|
||||
|
||||
Returns:
|
||||
Combined tool results as string
|
||||
"""
|
||||
from api.routes import execute_tool_server_side
|
||||
|
||||
tool_results = []
|
||||
for i, tc in enumerate(tool_calls):
|
||||
tool_name = tc.get("function", {}).get("name", "")
|
||||
tool_args_str = tc.get("function", {}).get("arguments", "{}")
|
||||
try:
|
||||
tool_args = json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str
|
||||
except:
|
||||
tool_args = {}
|
||||
|
||||
logger.debug(f" [{i+1}/{len(tool_calls)}] Executing: {tool_name}({tool_args})")
|
||||
result = await execute_tool_server_side(tool_name, tool_args, working_dir=client_working_dir)
|
||||
tool_results.append(f"Tool '{tool_name}' result: {result}")
|
||||
logger.debug(f" ✓ Completed: {result[:100]}..." if len(result) > 100 else f" ✓ Result: {result}")
|
||||
|
||||
return "\n\n".join(tool_results)
|
||||
|
||||
|
||||
def _create_response(
|
||||
content: str,
|
||||
tool_calls: list,
|
||||
finish_reason: str,
|
||||
prompt: str,
|
||||
request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
"""Create a chat completion response.
|
||||
|
||||
Args:
|
||||
content: Response content
|
||||
tool_calls: List of tool calls
|
||||
finish_reason: Finish reason
|
||||
prompt: Original prompt for token counting
|
||||
request: Original request
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse
|
||||
"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
completion_tokens = count_tokens(content)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls
|
||||
),
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _generate_with_local_swarm(
|
||||
swarm_manager,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temperature: float
|
||||
) -> tuple[str, int, float]:
|
||||
"""Generate response using local swarm.
|
||||
|
||||
Args:
|
||||
swarm_manager: Swarm manager instance
|
||||
prompt: Prompt to generate from
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tokens_generated, tokens_per_second)
|
||||
"""
|
||||
result = await swarm_manager.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
use_consensus=True
|
||||
)
|
||||
|
||||
response = result.selected_response
|
||||
return (
|
||||
response.text,
|
||||
response.tokens_generated,
|
||||
response.tokens_per_second
|
||||
)
|
||||
|
||||
|
||||
async def _generate_with_federation(
|
||||
federated_swarm,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temperature: float
|
||||
) -> tuple[str, list, str]:
|
||||
"""Generate response using federated swarm.
|
||||
|
||||
Args:
|
||||
federated_swarm: Federated swarm instance
|
||||
prompt: Prompt to generate from
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tool_calls, finish_reason)
|
||||
"""
|
||||
result = await federated_swarm.generate_with_federation(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
min_peers=0
|
||||
)
|
||||
|
||||
content = result.final_response
|
||||
|
||||
# Check for tool calls
|
||||
content_parsed, tool_calls_parsed = parse_tool_calls(content)
|
||||
if tool_calls_parsed:
|
||||
return content_parsed, tool_calls_parsed, "tool_calls"
|
||||
|
||||
return content, [], "stop"
|
||||
|
||||
|
||||
async def handle_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
swarm_manager,
|
||||
federated_swarm,
|
||||
client_working_dir: Optional[str],
|
||||
use_opencode_tools: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Handle a chat completion request.
|
||||
|
||||
Args:
|
||||
request: Chat completion request
|
||||
swarm_manager: Swarm manager instance
|
||||
federated_swarm: Optional federated swarm instance
|
||||
client_working_dir: Client working directory
|
||||
use_opencode_tools: Whether to use opencode tool definitions
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
"""
|
||||
# Format messages into prompt
|
||||
if use_opencode_tools:
|
||||
sanitized_tools = _sanitize_tools(request.tools)
|
||||
prompt = format_messages_with_tools(request.messages, sanitized_tools)
|
||||
has_tools = sanitized_tools is not None and len(sanitized_tools) > 0
|
||||
else:
|
||||
prompt = format_messages_with_tools(request.messages, None)
|
||||
has_tools = request.tools is not None and len(request.tools) > 0
|
||||
|
||||
logger.debug(f"\n{'='*60}")
|
||||
logger.debug(f"REQUEST: has_tools={has_tools}, stream={request.stream}")
|
||||
logger.debug(f"MODE: {'opencode' if use_opencode_tools else 'local'} tools")
|
||||
logger.debug(f"{'='*60}")
|
||||
|
||||
# Use federation if available
|
||||
if federated_swarm is not None:
|
||||
peers = federated_swarm.discovery.get_peers()
|
||||
if peers:
|
||||
logger.info(f"🌐 Using federation with {len(peers)} peer(s)...")
|
||||
content, tool_calls, finish_reason = await _generate_with_federation(
|
||||
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
return _create_response(content, tool_calls, finish_reason, prompt, request)
|
||||
|
||||
# Use local swarm
|
||||
logger.debug("Using local swarm generation")
|
||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
||||
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
|
||||
logger.debug(f"DEBUG: Generated response (tokens={tokens_generated}, t/s={tps:.1f})")
|
||||
logger.debug(f"DEBUG: Response preview: {response_text[:200]}...")
|
||||
|
||||
# Parse tool calls if tools were provided
|
||||
content = response_text
|
||||
tool_calls = []
|
||||
finish_reason = "stop"
|
||||
|
||||
if has_tools:
|
||||
logger.debug(f"DEBUG: Parsing tool calls from response...")
|
||||
content, tool_calls_parsed = parse_tool_calls(response_text)
|
||||
logger.debug(f"DEBUG: parse_tool_calls returned: content_len={len(content)}, parsed={tool_calls_parsed is not None}")
|
||||
|
||||
if tool_calls_parsed:
|
||||
logger.debug(f" 🔧 Model requesting {len(tool_calls_parsed)} tool(s)...")
|
||||
executor = get_tool_executor()
|
||||
if executor:
|
||||
logger.debug(f" 🔗 Tool executor: {executor.tool_host_url or 'local'}")
|
||||
else:
|
||||
logger.debug(f" ⚠️ No tool executor configured!")
|
||||
|
||||
# Execute tools
|
||||
tool_results_str = await _execute_tools(tool_calls_parsed, client_working_dir, executor)
|
||||
content = tool_results_str
|
||||
finish_reason = "stop"
|
||||
tool_calls = [] # Clear tool_calls since we executed them
|
||||
logger.debug(f" ✅ All tools executed, returning results")
|
||||
else:
|
||||
logger.debug(f"DEBUG: No tool calls parsed from response")
|
||||
else:
|
||||
logger.debug(f"DEBUG: No tools requested, returning normal response")
|
||||
|
||||
return _create_response(content, tool_calls, finish_reason, prompt, request)
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Message formatting module for Local Swarm.
|
||||
|
||||
Formats chat messages into prompts and handles tool instructions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from api.models import ChatMessage
|
||||
from utils.token_counter import count_tokens
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for tool instructions (loaded from config file)
|
||||
_TOOL_INSTRUCTIONS_CACHE: Optional[str] = None
|
||||
|
||||
# Global flag for tool mode (default: local tool server to save tokens)
|
||||
_USE_OPENCODE_TOOLS: bool = False
|
||||
|
||||
|
||||
def set_use_opencode_tools(value: bool) -> None:
|
||||
"""Set whether to use opencode's tool definitions (default: False = local tool server).
|
||||
|
||||
Args:
|
||||
value: True to use opencode tools (~27k tokens), False to use local tool server (~125 tokens)
|
||||
"""
|
||||
global _USE_OPENCODE_TOOLS
|
||||
_USE_OPENCODE_TOOLS = value
|
||||
logger.info(f"🔧 Tool mode set to: {'opencode tools (~27k tokens)' if value else 'local tool server (~125 tokens)'}")
|
||||
|
||||
|
||||
def _load_tool_instructions() -> str:
|
||||
"""Load tool instructions from config file.
|
||||
|
||||
Loads from config/prompts/tool_instructions.txt
|
||||
Falls back to default if file not found.
|
||||
|
||||
Returns:
|
||||
Tool instructions string
|
||||
"""
|
||||
global _TOOL_INSTRUCTIONS_CACHE
|
||||
|
||||
if _TOOL_INSTRUCTIONS_CACHE is not None:
|
||||
return _TOOL_INSTRUCTIONS_CACHE
|
||||
|
||||
# Try to load from config file
|
||||
config_path = Path(__file__).parent.parent.parent / "config" / "prompts" / "tool_instructions.txt"
|
||||
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r') as f:
|
||||
_TOOL_INSTRUCTIONS_CACHE = f.read().strip()
|
||||
logger.debug(f"Loaded tool instructions from {config_path}")
|
||||
else:
|
||||
# Fallback default instructions
|
||||
_TOOL_INSTRUCTIONS_CACHE = """You MUST use tools. DO NOT explain. DO NOT use markdown.
|
||||
|
||||
OUTPUT THIS EXACT FORMAT - NOTHING ELSE:
|
||||
|
||||
TOOL: bash
|
||||
ARGUMENTS: {"command": "your command here"}
|
||||
|
||||
Available tools:
|
||||
- bash: Run shell commands
|
||||
- write: Create files
|
||||
- read: Read files
|
||||
|
||||
NEVER write explanations.
|
||||
NEVER use numbered lists.
|
||||
NEVER use markdown code blocks.
|
||||
ONLY output TOOL: lines."""
|
||||
logger.warning(f"Tool instructions config not found at {config_path}, using default")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tool instructions: {e}")
|
||||
# Use minimal fallback
|
||||
_TOOL_INSTRUCTIONS_CACHE = 'Use TOOL: tool_name\nARGUMENTS: {"param": "value"} format.'
|
||||
|
||||
return _TOOL_INSTRUCTIONS_CACHE
|
||||
|
||||
|
||||
def _is_initial_request(messages: List[ChatMessage]) -> bool:
|
||||
"""Check if this is an initial request (no assistant or tool messages).
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
True if this is the initial request
|
||||
"""
|
||||
has_assistant = any(msg.role == "assistant" for msg in messages)
|
||||
has_tool = any(msg.role == "tool" for msg in messages)
|
||||
return not has_assistant and not has_tool
|
||||
|
||||
|
||||
def _compress_large_request(messages: List[ChatMessage], max_tokens: int = 4000) -> List[ChatMessage]:
|
||||
"""Compress large initial requests by keeping only user messages.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
max_tokens: Maximum tokens before compression
|
||||
|
||||
Returns:
|
||||
Compressed list of messages
|
||||
"""
|
||||
full_text = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
|
||||
current_tokens = count_tokens(full_text)
|
||||
|
||||
if current_tokens <= max_tokens:
|
||||
return messages
|
||||
|
||||
logger.info(f"🗜️ COMPRESSING: Initial request is {current_tokens} tokens, compressing to <{max_tokens}...")
|
||||
|
||||
# Keep only user messages
|
||||
user_messages = [msg for msg in messages if msg.role == "user"]
|
||||
|
||||
if not user_messages:
|
||||
logger.warning("No user messages found in initial request!")
|
||||
return []
|
||||
|
||||
# Get the last user message
|
||||
last_user_msg = user_messages[-1]
|
||||
user_content = last_user_msg.content
|
||||
|
||||
# Truncate if still too long
|
||||
if len(user_content) > 2000:
|
||||
user_content = user_content[:2000] + "... [truncated for token limit]"
|
||||
logger.debug(f"Truncated user message from {len(last_user_msg.content)} to 2000 chars")
|
||||
|
||||
return [ChatMessage(role="user", content=user_content)]
|
||||
|
||||
|
||||
def _filter_messages(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""Filter messages for processing.
|
||||
|
||||
For initial requests >4000 tokens, compress aggressively.
|
||||
Otherwise, just remove system messages.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
Filtered list of messages
|
||||
"""
|
||||
if _is_initial_request(messages):
|
||||
full_text = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
|
||||
if count_tokens(full_text) > 4000:
|
||||
return _compress_large_request(messages)
|
||||
|
||||
# Normal filtering: remove system messages
|
||||
return [msg for msg in messages if msg.role != "system"]
|
||||
|
||||
|
||||
def _add_tool_instructions(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""Add tool instructions to messages if needed.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
Messages with tool instructions added
|
||||
"""
|
||||
has_assistant = any(msg.role == "assistant" for msg in messages)
|
||||
|
||||
if has_assistant:
|
||||
return messages
|
||||
|
||||
tool_instructions = _load_tool_instructions()
|
||||
logger.debug(f"Using {'opencode' if _USE_OPENCODE_TOOLS else 'local'} tool mode: {len(tool_instructions)} chars")
|
||||
|
||||
return [ChatMessage(role="system", content=tool_instructions)] + messages
|
||||
|
||||
|
||||
def _format_to_chatml(messages: List[ChatMessage]) -> str:
|
||||
"""Format messages to ChatML format.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
ChatML formatted string
|
||||
"""
|
||||
formatted = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.role
|
||||
content = msg.content
|
||||
|
||||
if role == "system":
|
||||
formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
|
||||
elif role == "user":
|
||||
formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
|
||||
elif role == "assistant":
|
||||
formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||||
elif role == "tool":
|
||||
tool_name = getattr(msg, 'name', 'tool')
|
||||
formatted.append(f"<|im_start|>tool\n{tool_name}: {content}<|im_end|>")
|
||||
|
||||
formatted.append("<|im_start|>assistant\n")
|
||||
return "\n".join(formatted)
|
||||
|
||||
|
||||
def _log_prompt_preview(messages: List[ChatMessage]) -> None:
|
||||
"""Log a preview of the prompt for debugging.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
"""
|
||||
preview = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
preview.append(f"[SYSTEM] {msg.content[:200]}...")
|
||||
elif msg.role == "user":
|
||||
preview.append(f"[USER] {msg.content}")
|
||||
logger.debug(f"Prompt preview: {' | '.join(preview)}")
|
||||
|
||||
|
||||
def format_messages_with_tools(
|
||||
messages: List[ChatMessage],
|
||||
tools: Optional[list] = None
|
||||
) -> str:
|
||||
"""Format chat messages into a single prompt using ChatML format.
|
||||
|
||||
Note: Tools are handled server-side. The model should respond normally.
|
||||
IMPORTANT: If _USE_OPENCODE_TOOLS is True, use opencode's tool definitions (~27k tokens).
|
||||
If False, use local tool server (~125 tokens) to save tokens.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
tools: Optional list of tools (currently ignored, server-side handling)
|
||||
|
||||
Returns:
|
||||
Formatted prompt string in ChatML format
|
||||
"""
|
||||
# Filter messages
|
||||
filtered_messages = _filter_messages(messages)
|
||||
|
||||
# Add tool instructions if needed
|
||||
filtered_messages = _add_tool_instructions(filtered_messages)
|
||||
|
||||
# Log preview
|
||||
_log_prompt_preview(filtered_messages)
|
||||
|
||||
# Format to ChatML
|
||||
result = _format_to_chatml(filtered_messages)
|
||||
|
||||
# Log final token count
|
||||
final_tokens = count_tokens(result)
|
||||
original_tokens = count_tokens("\n".join([f"{msg.role}: {msg.content}" for msg in messages]))
|
||||
logger.info(f"📊 Final prompt size: {final_tokens} tokens (reduced from {original_tokens})")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def format_messages(messages: List[ChatMessage]) -> str:
|
||||
"""Format chat messages into a single prompt using ChatML format.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
return format_messages_with_tools(messages, None)
|
||||
@@ -76,6 +76,7 @@ class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = Field(default=0, description="Tokens in prompt")
|
||||
completion_tokens: int = Field(default=0, description="Tokens in completion")
|
||||
total_tokens: int = Field(default=0, description="Total tokens")
|
||||
tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens per second")
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
|
||||
+64
-853
File diff suppressed because it is too large
Load Diff
+18
-10
@@ -18,21 +18,23 @@ from swarm.status_monitor import StatusMonitor
|
||||
|
||||
class APIServer:
|
||||
"""OpenAI-compatible API server."""
|
||||
|
||||
def __init__(self, swarm_manager: SwarmManager, host: str = "127.0.0.1", port: int = 17615, show_live_status: bool = True):
|
||||
|
||||
def __init__(self, swarm_manager: SwarmManager, host: str = "127.0.0.1", port: int = 17615, show_live_status: bool = True, use_opencode_tools: bool = False):
|
||||
"""
|
||||
Initialize API server.
|
||||
|
||||
|
||||
Args:
|
||||
swarm_manager: Swarm manager instance
|
||||
host: Host to bind to
|
||||
port: Port to listen on
|
||||
show_live_status: Whether to show live worker status updates
|
||||
use_opencode_tools: Whether to use opencode's tool definitions (~27k tokens) or local tool server (~125 tokens)
|
||||
"""
|
||||
self.swarm_manager = swarm_manager
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.show_live_status = show_live_status
|
||||
self.use_opencode_tools = use_opencode_tools
|
||||
self.status_monitor: Optional[StatusMonitor] = None
|
||||
self.app = self._create_app()
|
||||
|
||||
@@ -44,6 +46,9 @@ class APIServer:
|
||||
"""Lifespan context manager for startup/shutdown."""
|
||||
# Startup: Set swarm manager in routes
|
||||
set_swarm_manager(self.swarm_manager)
|
||||
# Set tool mode in routes
|
||||
from api.routes import set_use_opencode_tools
|
||||
set_use_opencode_tools(self.use_opencode_tools)
|
||||
print(f"\n🌐 API server starting on http://{self.host}:{self.port}")
|
||||
print(f" Endpoints:")
|
||||
print(f" - POST /v1/chat/completions")
|
||||
@@ -90,32 +95,35 @@ class APIServer:
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level="info"
|
||||
log_level="warning",
|
||||
access_log=False
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
def run_sync(self):
|
||||
"""Run server synchronously (blocking)."""
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level="info"
|
||||
log_level="warning",
|
||||
access_log=False
|
||||
)
|
||||
|
||||
|
||||
def create_server(swarm_manager: SwarmManager, host: str = "127.0.0.1", port: int = 17615, show_live_status: bool = True) -> APIServer:
|
||||
def create_server(swarm_manager: SwarmManager, host: str = "127.0.0.1", port: int = 17615, show_live_status: bool = True, use_opencode_tools: bool = False) -> APIServer:
|
||||
"""
|
||||
Create API server instance.
|
||||
|
||||
|
||||
Args:
|
||||
swarm_manager: Swarm manager instance
|
||||
host: Host to bind to
|
||||
port: Port to listen on
|
||||
show_live_status: Whether to show live worker status updates
|
||||
|
||||
use_opencode_tools: Whether to use opencode's tool definitions (~27k tokens) or local tool server (~125 tokens)
|
||||
|
||||
Returns:
|
||||
APIServer instance
|
||||
"""
|
||||
return APIServer(swarm_manager, host, port, show_live_status)
|
||||
return APIServer(swarm_manager, host, port, show_live_status, use_opencode_tools)
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Tool parsing module for Local Swarm.
|
||||
|
||||
Parses tool calls from model output in various formats.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
|
||||
|
||||
def ensure_tool_arguments(tool_name: str, args_dict: dict) -> dict:
|
||||
"""Ensure tool arguments have all required fields.
|
||||
|
||||
For bash tool: inject 'description' field if missing.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
args_dict: Tool arguments dictionary
|
||||
|
||||
Returns:
|
||||
Updated arguments dictionary
|
||||
"""
|
||||
if tool_name == 'bash' and 'description' not in args_dict:
|
||||
# Generate description from command
|
||||
command = args_dict.get('command', '')
|
||||
desc = command.split()[0] if command else 'Execute command'
|
||||
args_dict['description'] = desc
|
||||
return args_dict
|
||||
|
||||
|
||||
def _parse_standard_format(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse standard TOOL: format.
|
||||
|
||||
Format: TOOL: name\nARGUMENTS: {"key": "value"}
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
tool_pattern = r'TOOL:\s*(\w+)\s*\nARGUMENTS:\s*(\{[^}]*\})'
|
||||
tool_matches = list(re.finditer(tool_pattern, text, re.IGNORECASE))
|
||||
|
||||
if not tool_matches:
|
||||
return None, None
|
||||
|
||||
tool_calls = []
|
||||
for i, tool_match in enumerate(tool_matches):
|
||||
tool_name = tool_match.group(1)
|
||||
args_str = tool_match.group(2)
|
||||
try:
|
||||
args_dict = json.loads(args_str)
|
||||
args_dict = ensure_tool_arguments(tool_name, args_dict)
|
||||
tool_calls.append({
|
||||
"id": f"call_{i+1}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(args_dict)
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if tool_calls:
|
||||
first_start = tool_matches[0].start()
|
||||
content = text[:first_start].strip()
|
||||
return content, tool_calls
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _parse_markdown_format(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse markdown code block format.
|
||||
|
||||
Format: ```bash command```
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
markdown_pattern = r'```(?:bash|shell|sh)?\s*\n(.*?)\n```'
|
||||
markdown_matches = list(re.finditer(markdown_pattern, text, re.DOTALL))
|
||||
|
||||
if not markdown_matches:
|
||||
return None, None
|
||||
|
||||
tool_calls = []
|
||||
for i, match in enumerate(markdown_matches):
|
||||
code_content = match.group(1).strip()
|
||||
if code_content:
|
||||
args_dict = {"command": code_content}
|
||||
args_dict = ensure_tool_arguments("bash", args_dict)
|
||||
tool_calls.append({
|
||||
"id": f"call_{i+1}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": json.dumps(args_dict)
|
||||
}
|
||||
})
|
||||
|
||||
if tool_calls:
|
||||
first_start = markdown_matches[0].start()
|
||||
content = text[:first_start].strip()
|
||||
return content, tool_calls
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _parse_command_lines(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse command lines in text.
|
||||
|
||||
Matches common bash commands with their arguments.
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
command_lines = []
|
||||
command_pattern = r'^(npm|npx|mkdir|cd|ls|cat|echo|git|python|pip|node|yarn|create-react-app)\s+'
|
||||
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if re.match(command_pattern, line):
|
||||
command_lines.append(line)
|
||||
|
||||
if command_lines:
|
||||
combined_command = ' && '.join(command_lines)
|
||||
args_dict = {"command": combined_command}
|
||||
args_dict = ensure_tool_arguments("bash", args_dict)
|
||||
return "", [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": json.dumps(args_dict)
|
||||
}
|
||||
}]
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _parse_standalone_commands(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse standalone bash commands.
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
standalone_pattern = r'(?:^|\n)(npm\s+\w+|npx\s+\w+|mkdir\s+\w+|cd\s+\w+|git\s+\w+)(?:\s|$)'
|
||||
standalone_matches = list(re.finditer(standalone_pattern, text, re.MULTILINE))
|
||||
|
||||
if standalone_matches:
|
||||
commands = [match.group(1).strip() for match in standalone_matches]
|
||||
if commands:
|
||||
combined_command = ' && '.join(commands)
|
||||
args_dict = {"command": combined_command}
|
||||
args_dict = ensure_tool_arguments("bash", args_dict)
|
||||
return "", [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": json.dumps(args_dict)
|
||||
}
|
||||
}]
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _parse_urls(text: str) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse URLs for webfetch tool.
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls) or (None, None) if not found
|
||||
"""
|
||||
url_pattern = r'https?://[^\s<>"\')\]]+[a-zA-Z0-9]'
|
||||
url_matches = list(re.finditer(url_pattern, text))
|
||||
|
||||
if url_matches:
|
||||
urls = [match.group(0) for match in url_matches]
|
||||
if urls:
|
||||
tool_calls = []
|
||||
for i, url in enumerate(urls):
|
||||
tool_calls.append({
|
||||
"id": f"call_{i+1}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "webfetch",
|
||||
"arguments": json.dumps({"url": url, "format": "markdown"})
|
||||
}
|
||||
})
|
||||
return "", tool_calls
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def parse_tool_calls(text: str) -> Tuple[str, Optional[List[Dict[str, Any]]]]:
|
||||
"""Parse tool calls from model output using multiple formats.
|
||||
|
||||
Supports multiple formats for compatibility with different model sizes:
|
||||
1. Standard: TOOL: name\nARGUMENTS: {"key": "value"}
|
||||
2. Markdown: ```bash command```
|
||||
3. Command lines: npm install, git clone, etc.
|
||||
4. Standalone commands
|
||||
5. URLs: for webfetch tool
|
||||
|
||||
Args:
|
||||
text: Model output text
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_tools, tool_calls or None)
|
||||
"""
|
||||
# Priority 1: Standard format
|
||||
result = _parse_standard_format(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
# Priority 2: Markdown code blocks
|
||||
result = _parse_markdown_format(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
# Priority 3: Command lines
|
||||
result = _parse_command_lines(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
# Priority 4: Standalone commands
|
||||
result = _parse_standalone_commands(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
# Priority 5: URLs
|
||||
result = _parse_urls(text)
|
||||
if result[1] is not None:
|
||||
return result[0] or "", result[1]
|
||||
|
||||
return text, None
|
||||
@@ -4,7 +4,7 @@ Creates the appropriate backend based on hardware and platform.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from hardware.detector import HardwareProfile, detect_hardware
|
||||
from hardware.detector import HardwareProfile, detect_hardware, calculate_gpu_layers
|
||||
from backends.base import LLMBackend
|
||||
from backends.llamacpp import LlamaCppBackend
|
||||
from backends.mlx import MLXBackend
|
||||
@@ -31,15 +31,17 @@ def create_backend(hardware: Optional[HardwareProfile] = None) -> LLMBackend:
|
||||
# Otherwise use llama.cpp (supports CUDA, ROCm, SYCL, CPU)
|
||||
print("Using llama.cpp backend")
|
||||
|
||||
# Determine GPU layers
|
||||
# Auto-configure GPU layers based on hardware
|
||||
n_gpu_layers = calculate_gpu_layers(hardware.gpu)
|
||||
|
||||
if hardware.gpu and not hardware.is_apple_silicon:
|
||||
# Has external GPU, offload all layers
|
||||
n_gpu_layers = -1
|
||||
print(f" GPU detected: {hardware.gpu.name}")
|
||||
print(f" Offloading all layers to GPU")
|
||||
if hardware.gpu.is_nvidia:
|
||||
print(f" Compute capability: {hardware.gpu.compute_capability or 'unknown'}")
|
||||
if hardware.gpu.device_count > 1:
|
||||
print(f" GPU count: {hardware.gpu.device_count}")
|
||||
print(f" Offloading {n_gpu_layers} layers to GPU")
|
||||
else:
|
||||
# CPU only
|
||||
n_gpu_layers = 0
|
||||
print(f" No GPU detected, using CPU")
|
||||
|
||||
return LlamaCppBackend(n_gpu_layers=n_gpu_layers)
|
||||
|
||||
@@ -0,0 +1,285 @@
|
||||
"""Main application runner for Local Swarm.
|
||||
|
||||
Handles the primary application modes: download-only, test, and full server mode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from models.selector import select_optimal_model, ModelConfig
|
||||
from models.downloader import download_model_for_config
|
||||
from swarm import SwarmManager
|
||||
from api import create_server
|
||||
from api.routes import set_federated_swarm
|
||||
from interactive import (
|
||||
interactive_model_selection,
|
||||
show_startup_summary,
|
||||
show_runtime_menu,
|
||||
)
|
||||
from network import create_discovery_service, FederatedSwarm
|
||||
from tools.executor import ToolExecutor, set_tool_executor
|
||||
from utils.network import get_local_ip
|
||||
|
||||
|
||||
class MainRunner:
|
||||
"""Runs the main application logic."""
|
||||
|
||||
def __init__(self, hardware, args):
|
||||
"""Initialize the main runner.
|
||||
|
||||
Args:
|
||||
hardware: Hardware profile
|
||||
args: Parsed command line arguments
|
||||
"""
|
||||
self.hardware = hardware
|
||||
self.args = args
|
||||
self.config: Optional[ModelConfig] = None
|
||||
self.swarm: Optional[SwarmManager] = None
|
||||
self.discovery = None
|
||||
self.federated_swarm = None
|
||||
self.mcp_server = None
|
||||
|
||||
async def run(self) -> int:
|
||||
"""Run the main application.
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, 1 for error)
|
||||
"""
|
||||
# Get configuration
|
||||
self.config = self._get_configuration()
|
||||
if not self.config:
|
||||
return 1
|
||||
|
||||
# Handle download-only mode
|
||||
if self.args.download_only:
|
||||
return await self._run_download_mode()
|
||||
|
||||
# Handle test mode
|
||||
if self.args.test:
|
||||
return await self._run_test_mode()
|
||||
|
||||
# Run full server mode
|
||||
return await self._run_server_mode()
|
||||
|
||||
def _get_configuration(self) -> Optional[ModelConfig]:
|
||||
"""Get the model configuration."""
|
||||
if self.args.model or self.args.instances or self.args.auto:
|
||||
return self._get_auto_config()
|
||||
else:
|
||||
return interactive_model_selection(self.hardware)
|
||||
|
||||
def _get_auto_config(self) -> Optional[ModelConfig]:
|
||||
"""Get auto-detected configuration."""
|
||||
print("\n📊 Calculating optimal configuration...")
|
||||
try:
|
||||
config = select_optimal_model(
|
||||
self.hardware,
|
||||
preferred_model=self.args.model,
|
||||
force_instances=self.args.instances
|
||||
)
|
||||
|
||||
if not config:
|
||||
print("\n❌ No suitable model found for your hardware")
|
||||
print(" Minimum requirement: 2 GB available memory")
|
||||
return None
|
||||
|
||||
print(f"\n✓ Selected: {config.display_name}")
|
||||
print(f" Instances: {config.instances}")
|
||||
print(f" Memory: {config.total_memory_gb:.1f} GB")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error selecting model: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
async def _run_download_mode(self) -> int:
|
||||
"""Run download-only mode."""
|
||||
print("\n" + "=" * 70)
|
||||
print("⬇️ Download Mode: Downloading model only")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
model_path = download_model_for_config(self.config)
|
||||
print(f"✓ Model downloaded to: {model_path}")
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ Download complete")
|
||||
print("=" * 70)
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"\n❌ Download failed: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
async def _run_test_mode(self) -> int:
|
||||
"""Run test mode with sample prompt."""
|
||||
from cli.test_runner import run_test
|
||||
return await run_test(self.hardware, self.config)
|
||||
|
||||
async def _run_server_mode(self) -> int:
|
||||
"""Run full server mode."""
|
||||
show_startup_summary(self.hardware, self.config)
|
||||
|
||||
# Setup swarm
|
||||
if not await self._setup_swarm():
|
||||
return 1
|
||||
|
||||
# Initialize tool executor
|
||||
self._setup_tool_executor()
|
||||
|
||||
# Show updated summary with runtime info
|
||||
show_startup_summary(self.hardware, self.config, self.swarm)
|
||||
|
||||
# Initialize federation if enabled
|
||||
if self.args.federation:
|
||||
await self._setup_federation()
|
||||
|
||||
# Start MCP server if enabled
|
||||
if self.args.mcp:
|
||||
await self._setup_mcp()
|
||||
|
||||
# Run server
|
||||
return await self._run_server()
|
||||
|
||||
async def _setup_swarm(self) -> bool:
|
||||
"""Setup the swarm.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
print("\n⬇️ Downloading model...")
|
||||
try:
|
||||
model_path = download_model_for_config(self.config)
|
||||
print(f"✓ Model ready at: {model_path}")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error downloading model: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
print("\n🚀 Initializing swarm...")
|
||||
try:
|
||||
self.swarm = SwarmManager(
|
||||
model_config=self.config,
|
||||
hardware=self.hardware,
|
||||
consensus_strategy="similarity"
|
||||
)
|
||||
|
||||
success = await self.swarm.initialize(str(model_path))
|
||||
if not success:
|
||||
print("❌ Failed to initialize swarm")
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error initializing swarm: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
def _setup_tool_executor(self) -> None:
|
||||
"""Setup the tool executor."""
|
||||
if self.args.tool_host is not None:
|
||||
if self.args.tool_host == "":
|
||||
tool_host_url = f"http://{get_local_ip()}:17616"
|
||||
print(f"\n🔧 Using remote tool host: {tool_host_url} (auto-detected)")
|
||||
else:
|
||||
tool_host_url = self.args.tool_host
|
||||
print(f"\n🔧 Using remote tool host: {tool_host_url}")
|
||||
executor = ToolExecutor(tool_host_url=tool_host_url)
|
||||
else:
|
||||
executor = ToolExecutor(tool_host_url=None)
|
||||
print("\n🔧 Tool Server: Local")
|
||||
|
||||
set_tool_executor(executor)
|
||||
|
||||
async def _setup_federation(self) -> None:
|
||||
"""Setup federation."""
|
||||
print("\n🌐 Initializing federation...")
|
||||
try:
|
||||
advertise_ip = self.args.host if self.args.host else None
|
||||
self.discovery = await create_discovery_service(
|
||||
self.args.port,
|
||||
advertise_ip=advertise_ip
|
||||
)
|
||||
|
||||
swarm_info = {
|
||||
"version": "0.1.0",
|
||||
"instances": self.config.instances,
|
||||
"model_id": self.config.model_id,
|
||||
"hardware_summary": f"{self.hardware.cpu_cores} CPU, {self.hardware.ram_gb:.1f}GB RAM"
|
||||
}
|
||||
|
||||
await self.discovery.start_advertising(swarm_info)
|
||||
await self.discovery.start_listening()
|
||||
|
||||
# Add manual peers
|
||||
if self.args.peers:
|
||||
await self._add_manual_peers()
|
||||
|
||||
self.federated_swarm = FederatedSwarm(self.swarm, self.discovery)
|
||||
set_federated_swarm(self.federated_swarm)
|
||||
|
||||
# Start health check loop
|
||||
asyncio.create_task(
|
||||
self.discovery.start_health_check_loop(interval_seconds=10)
|
||||
)
|
||||
|
||||
print(f" ✓ Federation enabled")
|
||||
print(f" ✓ Discovery active on port {self.discovery.discovery_port}")
|
||||
print(f" ✓ Peer health checks every 10s")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to initialize federation: {e}")
|
||||
print(" Continuing without federation...")
|
||||
|
||||
async def _add_manual_peers(self) -> None:
|
||||
"""Add manual peers from command line."""
|
||||
print(f" 📍 Adding {len(self.args.peers)} manual peer(s)...")
|
||||
from network.discovery import PeerInfo
|
||||
from datetime import datetime
|
||||
|
||||
for peer_str in self.args.peers:
|
||||
try:
|
||||
host, port = peer_str.rsplit(':', 1)
|
||||
port = int(port)
|
||||
peer = PeerInfo(
|
||||
host=host,
|
||||
port=port,
|
||||
name=f"manual_{host}_{port}",
|
||||
version="0.1.0",
|
||||
instances=0,
|
||||
model_id="unknown",
|
||||
hardware_summary="manual",
|
||||
last_seen=datetime.now()
|
||||
)
|
||||
self.discovery.peers[peer.name] = peer
|
||||
print(f" ✓ Added peer: {host}:{port}")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to add peer {peer_str}: {e}")
|
||||
|
||||
async def _setup_mcp(self) -> None:
|
||||
"""Setup MCP server."""
|
||||
print("\n🤖 Starting MCP server...")
|
||||
from mcp_server import create_mcp_server
|
||||
self.mcp_server = await create_mcp_server(self.swarm)
|
||||
print(" MCP server active (stdio)")
|
||||
|
||||
async def _run_server(self) -> int:
|
||||
"""Run the API server."""
|
||||
from cli.server_runner import ServerRunner
|
||||
|
||||
runner = ServerRunner(
|
||||
self.swarm,
|
||||
self.discovery,
|
||||
self.federated_swarm,
|
||||
self.args
|
||||
)
|
||||
|
||||
try:
|
||||
return await runner.run()
|
||||
finally:
|
||||
await self._shutdown()
|
||||
|
||||
async def _shutdown(self) -> None:
|
||||
"""Shutdown all services."""
|
||||
if self.federated_swarm:
|
||||
await self.federated_swarm.close()
|
||||
if self.discovery:
|
||||
await self.discovery.stop()
|
||||
if self.swarm:
|
||||
await self.swarm.shutdown()
|
||||
@@ -0,0 +1,151 @@
|
||||
"""CLI argument parsing for Local Swarm."""
|
||||
|
||||
import argparse
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
"""Create and configure the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Local Swarm - AI-powered coding LLM swarm",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python main.py # Interactive setup and start
|
||||
python main.py --auto # Auto-detect and start without menu
|
||||
python main.py --detect # Show hardware detection only
|
||||
python main.py --model qwen:3b:q4 # Use specific model (skip menu)
|
||||
python main.py --port 17615 # Use custom port (default: 17615)
|
||||
python main.py --host 192.168.1.5 # Bind to specific IP
|
||||
python main.py --instances 4 # Force number of instances
|
||||
python main.py --download-only # Download model only
|
||||
python main.py --test # Test with sample prompt
|
||||
python main.py --mcp # Enable MCP server
|
||||
python main.py --federation # Enable federation with other instances
|
||||
python main.py --federation --peer 192.168.1.10:17615 # Manual peer
|
||||
"""
|
||||
)
|
||||
|
||||
# Mode options
|
||||
parser.add_argument(
|
||||
"--auto",
|
||||
action="store_true",
|
||||
help="Auto-detect best configuration without interactive menu"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detect",
|
||||
action="store_true",
|
||||
help="Show hardware detection and exit"
|
||||
)
|
||||
|
||||
# Model options
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Model to use (format: name:size:quant, e.g., qwen:3b:q4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instances",
|
||||
type=int,
|
||||
help="Force number of instances (overrides auto-calculation)"
|
||||
)
|
||||
|
||||
# Server options
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=17615,
|
||||
help="Port to run the API server on (default: 17615)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Host IP to bind to (default: auto-detect)"
|
||||
)
|
||||
|
||||
# Operation modes
|
||||
parser.add_argument(
|
||||
"--download-only",
|
||||
action="store_true",
|
||||
help="Download models only, don't start server"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Test with a sample prompt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mcp",
|
||||
action="store_true",
|
||||
help="Enable MCP server alongside HTTP API"
|
||||
)
|
||||
|
||||
# Configuration
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="config.yaml",
|
||||
help="Path to config file"
|
||||
)
|
||||
|
||||
# Federation options
|
||||
parser.add_argument(
|
||||
"--federation",
|
||||
action="store_true",
|
||||
help="Enable federation with other Local Swarm instances on the network"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--peer",
|
||||
action="append",
|
||||
dest="peers",
|
||||
help="Manually add a peer (format: host:port, can be used multiple times)"
|
||||
)
|
||||
|
||||
# Tool server options
|
||||
parser.add_argument(
|
||||
"--tool-server",
|
||||
action="store_true",
|
||||
help="Run as dedicated tool execution server (executes read/write/bash tools)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tool-port",
|
||||
type=int,
|
||||
default=17616,
|
||||
help="Port for tool execution server (default: 17616)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tool-host",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs='?',
|
||||
const='',
|
||||
help="URL of tool execution server. Use without value for auto-detected local IP"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-opencode-tools",
|
||||
action="store_true",
|
||||
help="Use opencode's tool definitions (~27k tokens). Default: use local tool server"
|
||||
)
|
||||
|
||||
# Version
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
version="%(prog)s 0.1.0"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args(args: Optional[list] = None):
|
||||
"""Parse command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments (defaults to sys.argv)
|
||||
|
||||
Returns:
|
||||
Parsed arguments namespace
|
||||
"""
|
||||
parser = create_parser()
|
||||
return parser.parse_args(args)
|
||||
@@ -0,0 +1,103 @@
|
||||
"""Server runner for Local Swarm."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from api import create_server
|
||||
from api.routes import set_federated_swarm
|
||||
from utils.network import get_local_ip
|
||||
|
||||
|
||||
class ServerRunner:
|
||||
"""Handles server startup and shutdown."""
|
||||
|
||||
def __init__(self, swarm, discovery, federated_swarm, args):
|
||||
"""Initialize server runner.
|
||||
|
||||
Args:
|
||||
swarm: Swarm manager instance
|
||||
discovery: Discovery service (optional)
|
||||
federated_swarm: Federated swarm (optional)
|
||||
args: Command line arguments
|
||||
"""
|
||||
self.swarm = swarm
|
||||
self.discovery = discovery
|
||||
self.federated_swarm = federated_swarm
|
||||
self.args = args
|
||||
self.mcp_server = None
|
||||
|
||||
async def run(self) -> int:
|
||||
"""Run the server.
|
||||
|
||||
Returns:
|
||||
Exit code
|
||||
"""
|
||||
print("\n🌐 Starting HTTP API server...")
|
||||
|
||||
# Determine host
|
||||
host = self._get_host()
|
||||
|
||||
# Show tool mode
|
||||
self._show_tool_mode()
|
||||
|
||||
# Create and start server
|
||||
server = create_server(
|
||||
self.swarm,
|
||||
host=host,
|
||||
port=self.args.port,
|
||||
use_opencode_tools=self.args.use_opencode_tools
|
||||
)
|
||||
|
||||
self._print_connection_info(host)
|
||||
|
||||
# Start server
|
||||
try:
|
||||
await server.start()
|
||||
finally:
|
||||
await self._shutdown()
|
||||
|
||||
return 0
|
||||
|
||||
def _get_host(self) -> str:
|
||||
"""Get the host to bind to."""
|
||||
if self.args.host:
|
||||
print(f"🔗 Using specified host: {self.args.host}:{self.args.port}")
|
||||
return self.args.host
|
||||
else:
|
||||
host = get_local_ip()
|
||||
print(f"🔗 Binding to {host}:{self.args.port}")
|
||||
return host
|
||||
|
||||
def _show_tool_mode(self) -> None:
|
||||
"""Display tool mode information."""
|
||||
if self.args.use_opencode_tools:
|
||||
print(f"🔧 Tool mode: opencode tools (~27k tokens)")
|
||||
else:
|
||||
print(f"🔧 Tool mode: local tool server (~125 tokens)")
|
||||
|
||||
def _print_connection_info(self, host: str) -> None:
|
||||
"""Print server connection information."""
|
||||
print(f"\n✅ Local Swarm is running!")
|
||||
print(f" API: http://{host}:{self.args.port}/v1")
|
||||
print(f" Health: http://{host}:{self.args.port}/health")
|
||||
|
||||
if self.args.federation and self.discovery:
|
||||
peers = self.discovery.get_peers()
|
||||
print(f"\n🌐 Federation: Enabled")
|
||||
print(f" Discovery port: {self.discovery.discovery_port}")
|
||||
if peers:
|
||||
print(f" Peers discovered: {len(peers)}")
|
||||
|
||||
print(f"\n💡 Configure opencode to use:")
|
||||
print(f' base_url: http://127.0.0.1:{self.args.port}/v1')
|
||||
print(f' api_key: any (not used)')
|
||||
print(f"\nPress Ctrl+C to stop...\n")
|
||||
|
||||
async def _shutdown(self) -> None:
|
||||
"""Shutdown all services."""
|
||||
if self.federated_swarm:
|
||||
await self.federated_swarm.close()
|
||||
if self.discovery:
|
||||
await self.discovery.stop()
|
||||
if self.swarm:
|
||||
await self.swarm.shutdown()
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Test mode runner for Local Swarm."""
|
||||
|
||||
import asyncio
|
||||
from models.downloader import download_model_for_config
|
||||
from swarm import SwarmManager
|
||||
from interactive import show_startup_summary
|
||||
|
||||
|
||||
async def run_test(hardware, config) -> int:
|
||||
"""Run test mode with sample prompt.
|
||||
|
||||
Args:
|
||||
hardware: Hardware profile
|
||||
config: Model configuration
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, 1 for error)
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("🧪 Test Mode: Running sample inference")
|
||||
print("=" * 70)
|
||||
|
||||
show_startup_summary(hardware, config)
|
||||
|
||||
# Download model
|
||||
print("\n⬇️ Downloading model...")
|
||||
try:
|
||||
model_path = download_model_for_config(config)
|
||||
print(f"✓ Model ready at: {model_path}")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error downloading model: {e}")
|
||||
return 1
|
||||
|
||||
# Initialize swarm
|
||||
print("\n🚀 Initializing swarm...")
|
||||
try:
|
||||
swarm = SwarmManager(
|
||||
model_config=config,
|
||||
hardware=hardware,
|
||||
consensus_strategy="similarity"
|
||||
)
|
||||
|
||||
success = await swarm.initialize(str(model_path))
|
||||
if not success:
|
||||
print("❌ Failed to initialize swarm")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error initializing swarm: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
# Test prompt
|
||||
prompt = "Write a Python function to calculate factorial:"
|
||||
print(f"\nPrompt: {prompt}\n")
|
||||
print("Generating responses...\n")
|
||||
|
||||
result = await swarm.generate(prompt, max_tokens=200)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("SELECTED RESPONSE:")
|
||||
print("=" * 70)
|
||||
print(result.selected_response.text)
|
||||
print("\n" + "=" * 70)
|
||||
print(f"Strategy: {result.strategy}")
|
||||
print(f"Confidence: {result.confidence:.2f}")
|
||||
print(f"Latency: {result.selected_response.latency_ms:.1f}ms")
|
||||
print(f"Tokens/sec: {result.selected_response.tokens_per_second:.1f}")
|
||||
|
||||
# Show all responses
|
||||
print("\nAll responses received:")
|
||||
for i, resp in enumerate(result.all_responses):
|
||||
preview = resp.text[:60].replace('\n', ' ')
|
||||
print(f" Worker {i}: {preview}... ({resp.latency_ms:.1f}ms)")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ Test complete")
|
||||
print("=" * 70)
|
||||
return 0
|
||||
|
||||
finally:
|
||||
await swarm.shutdown()
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Tool server for Local Swarm.
|
||||
|
||||
Standalone tool execution server for distributed setups.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
from tools.executor import ToolExecutor, set_tool_executor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_tool_server_app() -> FastAPI:
|
||||
"""Create the tool server FastAPI application.
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application
|
||||
"""
|
||||
app = FastAPI(title="Local Swarm Tool Server")
|
||||
|
||||
@app.post("/v1/tools/execute")
|
||||
async def execute_tool(request: dict):
|
||||
tool_name = request.get("tool", "")
|
||||
tool_args = request.get("arguments", {})
|
||||
|
||||
# Get the global executor
|
||||
from tools.executor import get_tool_executor
|
||||
executor = get_tool_executor()
|
||||
|
||||
if executor is None:
|
||||
return {"result": "Error: No tool executor configured"}
|
||||
|
||||
result = await executor.execute(tool_name, tool_args)
|
||||
return {"result": result}
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy", "mode": "tool-server"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_tool_server(host: str, port: int) -> None:
|
||||
"""Run the tool server.
|
||||
|
||||
Args:
|
||||
host: Host to bind to
|
||||
port: Port to listen on
|
||||
"""
|
||||
# Initialize local tool executor
|
||||
tool_executor = ToolExecutor(tool_host_url=None)
|
||||
set_tool_executor(tool_executor)
|
||||
|
||||
app = create_tool_server_app()
|
||||
|
||||
print(f"🔗 Tool server running at http://{host}:{port}")
|
||||
print(f" Endpoints:")
|
||||
print(f" - POST /v1/tools/execute")
|
||||
print(f" - GET /health")
|
||||
print(f"\n✅ Tool server ready!")
|
||||
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="warning")
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
+132
-2
@@ -2,6 +2,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
import os
|
||||
import platform
|
||||
import psutil
|
||||
|
||||
@@ -17,6 +18,8 @@ class GPUInfo:
|
||||
is_nvidia: bool = False
|
||||
is_amd: bool = False
|
||||
is_mobile: bool = False
|
||||
compute_capability: Optional[str] = None # CUDA compute capability
|
||||
device_count: int = 1 # Number of GPUs available
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -70,10 +73,55 @@ class HardwareProfile:
|
||||
return self.available_memory_gb
|
||||
|
||||
|
||||
def is_android() -> bool:
|
||||
"""Check if running on Android (beyond just Termux)."""
|
||||
# Check multiple Android indicators
|
||||
|
||||
# 1. Check for Android-specific environment variables
|
||||
android_env_vars = [
|
||||
"ANDROID_ROOT",
|
||||
"ANDROID_DATA",
|
||||
"ANDROID_ART_ROOT",
|
||||
"ANDROID_I18N_ROOT",
|
||||
"ANDROID_TZDATA_ROOT",
|
||||
]
|
||||
if any(os.environ.get(var) for var in android_env_vars):
|
||||
return True
|
||||
|
||||
# 2. Check for Android-specific paths
|
||||
android_paths = [
|
||||
"/system/build.prop",
|
||||
"/system/bin/app_process",
|
||||
"/data/data",
|
||||
]
|
||||
if any(os.path.exists(path) for path in android_paths):
|
||||
return True
|
||||
|
||||
# 3. Check for Termux (which runs on Android)
|
||||
if _is_android_or_termux():
|
||||
return True
|
||||
|
||||
# 4. Check /proc/sys/kernel/osrelease for Android
|
||||
try:
|
||||
if os.path.exists("/proc/sys/kernel/osrelease"):
|
||||
with open("/proc/sys/kernel/osrelease", "r") as f:
|
||||
release = f.read().lower()
|
||||
if "android" in release:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def detect_os() -> str:
|
||||
"""Detect the operating system."""
|
||||
system = platform.system().lower()
|
||||
if system == "darwin":
|
||||
|
||||
# Check for Android first (reports as Linux)
|
||||
if system == "linux" and is_android():
|
||||
return "android"
|
||||
elif system == "darwin":
|
||||
return "darwin"
|
||||
elif system == "windows":
|
||||
return "windows"
|
||||
@@ -132,6 +180,14 @@ def detect_nvidia_gpu() -> Optional[GPUInfo]:
|
||||
except Exception:
|
||||
driver = None
|
||||
|
||||
# Get compute capability
|
||||
compute_capability = None
|
||||
try:
|
||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
compute_capability = f"{major}.{minor}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return GPUInfo(
|
||||
name=name,
|
||||
vram_gb=vram_gb,
|
||||
@@ -139,7 +195,9 @@ def detect_nvidia_gpu() -> Optional[GPUInfo]:
|
||||
device_id=0,
|
||||
is_nvidia=True,
|
||||
is_apple_silicon=False,
|
||||
is_amd=False
|
||||
is_amd=False,
|
||||
compute_capability=compute_capability,
|
||||
device_count=device_count
|
||||
)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
@@ -219,6 +277,78 @@ def detect_gpu() -> Optional[GPUInfo]:
|
||||
return None
|
||||
|
||||
|
||||
def calculate_gpu_layers(gpu: Optional[GPUInfo]) -> int:
|
||||
"""Calculate optimal number of GPU layers to offload.
|
||||
|
||||
Args:
|
||||
gpu: GPU information (None if no GPU)
|
||||
|
||||
Returns:
|
||||
Number of layers to offload (-1 = all, 0 = CPU only)
|
||||
"""
|
||||
if gpu is None:
|
||||
return 0
|
||||
|
||||
if gpu.is_apple_silicon:
|
||||
# Apple Silicon: offload all layers (unified memory)
|
||||
return -1
|
||||
|
||||
if gpu.is_nvidia:
|
||||
# NVIDIA: Check compute capability for compatibility
|
||||
if gpu.compute_capability:
|
||||
major, _ = gpu.compute_capability.split('.')
|
||||
if int(major) < 5:
|
||||
# Very old GPUs (Kepler and earlier) may have issues
|
||||
return 0
|
||||
|
||||
# Multi-GPU support: use device_count to determine layers
|
||||
# For now, offload all layers if we have any NVIDIA GPU
|
||||
return -1
|
||||
|
||||
if gpu.is_amd:
|
||||
# AMD: ROCm support varies, be conservative
|
||||
return -1
|
||||
|
||||
# Unknown GPU type: use CPU
|
||||
return 0
|
||||
|
||||
|
||||
def validate_gpu_layers(requested_layers: int, gpu: Optional[GPUInfo]) -> int:
|
||||
"""Validate and adjust requested GPU layers.
|
||||
|
||||
Args:
|
||||
requested_layers: Requested number of layers (-1 = all)
|
||||
gpu: GPU information
|
||||
|
||||
Returns:
|
||||
Validated layer count
|
||||
"""
|
||||
if requested_layers == 0:
|
||||
return 0
|
||||
|
||||
if gpu is None:
|
||||
if requested_layers != 0:
|
||||
raise ValueError(
|
||||
f"Requested {requested_layers} GPU layers but no GPU detected. "
|
||||
"Use n_gpu_layers=0 for CPU-only mode."
|
||||
)
|
||||
return 0
|
||||
|
||||
if gpu.is_apple_silicon:
|
||||
# Apple Silicon always uses all layers
|
||||
return -1
|
||||
|
||||
if gpu.is_nvidia and gpu.compute_capability:
|
||||
major, _ = gpu.compute_capability.split('.')
|
||||
if int(major) < 5:
|
||||
raise ValueError(
|
||||
f"NVIDIA GPU {gpu.name} has compute capability {gpu.compute_capability}. "
|
||||
f"Minimum required is 5.0. Use n_gpu_layers=0 for CPU mode."
|
||||
)
|
||||
|
||||
return requested_layers
|
||||
|
||||
|
||||
def detect_hardware() -> HardwareProfile:
|
||||
"""Detect complete hardware profile."""
|
||||
os_name = detect_os()
|
||||
|
||||
@@ -10,6 +10,64 @@ from typing import Optional
|
||||
from hardware.detector import GPUInfo
|
||||
|
||||
|
||||
# Android-specific file paths for common operations
|
||||
ANDROID_PATHS = {
|
||||
"termux_home": "/data/data/com.termux/files/home",
|
||||
"termux_usr": "/data/data/com.termux/files/usr",
|
||||
"termux_bin": "/data/data/com.termux/files/usr/bin",
|
||||
"shared_storage": "/sdcard",
|
||||
"android_data": "/data/data",
|
||||
}
|
||||
|
||||
|
||||
def get_android_path(path_type: str, subpath: str = "") -> str:
|
||||
"""Get Android-specific file path.
|
||||
|
||||
Args:
|
||||
path_type: Type of path (termux_home, shared_storage, etc.)
|
||||
subpath: Additional path components
|
||||
|
||||
Returns:
|
||||
Full path string
|
||||
"""
|
||||
base = ANDROID_PATHS.get(path_type, path_type)
|
||||
if subpath:
|
||||
return os.path.join(base, subpath)
|
||||
return base
|
||||
|
||||
|
||||
def normalize_path_for_android(path: str) -> str:
|
||||
"""Normalize a path for Android/Termux environment.
|
||||
|
||||
Args:
|
||||
path: Original path
|
||||
|
||||
Returns:
|
||||
Normalized path for Android
|
||||
"""
|
||||
# Expand user home directory properly on Android
|
||||
if path.startswith("~/"):
|
||||
if is_termux():
|
||||
home = ANDROID_PATHS["termux_home"]
|
||||
else:
|
||||
home = os.environ.get("HOME", "/")
|
||||
path = os.path.join(home, path[2:])
|
||||
|
||||
# Handle /sdcard paths
|
||||
if path.startswith("/sdcard") and not os.path.exists("/sdcard"):
|
||||
# Try alternative storage paths
|
||||
alternatives = [
|
||||
"/storage/emulated/0",
|
||||
"/storage/self/primary",
|
||||
]
|
||||
for alt in alternatives:
|
||||
if os.path.exists(alt):
|
||||
path = path.replace("/sdcard", alt, 1)
|
||||
break
|
||||
|
||||
return os.path.normpath(path)
|
||||
|
||||
|
||||
def is_termux() -> bool:
|
||||
"""Check if running in Termux environment."""
|
||||
return (
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Configuration selection for Local Swarm interactive mode."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from hardware.detector import HardwareProfile
|
||||
from models.registry import Model, list_models
|
||||
from models.selector import ModelConfig, select_optimal_model, calculate_max_instances
|
||||
from interactive.ui import print_section, MenuOption, display_menu
|
||||
|
||||
|
||||
def get_recommended_config(
|
||||
hardware: HardwareProfile,
|
||||
context_size: int = 32768,
|
||||
offload_percent: float = 0.0
|
||||
) -> Optional[ModelConfig]:
|
||||
"""Get the recommended configuration for the hardware with context and offload settings."""
|
||||
use_mlx = hardware.is_apple_silicon if hardware else False
|
||||
return select_optimal_model(
|
||||
hardware,
|
||||
context_size=context_size,
|
||||
offload_percent=offload_percent,
|
||||
use_mlx=use_mlx
|
||||
)
|
||||
|
||||
|
||||
def list_available_configurations(
|
||||
hardware: HardwareProfile,
|
||||
context_size: int = 32768,
|
||||
offload_percent: float = 0.0
|
||||
) -> List[Tuple[str, ModelConfig]]:
|
||||
"""List all feasible configurations for the hardware with context and offload settings."""
|
||||
from models.selector import calculate_memory_with_offload, get_available_memory_with_offload
|
||||
|
||||
configs = []
|
||||
available_vram, available_ram = get_available_memory_with_offload(hardware, offload_percent)
|
||||
|
||||
# Use MLX models on Apple Silicon
|
||||
use_mlx = hardware.is_apple_silicon if hardware else False
|
||||
is_mac = use_mlx
|
||||
|
||||
for model in list_models(use_mlx=use_mlx):
|
||||
for variant in model.variants:
|
||||
for quant in variant.quantizations:
|
||||
# Calculate memory with context and offload
|
||||
if 'bit' in quant.name:
|
||||
quantization_bits = int(quant.name.replace('bit', ''))
|
||||
elif 'q4' in quant.name:
|
||||
quantization_bits = 4
|
||||
elif 'q5' in quant.name:
|
||||
quantization_bits = 5
|
||||
elif 'q6' in quant.name:
|
||||
quantization_bits = 6
|
||||
else:
|
||||
quantization_bits = 4
|
||||
|
||||
vram_per_instance, ram_per_instance = calculate_memory_with_offload(
|
||||
quant.vram_gb, context_size, offload_percent, quantization_bits
|
||||
)
|
||||
|
||||
# Check if at least 1 instance fits in VRAM
|
||||
if vram_per_instance <= available_vram:
|
||||
if is_mac:
|
||||
num_responses = 3
|
||||
total_memory = vram_per_instance + ram_per_instance
|
||||
else:
|
||||
num_responses = calculate_max_instances(available_vram, vram_per_instance)
|
||||
total_memory = (vram_per_instance + ram_per_instance) * num_responses
|
||||
|
||||
config = ModelConfig(
|
||||
model=model,
|
||||
variant=variant,
|
||||
quantization=quant,
|
||||
instances=num_responses,
|
||||
memory_per_instance_gb=vram_per_instance + ram_per_instance,
|
||||
total_memory_gb=total_memory,
|
||||
context_size=context_size,
|
||||
offload_percent=offload_percent,
|
||||
vram_usage_gb=vram_per_instance,
|
||||
ram_usage_gb=ram_per_instance
|
||||
)
|
||||
|
||||
ctx_label = model.context_label
|
||||
label = f"{model.name} [{ctx_label}] {variant.size} ({quant.name})"
|
||||
configs.append((label, config))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def select_context_size() -> int:
|
||||
"""Let user select context window size."""
|
||||
print_section("Context Size Selection")
|
||||
print(" Context window determines how much text the model can process at once.")
|
||||
print(" Larger context = more memory usage but can handle longer code files.\n")
|
||||
|
||||
options = [
|
||||
MenuOption("1", "16K tokens", "Good for small code files"),
|
||||
MenuOption("2", "32K tokens (Recommended)", "Best balance for most users"),
|
||||
MenuOption("3", "64K tokens", "Large codebases"),
|
||||
MenuOption("4", "128K tokens", "Very large files (uses more memory)"),
|
||||
]
|
||||
|
||||
choice = display_menu(options, "Select Context Size")
|
||||
|
||||
context_map = {"1": 16384, "2": 32768, "3": 65536, "4": 131072}
|
||||
return context_map.get(choice, 32768)
|
||||
|
||||
|
||||
def select_offload_option() -> float:
|
||||
"""Let user select offloading option."""
|
||||
print_section("Memory Offloading")
|
||||
print(" Offloading moves some model layers to system RAM.")
|
||||
print(" This allows larger models/contexts but may be slower.\n")
|
||||
|
||||
options = [
|
||||
MenuOption("1", "No offload (Default)", "100% GPU VRAM - fastest"),
|
||||
MenuOption("2", "20% offload", "80% GPU + 20% RAM - balanced"),
|
||||
MenuOption("3", "50% offload", "50% GPU + 50% RAM - maximum capacity"),
|
||||
]
|
||||
|
||||
choice = display_menu(options, "Select Offloading")
|
||||
|
||||
offload_map = {"1": 0.0, "2": 0.2, "3": 0.5}
|
||||
return offload_map.get(choice, 0.0)
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Display functions for Local Swarm interactive mode.
|
||||
|
||||
Hardware info and resource usage display.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from hardware.detector import HardwareProfile
|
||||
from interactive.ui import print_section
|
||||
|
||||
|
||||
def print_hardware_info(hardware: HardwareProfile) -> None:
|
||||
"""Print detailed hardware information."""
|
||||
print_section("Hardware Detection")
|
||||
|
||||
print(f" Operating System: {hardware.os.capitalize()}")
|
||||
print(f" CPU: {hardware.cpu_cores} cores")
|
||||
print(f" System RAM: {hardware.ram_gb:.1f} GB")
|
||||
print(f" Available RAM: {hardware.ram_available_gb:.1f} GB")
|
||||
|
||||
if hardware.gpu:
|
||||
print(f"\n GPU Detected:")
|
||||
print(f" Name: {hardware.gpu.name}")
|
||||
if hardware.is_apple_silicon:
|
||||
print(f" Type: Apple Silicon (Unified Memory)")
|
||||
print(f" Total Memory: {hardware.gpu.vram_gb:.1f} GB")
|
||||
else:
|
||||
print(f" Type: {hardware.gpu.name}")
|
||||
print(f" VRAM: {hardware.gpu.vram_gb:.1f} GB")
|
||||
if hardware.gpu.driver_version:
|
||||
print(f" Driver: {hardware.gpu.driver_version}")
|
||||
else:
|
||||
print(f"\n GPU: None detected (CPU-only mode)")
|
||||
|
||||
if hardware.has_dedicated_gpu:
|
||||
# Dedicated GPU: hard limit based on VRAM
|
||||
print(f"\n Available for LLMs: {hardware.available_memory_gb:.1f} GB")
|
||||
print(f" (Using 100% of GPU VRAM minus buffer)")
|
||||
elif hardware.is_apple_silicon:
|
||||
# Apple Silicon: show recommendation vs limit (like CPU-only)
|
||||
print(f"\n Recommended for LLMs: {hardware.recommended_memory_gb:.1f} GB (50% of unified memory)")
|
||||
print(f" Maximum available: {hardware.available_memory_gb:.1f} GB (unified memory - 4GB safety)")
|
||||
else:
|
||||
# CPU-only: show recommendation vs limit
|
||||
print(f"\n Recommended for LLMs: {hardware.recommended_memory_gb:.1f} GB (50% of RAM)")
|
||||
print(f" Maximum available: {hardware.available_memory_gb:.1f} GB (system RAM - 4GB safety)")
|
||||
|
||||
|
||||
def print_resource_usage(swarm_manager) -> None:
|
||||
"""Print current resource usage if swarm is running."""
|
||||
if swarm_manager is None:
|
||||
return
|
||||
|
||||
print_section("Current Resource Usage")
|
||||
|
||||
status = swarm_manager.get_status()
|
||||
workers = swarm_manager.get_worker_info()
|
||||
|
||||
print(f" Swarm Status: {'Running' if status.is_running else 'Stopped'}")
|
||||
print(f" Model: {status.model_name}")
|
||||
print(f" Workers: {status.healthy_workers}/{status.total_workers} healthy")
|
||||
print(f" Consensus Strategy: {status.strategy}")
|
||||
print(f" Memory Usage: {status.total_memory_gb:.2f} GB")
|
||||
print(f" Memory per Worker: {status.total_memory_gb / status.total_workers:.2f} GB" if status.total_workers > 0 else " Memory per Worker: N/A")
|
||||
|
||||
if workers:
|
||||
print(f"\n Worker Details:")
|
||||
for w in workers:
|
||||
status_icon = "✓" if w.is_healthy else "✗"
|
||||
|
||||
# Show IP for remote workers
|
||||
location = f" [{w.ip_address}]" if w.is_remote and w.ip_address else ""
|
||||
|
||||
print(f" [{status_icon}] {w.name}{location}: {w.backend_name}")
|
||||
|
||||
# Show live data if available
|
||||
if w.is_generating:
|
||||
progress_bar = "█" * int(w.progress / 5) + "░" * (20 - int(w.progress / 5))
|
||||
print(f" 🔄 Generating: {progress_bar} ({w.progress:.0f}%)")
|
||||
print(f" 📏 Context: {w.context_used:,} tokens")
|
||||
if w.last_output:
|
||||
preview = w.last_output[:60].replace('\n', ' ')
|
||||
print(f" 💬 Last: {preview}...")
|
||||
|
||||
if w.stats.total_requests > 0:
|
||||
print(f" 📊 Requests: {w.stats.total_requests}")
|
||||
print(f" ⏱️ Avg Latency: {w.stats.avg_latency_ms:.1f}ms")
|
||||
print(f" 🚀 Tokens/sec: {w.stats.tokens_per_second:.1f}")
|
||||
@@ -0,0 +1,226 @@
|
||||
"""Tips and help content for Local Swarm.
|
||||
|
||||
Educational content about models, quantization, and optimization.
|
||||
"""
|
||||
|
||||
from hardware.detector import HardwareProfile
|
||||
from interactive.ui import clear_screen, print_header, print_section
|
||||
|
||||
|
||||
def show_model_recommendations():
|
||||
"""Display model recommendations."""
|
||||
clear_screen()
|
||||
print_header("Model Recommendations")
|
||||
|
||||
print_section("Best Models for Coding (Ranked)")
|
||||
print("""
|
||||
🥇 Qwen 2.5 Coder - BEST OVERALL
|
||||
• Excellent code completion and generation
|
||||
• Strong performance even at smaller sizes (3B)
|
||||
• Good at following instructions
|
||||
• Recommended for most users
|
||||
|
||||
🥈 DeepSeek Coder - GREAT ALTERNATIVE
|
||||
• Very capable on coding tasks
|
||||
• Good balance of speed and quality
|
||||
• Smaller 1.3B option for low-end hardware
|
||||
|
||||
🥉 CodeLlama - SOLID CHOICE
|
||||
• Meta's dedicated code model
|
||||
• Good performance, widely tested
|
||||
• Larger sizes (13B+) for complex tasks
|
||||
|
||||
Other Good Options:
|
||||
• Llama 3.2 - General model with good coding skills
|
||||
• Phi-4 - Microsoft's efficient small model
|
||||
• Gemma 2 - Google's open model
|
||||
• StarCoder2 - Good for code completion
|
||||
|
||||
Which size to choose?
|
||||
• 1-3B: Fast, good for simple tasks, low VRAM
|
||||
• 7B: Sweet spot for most users
|
||||
• 13-15B: Better quality, needs more VRAM
|
||||
• 30B+: Best quality but very slow
|
||||
""")
|
||||
input("\n Press Enter to continue...")
|
||||
|
||||
|
||||
def show_quantization_guide():
|
||||
"""Display quantization guide."""
|
||||
clear_screen()
|
||||
print_header("Quantization Guide")
|
||||
|
||||
print_section("What is Quantization?")
|
||||
print("""
|
||||
Quantization compresses the model to use less memory.
|
||||
Lower precision = smaller size = faster loading
|
||||
But may reduce quality slightly.
|
||||
""")
|
||||
|
||||
print_section("Quantization Levels")
|
||||
print("""
|
||||
Q4_K_M (Good) - RECOMMENDED FOR MOST USERS
|
||||
• 4-bit quantization with medium quality
|
||||
• ~70% smaller than original
|
||||
• Minimal quality loss for coding
|
||||
• Best speed/memory/quality balance
|
||||
• Use this if unsure!
|
||||
|
||||
Q5_K_M (Better)
|
||||
• 5-bit quantization with better quality
|
||||
• ~60% smaller than original
|
||||
• Better for complex reasoning
|
||||
• Slightly more VRAM needed
|
||||
|
||||
Q6_K (Best)
|
||||
• 6-bit quantization with highest quality
|
||||
• ~50% smaller than original
|
||||
• Close to original model quality
|
||||
• Requires more VRAM
|
||||
• Use if you have plenty of memory
|
||||
|
||||
When to use each:
|
||||
• Q4_K_M: Default choice, works great
|
||||
• Q5_K_M: If you have extra VRAM, want better quality
|
||||
• Q6_K: If VRAM is abundant, want best quality
|
||||
""")
|
||||
|
||||
print_section("Quick Reference")
|
||||
print("""
|
||||
Size comparison for 7B model:
|
||||
• Original (FP16): ~14 GB
|
||||
• Q6_K: ~6 GB
|
||||
• Q5_K_M: ~5.2 GB
|
||||
• Q4_K_M: ~4.5 GB
|
||||
""")
|
||||
input("\n Press Enter to continue...")
|
||||
|
||||
|
||||
def show_instance_tips(hardware: HardwareProfile):
|
||||
"""Display tips for optimal instance count."""
|
||||
clear_screen()
|
||||
print_header("Instance Count Optimization")
|
||||
|
||||
print_section("What Are Instances?")
|
||||
print("""
|
||||
Each instance = one copy of the model running.
|
||||
Multiple instances = multiple workers voting on answers.
|
||||
More instances = better consensus but uses more memory.
|
||||
""")
|
||||
|
||||
print_section("Recommended Instance Counts")
|
||||
print(f"""
|
||||
Based on your hardware ({hardware.available_memory_gb:.1f} GB available):
|
||||
|
||||
Minimum: 2 instances
|
||||
• Required for consensus voting
|
||||
• Detects bad/hallucinated responses
|
||||
• Better than single model
|
||||
|
||||
Good Range: 3-5 instances
|
||||
• Most common setup
|
||||
• Good consensus quality
|
||||
• Reasonable memory usage
|
||||
• Recommended sweet spot
|
||||
|
||||
Maximum: 8 instances
|
||||
• Best consensus quality
|
||||
• Higher memory usage
|
||||
• Diminishing returns after 5-6
|
||||
• Use only if VRAM abundant
|
||||
|
||||
Research Note:
|
||||
Studies show consensus with 3-5 models gives 85-90%
|
||||
of the benefit, with minimal overhead. More than 8
|
||||
provides minimal improvement.
|
||||
""")
|
||||
|
||||
print_section("Memory Calculation Example")
|
||||
print(f"""
|
||||
Your available memory: {hardware.available_memory_gb:.1f} GB
|
||||
|
||||
Example: 7B model at Q4_K_M (4.5 GB per instance)
|
||||
• 2 instances: 9.0 GB used
|
||||
• 3 instances: 13.5 GB used
|
||||
• 4 instances: 18.0 GB used
|
||||
|
||||
Rule of thumb: Leave 10% buffer for overhead
|
||||
""")
|
||||
input("\n Press Enter to continue...")
|
||||
|
||||
|
||||
def show_hardware_tips(hardware: HardwareProfile):
|
||||
"""Display hardware-specific tips."""
|
||||
clear_screen()
|
||||
print_header("Hardware Optimization Tips")
|
||||
|
||||
print_section("Your Hardware Profile")
|
||||
print(f"""
|
||||
OS: {hardware.os.capitalize()}
|
||||
CPU: {hardware.cpu_cores} cores
|
||||
Available Memory: {hardware.available_memory_gb:.1f} GB
|
||||
GPU: {hardware.gpu.name if hardware.gpu else "None (CPU mode)"}
|
||||
""")
|
||||
|
||||
if hardware.is_apple_silicon:
|
||||
print_section("Apple Silicon Tips")
|
||||
print("""
|
||||
✓ Using MLX backend (optimized for Metal)
|
||||
✓ Unified memory architecture
|
||||
✓ 50% of RAM allocated for LLMs
|
||||
|
||||
Tips:
|
||||
• Use Q4_K_M quantization for best balance
|
||||
• 7B models work great on 16GB+ Macs
|
||||
• 3B models good for 8GB Macs
|
||||
• M1/M2/M3 all supported
|
||||
• Close other apps for best performance
|
||||
""")
|
||||
elif hardware.gpu and not hardware.is_apple_silicon:
|
||||
print_section("Discrete GPU Tips")
|
||||
print(f"""
|
||||
✓ GPU: {hardware.gpu.name}
|
||||
✓ Using 100% of VRAM
|
||||
|
||||
Tips:
|
||||
• Install CUDA/ROCm drivers for acceleration
|
||||
• Use Q4_K_M or Q5_K_M quantization
|
||||
• Monitor GPU temperature during long runs
|
||||
• Close GPU-intensive apps (games, etc.)
|
||||
• 7B-13B models work well on 8-16GB VRAM
|
||||
""")
|
||||
else:
|
||||
print_section("CPU-Only Tips")
|
||||
print("""
|
||||
✓ Running in CPU mode
|
||||
✓ 50% of system RAM allocated
|
||||
|
||||
Tips:
|
||||
• Use smaller models (3B-4B) for speed
|
||||
• Use Q4_K_M quantization
|
||||
• Fewer instances (2-3) recommended
|
||||
• Expect slower generation than GPU
|
||||
• Good for testing, not production use
|
||||
• Consider cloud GPU for heavy use
|
||||
""")
|
||||
|
||||
print_section("General Optimization")
|
||||
print("""
|
||||
Speed vs Quality:
|
||||
• Smaller models (3B) = faster, less capable
|
||||
• Larger models (7B+) = slower, smarter
|
||||
• Q4 = faster, less precise
|
||||
• Q6 = slower, more precise
|
||||
|
||||
Memory Management:
|
||||
• Leave 10-20% RAM/VRAM free
|
||||
• Close browsers and heavy apps
|
||||
• Use swap if necessary (slower)
|
||||
|
||||
Best Practices:
|
||||
• Start with recommended config
|
||||
• Test with --test flag first
|
||||
• Monitor memory usage
|
||||
• Adjust instances based on performance
|
||||
""")
|
||||
input("\n Press Enter to continue...")
|
||||
@@ -0,0 +1,63 @@
|
||||
"""UI utilities for Local Swarm interactive mode.
|
||||
|
||||
Terminal display helpers and formatting functions.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MenuOption:
|
||||
"""A menu option."""
|
||||
key: str
|
||||
label: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
def clear_screen():
|
||||
"""Clear the terminal screen."""
|
||||
subprocess.run(['cls' if os.name == 'nt' else 'clear'], shell=True, check=False)
|
||||
|
||||
|
||||
def print_header(title: str):
|
||||
"""Print a formatted header."""
|
||||
width = 70
|
||||
print("=" * width)
|
||||
print(f" {title}".ljust(width))
|
||||
print("=" * width)
|
||||
print()
|
||||
|
||||
|
||||
def print_section(title: str):
|
||||
"""Print a section title."""
|
||||
print(f"\n{'─' * 70}")
|
||||
print(f" {title}")
|
||||
print(f"{'─' * 70}")
|
||||
|
||||
|
||||
def display_menu(options: List[MenuOption], title: str = "Menu") -> str:
|
||||
"""Display a menu and return the user's choice.
|
||||
|
||||
Args:
|
||||
options: List of menu options
|
||||
title: Menu title
|
||||
|
||||
Returns:
|
||||
Selected option key
|
||||
"""
|
||||
print_section(title)
|
||||
|
||||
for opt in options:
|
||||
desc = f" - {opt.description}" if opt.description else ""
|
||||
print(f" [{opt.key}] {opt.label}{desc}")
|
||||
|
||||
print()
|
||||
while True:
|
||||
choice = input(" Enter your choice: ").strip().lower()
|
||||
valid_keys = [opt.key.lower() for opt in options]
|
||||
if choice in valid_keys:
|
||||
return choice
|
||||
print(f" Invalid choice. Please enter one of: {', '.join(valid_keys)}")
|
||||
+71
-33
@@ -87,7 +87,7 @@ class Model:
|
||||
|
||||
# MLX quantization sizes (GB) based on mlx-community models
|
||||
# HARDOCODED: These are verified to exist on HuggingFace mlx-community
|
||||
# Last verified: 2025-02-23
|
||||
# Last verified: 2025-02-25
|
||||
# DO NOT make API calls on startup - use this hardcoded list
|
||||
MLX_QUANT_SIZES = {
|
||||
# Format: model_id: {variant_size: {quant_bit: vram_gb}}
|
||||
@@ -101,16 +101,15 @@ MLX_QUANT_SIZES = {
|
||||
# 5bit does NOT exist for 14b
|
||||
},
|
||||
"deepseek-coder": {
|
||||
"1.3b": {"4bit": 0.8, "6bit": 1.2},
|
||||
# 3bit, 5bit, 8bit do NOT exist
|
||||
"6.7b": {"4bit": 3.9, "6bit": 5.9, "8bit": 7.9},
|
||||
# 3bit, 5bit do NOT exist
|
||||
"1.3b": {}, # Only base models exist, no quantized versions
|
||||
"6.7b": {"4bit": 3.9}, # Only 4bit exists (base and instruct)
|
||||
},
|
||||
"deepseek-coder-v2-lite": {
|
||||
"instruct": {"4bit": 4.5, "6bit": 6.5, "8bit": 8.5}, # V2 Lite has better MLX support
|
||||
},
|
||||
"codellama": {
|
||||
"7b": {"4bit": 4.1, "6bit": 6.1, "8bit": 8.1},
|
||||
# 3bit, 5bit do NOT exist
|
||||
"13b": {"4bit": 7.6, "6bit": 11.4, "8bit": 15.2},
|
||||
# 3bit, 5bit do NOT exist
|
||||
"7b": {"4bit": 4.1, "6bit": 6.1, "8bit": 8.1}, # Instruct variants only
|
||||
"13b": {"4bit": 7.6, "6bit": 11.4, "8bit": 15.2}, # Instruct variants only
|
||||
},
|
||||
"llama-3.2": {
|
||||
"1b": {"4bit": 0.6, "8bit": 1.2},
|
||||
@@ -131,12 +130,9 @@ MLX_QUANT_SIZES = {
|
||||
# 3bit, 5bit do NOT exist
|
||||
},
|
||||
"starcoder2": {
|
||||
"3b": {"4bit": 1.8, "6bit": 2.6, "8bit": 3.5},
|
||||
# 3bit, 5bit do NOT exist
|
||||
"7b": {"4bit": 4.1, "6bit": 6.1, "8bit": 8.1},
|
||||
# 3bit, 5bit do NOT exist
|
||||
"15b": {"4bit": 8.8, "6bit": 13.2, "8bit": 17.6},
|
||||
# 3bit, 5bit do NOT exist
|
||||
"3b": {"4bit": 1.8}, # Only 4bit exists
|
||||
"7b": {"4bit": 4.1}, # Only 4bit exists
|
||||
"15b": {"4bit": 8.8, "8bit": 17.6}, # Has 4bit base, 4bit/8bit instruct variants
|
||||
},
|
||||
}
|
||||
|
||||
@@ -165,6 +161,13 @@ MODEL_METADATA = {
|
||||
"max_context": 16384,
|
||||
"variants": ["1.3b", "6.7b"],
|
||||
},
|
||||
"deepseek-coder-v2-lite": {
|
||||
"name": "DeepSeek Coder V2 Lite",
|
||||
"description": "DeepSeek's V2 Lite model with better MLX support",
|
||||
"priority": 2,
|
||||
"max_context": 16384,
|
||||
"variants": ["instruct"],
|
||||
},
|
||||
"codellama": {
|
||||
"name": "CodeLlama",
|
||||
"description": "Meta's code model",
|
||||
@@ -364,25 +367,60 @@ def get_model_hf_repo_mlx(model_id: str, variant: ModelVariant, quant: Quantizat
|
||||
"q8": "8bit",
|
||||
}
|
||||
|
||||
# MLX quantized models are in mlx-community org with -{quant}bit suffix
|
||||
# Map base model names to mlx-community quantized versions
|
||||
mlx_repo_map = {
|
||||
"qwen2.5-coder": f"mlx-community/Qwen2.5-Coder-{variant.size.capitalize()}-Instruct",
|
||||
"deepseek-coder": f"mlx-community/deepseek-coder-{variant.size}-base",
|
||||
"codellama": f"mlx-community/CodeLlama-{variant.size}-Instruct",
|
||||
"llama-3.2": f"mlx-community/Llama-3.2-{variant.size}-Instruct",
|
||||
"phi-4": f"mlx-community/phi-4",
|
||||
"gemma-2": f"mlx-community/gemma-2-{variant.size}-it",
|
||||
"starcoder2": f"mlx-community/starcoder2-{variant.size}",
|
||||
}
|
||||
# Convert GGUF quant name to MLX quant name
|
||||
mlx_quant = gguf_to_mlx_quant.get(quant.name, quant.name) if quant else None
|
||||
|
||||
base_repo = mlx_repo_map.get(model_id, "")
|
||||
if base_repo and quant:
|
||||
# Convert GGUF quant name to MLX quant name
|
||||
mlx_quant = gguf_to_mlx_quant.get(quant.name, quant.name)
|
||||
# Append quantization suffix
|
||||
return f"{base_repo}-{mlx_quant}"
|
||||
return base_repo
|
||||
# MLX quantized models are in mlx-community org
|
||||
# Repository naming varies by model - these are verified to exist on HF
|
||||
if model_id == "qwen2.5-coder":
|
||||
# Qwen: mlx-community/Qwen2.5-Coder-{Size}-Instruct-{quant}bit
|
||||
return f"mlx-community/Qwen2.5-Coder-{variant.size.capitalize()}-Instruct-{mlx_quant}"
|
||||
|
||||
elif model_id == "deepseek-coder":
|
||||
# DeepSeek: Very limited MLX support
|
||||
# 1.3b: Only base models exist (no quantized versions)
|
||||
# 6.7b: mlx-community/deepseek-coder-6.7b-base-4bit-mlx (base only)
|
||||
# mlx-community/deepseek-coder-6.7b-instruct-hf-4bit-mlx (instruct)
|
||||
if variant.size == "1.3b":
|
||||
# Only base model exists, no quantization
|
||||
return "mlx-community/deepseek-coder-1.3b-base-mlx"
|
||||
elif variant.size == "6.7b":
|
||||
# Use instruct variant (better for coding) with hf-{quant}bit-mlx suffix
|
||||
return f"mlx-community/deepseek-coder-6.7b-instruct-hf-{mlx_quant}-mlx"
|
||||
|
||||
elif model_id == "deepseek-coder-v2-lite":
|
||||
# DeepSeek Coder V2 Lite: Has good MLX support
|
||||
# mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx
|
||||
# mlx-community/DeepSeek-Coder-V2-Lite-Instruct-6bit
|
||||
# mlx-community/DeepSeek-Coder-V2-Lite-Instruct-8bit
|
||||
if mlx_quant == "4bit":
|
||||
return "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx"
|
||||
else:
|
||||
# 6bit and 8bit don't have -mlx suffix
|
||||
return f"mlx-community/DeepSeek-Coder-V2-Lite-Instruct-{mlx_quant}"
|
||||
|
||||
elif model_id == "codellama":
|
||||
# CodeLlama: mlx-community/CodeLlama-{size}-Instruct-hf-{quant}bit-mlx
|
||||
# Only Instruct variants have quantized versions
|
||||
return f"mlx-community/CodeLlama-{variant.size}-Instruct-hf-{mlx_quant}-mlx"
|
||||
|
||||
elif model_id == "llama-3.2":
|
||||
# Llama 3.2: mlx-community/Llama-3.2-{size}-Instruct-{quant}bit
|
||||
return f"mlx-community/Llama-3.2-{variant.size}-Instruct-{mlx_quant}"
|
||||
|
||||
elif model_id == "phi-4":
|
||||
# Phi-4: mlx-community/phi-4-{quant}bit
|
||||
return f"mlx-community/phi-4-{mlx_quant}"
|
||||
|
||||
elif model_id == "gemma-2":
|
||||
# Gemma 2: mlx-community/gemma-2-{size}-it-{quant}bit
|
||||
return f"mlx-community/gemma-2-{variant.size}-it-{mlx_quant}"
|
||||
|
||||
elif model_id == "starcoder2":
|
||||
# StarCoder2: mlx-community/starcoder2-{size}-{quant}bit
|
||||
return f"mlx-community/starcoder2-{variant.size}-{mlx_quant}"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def get_model_filename(model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> str:
|
||||
|
||||
@@ -6,10 +6,43 @@ Uses mDNS/Bonjour to discover other Local Swarm instances on the local network.
|
||||
import socket
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeerMetrics:
|
||||
"""Metrics for tracking peer performance."""
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
total_latency_ms: float = 0.0
|
||||
avg_latency_ms: float = 0.0
|
||||
last_error: Optional[str] = None
|
||||
last_error_time: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate success rate (0.0 to 1.0)."""
|
||||
if self.total_requests == 0:
|
||||
return 1.0
|
||||
return self.successful_requests / self.total_requests
|
||||
|
||||
def record_success(self, latency_ms: float):
|
||||
"""Record a successful request."""
|
||||
self.total_requests += 1
|
||||
self.successful_requests += 1
|
||||
self.total_latency_ms += latency_ms
|
||||
self.avg_latency_ms = self.total_latency_ms / self.successful_requests
|
||||
|
||||
def record_failure(self, error: str):
|
||||
"""Record a failed request."""
|
||||
self.total_requests += 1
|
||||
self.failed_requests += 1
|
||||
self.last_error = error
|
||||
self.last_error_time = datetime.now()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeerInfo:
|
||||
"""Information about a peer swarm."""
|
||||
@@ -21,6 +54,8 @@ class PeerInfo:
|
||||
model_id: str
|
||||
hardware_summary: str
|
||||
last_seen: datetime
|
||||
timeout_seconds: float = 60.0 # Configurable timeout per peer
|
||||
metrics: PeerMetrics = field(default_factory=PeerMetrics)
|
||||
|
||||
@property
|
||||
def api_url(self) -> str:
|
||||
@@ -100,6 +135,8 @@ class SwarmDiscovery:
|
||||
await asyncio.to_thread(self._zeroconf.register_service, self._info)
|
||||
print(f" ✓ Advertising on mDNS: {service_name}")
|
||||
print(f" IP: {ip}:{self.listen_port}")
|
||||
print(f" Service type: {self.SERVICE_TYPE}")
|
||||
print(f" Properties: instances={swarm_info.get('instances', 0)}, model={swarm_info.get('model_id', 'unknown')}")
|
||||
|
||||
except ImportError:
|
||||
print(" ⚠️ zeroconf not installed, skipping mDNS advertising")
|
||||
@@ -117,6 +154,10 @@ class SwarmDiscovery:
|
||||
self._async_zeroconf = AsyncZeroconf()
|
||||
self._zeroconf = self._async_zeroconf.zeroconf
|
||||
|
||||
# Store event loop reference for callbacks
|
||||
self._loop = asyncio.get_event_loop()
|
||||
print(f" Event loop: {self._loop}")
|
||||
|
||||
# Create async browser (passes the underlying Zeroconf instance)
|
||||
self._browser = AsyncServiceBrowser(
|
||||
self._zeroconf,
|
||||
@@ -125,6 +166,7 @@ class SwarmDiscovery:
|
||||
)
|
||||
|
||||
print(f" ✓ Listening for peers on {self.SERVICE_TYPE}")
|
||||
print(f" Will discover peers advertising on same network")
|
||||
self._running = True
|
||||
|
||||
except ImportError:
|
||||
@@ -136,16 +178,23 @@ class SwarmDiscovery:
|
||||
"""Handle mDNS service state changes (called from zeroconf background thread)."""
|
||||
from zeroconf import ServiceStateChange
|
||||
|
||||
print(f" [mDNS] Service state change: {name} -> {state_change.name}")
|
||||
|
||||
if state_change == ServiceStateChange.Added:
|
||||
print(f" [mDNS] Service added: {name}")
|
||||
# Schedule coroutine on the event loop from this background thread
|
||||
if self._loop is not None and self._loop.is_running():
|
||||
print(f" [mDNS] Scheduling peer addition...")
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._add_peer(zeroconf, service_type, name),
|
||||
self._loop
|
||||
)
|
||||
else:
|
||||
print(f" [mDNS] Warning: Event loop not available")
|
||||
elif state_change == ServiceStateChange.Removed:
|
||||
# Service removed
|
||||
peer_key = name.replace(f".{self.SERVICE_TYPE}", "")
|
||||
print(f" [mDNS] Service removed: {peer_key}")
|
||||
if peer_key in self.peers:
|
||||
del self.peers[peer_key]
|
||||
print(f" 👋 Peer left: {peer_key}")
|
||||
@@ -292,13 +341,13 @@ class SwarmDiscovery:
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
# Verify it's the correct private IP (192.168.x.x only for this network)
|
||||
is_private = ip.startswith('192.168.')
|
||||
|
||||
if is_private:
|
||||
# Only bind to 192.168.x.x as requested
|
||||
if ip.startswith('192.168.'):
|
||||
print(f" ✓ Using IP: {ip}")
|
||||
return ip
|
||||
else:
|
||||
print(f" ⚠️ IP {ip} is not private, using localhost")
|
||||
print(f" ⚠️ IP {ip} is not 192.168.x.x, using localhost")
|
||||
print(f" Federation requires 192.168.x.x network")
|
||||
return '127.0.0.1'
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Error detecting IP: {e}")
|
||||
|
||||
+170
-70
@@ -5,7 +5,7 @@ Handles communication between peer swarms for distributed consensus.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from network.discovery import PeerInfo
|
||||
@@ -20,6 +20,8 @@ class PeerVote:
|
||||
confidence: float
|
||||
latency_ms: float
|
||||
worker_count: int
|
||||
tokens_per_second: float = 0.0
|
||||
tokens_generated: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,12 +31,14 @@ class FederationResult:
|
||||
local_confidence: float
|
||||
peer_votes: List[PeerVote]
|
||||
strategy: str
|
||||
winner: str = "" # Name of the winning node ("local" or peer name)
|
||||
global_tokens_per_second: float = 0.0 # Includes sync + voting overhead
|
||||
|
||||
|
||||
class FederationClient:
|
||||
"""Client for communicating with peer swarms."""
|
||||
|
||||
def __init__(self, timeout: float = 30.0):
|
||||
def __init__(self, timeout: float = 60.0):
|
||||
"""
|
||||
Initialize federation client.
|
||||
|
||||
@@ -79,42 +83,58 @@ class FederationClient:
|
||||
Returns:
|
||||
PeerVote or None if request failed
|
||||
"""
|
||||
request_start = time.time()
|
||||
# Use peer-specific timeout if available, otherwise use default
|
||||
timeout = getattr(peer, 'timeout_seconds', self.timeout)
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
session = await self._get_session()
|
||||
# Create session with peer-specific timeout
|
||||
session_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=session_timeout) as session:
|
||||
url = f"{peer.api_url}/v1/federation/vote"
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"request_id": f"fed_{time.time()}"
|
||||
}
|
||||
|
||||
url = f"{peer.api_url}/v1/federation/vote"
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"request_id": f"fed_{time.time()}"
|
||||
}
|
||||
print(f" → Sending request to {url} (timeout: {timeout}s)")
|
||||
async with session.post(url, json=payload) as resp:
|
||||
print(f" ← Got response {resp.status} from {peer.name}")
|
||||
if resp.status != 200:
|
||||
print(f" ✗ Peer {peer.name} returned status {resp.status}")
|
||||
peer.metrics.record_failure(f"HTTP {resp.status}")
|
||||
return None
|
||||
|
||||
print(f" → Sending request to {url}")
|
||||
async with session.post(url, json=payload) as resp:
|
||||
print(f" ← Got response {resp.status} from {peer.name}")
|
||||
if resp.status != 200:
|
||||
print(f" ✗ Peer {peer.name} returned status {resp.status}")
|
||||
return None
|
||||
data = await resp.json()
|
||||
latency_ms = (time.time() - request_start) * 1000
|
||||
print(f" ✓ Peer {peer.name} responded successfully ({latency_ms:.0f}ms)")
|
||||
|
||||
# Record success metrics
|
||||
peer.metrics.record_success(latency_ms)
|
||||
|
||||
data = await resp.json()
|
||||
print(f" ✓ Peer {peer.name} responded successfully")
|
||||
|
||||
return PeerVote(
|
||||
peer_name=peer.name,
|
||||
response_text=data.get("response", ""),
|
||||
confidence=data.get("confidence", 0.5),
|
||||
latency_ms=data.get("latency_ms", 0),
|
||||
worker_count=data.get("worker_count", 0)
|
||||
)
|
||||
return PeerVote(
|
||||
peer_name=peer.name,
|
||||
response_text=data.get("response", ""),
|
||||
confidence=data.get("confidence", 0.5),
|
||||
latency_ms=data.get("latency_ms", latency_ms),
|
||||
worker_count=data.get("worker_count", 0),
|
||||
tokens_per_second=data.get("tokens_per_second", 0.0),
|
||||
tokens_generated=data.get("tokens_generated", 0)
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f" ⚠️ Peer {peer.name} timed out (>{self.timeout}s)")
|
||||
error_msg = f"Timeout ({timeout}s)"
|
||||
print(f" ⚠️ Peer {peer.name} {error_msg}")
|
||||
peer.metrics.record_failure(error_msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Error contacting peer {peer.name}: {e}")
|
||||
error_msg = str(e)
|
||||
print(f" ⚠️ Error contacting peer {peer.name}: {error_msg}")
|
||||
peer.metrics.record_failure(error_msg)
|
||||
return None
|
||||
|
||||
async def health_check(self, peer: PeerInfo) -> bool:
|
||||
@@ -172,6 +192,8 @@ class FederatedSwarm:
|
||||
) -> FederationResult:
|
||||
"""
|
||||
Generate with federation across peer swarms.
|
||||
|
||||
Optimized: Runs local and peer generation in parallel for maximum speed.
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
@@ -182,74 +204,131 @@ class FederatedSwarm:
|
||||
Returns:
|
||||
FederationResult with final response
|
||||
"""
|
||||
# Phase 1: Local generation and consensus
|
||||
print(f" 🏠 Local swarm generating...")
|
||||
local_result = await self.local_swarm.generate(
|
||||
peers = self.discovery.get_peers()
|
||||
|
||||
if len(peers) == 0:
|
||||
if min_peers > 0:
|
||||
raise RuntimeError(f"Federation requires {min_peers} peers, but none found")
|
||||
|
||||
# Solo mode - just run local generation
|
||||
print(f" 🏠 Solo mode - local swarm generating...")
|
||||
solo_start_time = time.time()
|
||||
local_result = await self.local_swarm.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
use_consensus=True
|
||||
)
|
||||
solo_end_time = time.time()
|
||||
total_elapsed = solo_end_time - solo_start_time
|
||||
tokens_generated = local_result.selected_response.tokens_generated
|
||||
global_tps = tokens_generated / total_elapsed if total_elapsed > 0 else 0.0
|
||||
|
||||
print(f"\n 📊 Global Performance:")
|
||||
print(f" Total tokens: {tokens_generated}")
|
||||
print(f" Total time: {total_elapsed:.2f}s")
|
||||
print(f" Global speed: {global_tps:.1f} t/s")
|
||||
|
||||
return FederationResult(
|
||||
final_response=local_result.selected_response.text,
|
||||
local_confidence=local_result.confidence,
|
||||
peer_votes=[],
|
||||
strategy="solo",
|
||||
global_tokens_per_second=global_tps
|
||||
)
|
||||
|
||||
# Parallel generation: Local swarm AND peers generate simultaneously
|
||||
print(f" 🏠 Local swarm AND {len(peers)} peer(s) generating in parallel...")
|
||||
|
||||
# Track timing for global t/sec calculation (includes sync + voting overhead)
|
||||
federation_start_time = time.time()
|
||||
total_tokens_generated = 0
|
||||
|
||||
# Start local generation
|
||||
local_task = self.local_swarm.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
use_consensus=True
|
||||
)
|
||||
|
||||
local_best = local_result.selected_response
|
||||
local_confidence = local_result.confidence
|
||||
|
||||
print(f" ✓ Local best (confidence: {local_confidence:.2f})")
|
||||
|
||||
# Phase 2: Collect peer votes
|
||||
peers = self.discovery.get_peers()
|
||||
|
||||
if len(peers) == 0:
|
||||
if min_peers > 0:
|
||||
raise RuntimeError(f"Federation requires {min_peers} peers, but none found")
|
||||
|
||||
# Solo mode - just return local result
|
||||
return FederationResult(
|
||||
final_response=local_best.text,
|
||||
local_confidence=local_confidence,
|
||||
peer_votes=[],
|
||||
strategy="solo"
|
||||
)
|
||||
|
||||
print(f" 🌐 Requesting votes from {len(peers)} peer(s)...")
|
||||
for peer in peers:
|
||||
print(f" → Contacting {peer.name} at {peer.api_url}")
|
||||
|
||||
peer_votes = []
|
||||
|
||||
# Start peer requests
|
||||
vote_tasks = [
|
||||
self.federation_client.request_vote(peer, prompt, max_tokens, temperature)
|
||||
for peer in peers
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*vote_tasks, return_exceptions=True)
|
||||
|
||||
for peer, result in zip(peers, results):
|
||||
|
||||
# Run everything in parallel
|
||||
all_tasks = [local_task] + vote_tasks
|
||||
results = await asyncio.gather(*all_tasks, return_exceptions=True)
|
||||
|
||||
# Separate local result from peer votes
|
||||
local_result_raw = results[0]
|
||||
if isinstance(local_result_raw, Exception):
|
||||
print(f" ✗ Local swarm failed: {local_result_raw}")
|
||||
raise RuntimeError(f"Local generation failed: {local_result_raw}")
|
||||
|
||||
from swarm.manager import ConsensusResult
|
||||
local_result: ConsensusResult = local_result_raw # Now guaranteed not to be an exception
|
||||
local_best = local_result.selected_response
|
||||
local_confidence = local_result.confidence
|
||||
local_tps = local_best.tokens_per_second
|
||||
total_tokens_generated += local_best.tokens_generated
|
||||
print(f" ✓ Local completed (confidence: {local_confidence:.2f}, {local_tps:.1f} t/s)")
|
||||
|
||||
# Collect peer votes
|
||||
peer_votes = []
|
||||
for peer, result in zip(peers, results[1:]):
|
||||
if isinstance(result, Exception):
|
||||
print(f" ✗ Peer {peer.name} failed: {result}")
|
||||
elif result is not None:
|
||||
peer_votes.append(result)
|
||||
print(f" ✓ Peer {peer.name} voted (confidence: {result.confidence:.2f})")
|
||||
total_tokens_generated += result.tokens_generated if hasattr(result, 'tokens_generated') else 0
|
||||
print(f" ✓ Peer {peer.name} completed (confidence: {result.confidence:.2f}, {result.tokens_per_second:.1f} t/s)")
|
||||
|
||||
if len(peer_votes) == 0:
|
||||
# No peers responded, use local result
|
||||
print(" ⚠️ No peers responded, using local result")
|
||||
|
||||
# Calculate global t/sec even in fallback mode
|
||||
federation_end_time = time.time()
|
||||
total_elapsed_seconds = federation_end_time - federation_start_time
|
||||
global_tps = total_tokens_generated / total_elapsed_seconds if total_elapsed_seconds > 0 else 0.0
|
||||
|
||||
print(f"\n 📊 Global Performance:")
|
||||
print(f" Total tokens: {total_tokens_generated}")
|
||||
print(f" Total time: {total_elapsed_seconds:.2f}s")
|
||||
print(f" Global speed: {global_tps:.1f} t/s")
|
||||
|
||||
return FederationResult(
|
||||
final_response=local_best.text,
|
||||
local_confidence=local_confidence,
|
||||
peer_votes=[],
|
||||
strategy="local_fallback"
|
||||
strategy="local_fallback",
|
||||
global_tokens_per_second=global_tps
|
||||
)
|
||||
|
||||
# Phase 3: Global consensus
|
||||
# Global consensus
|
||||
print(f" 🗳️ Running global consensus ({len(peer_votes) + 1} votes)...")
|
||||
final_response, winner = self._weighted_vote(local_best.text, local_confidence, peer_votes)
|
||||
|
||||
final_response = self._weighted_vote(local_best.text, local_confidence, peer_votes)
|
||||
# Calculate global tokens/sec including sync + voting overhead
|
||||
federation_end_time = time.time()
|
||||
total_elapsed_seconds = federation_end_time - federation_start_time
|
||||
global_tps = total_tokens_generated / total_elapsed_seconds if total_elapsed_seconds > 0 else 0.0
|
||||
|
||||
print(f"\n 📊 Global Performance:")
|
||||
print(f" Total tokens: {total_tokens_generated}")
|
||||
print(f" Total time: {total_elapsed_seconds:.2f}s")
|
||||
print(f" Global speed: {global_tps:.1f} t/s (includes sync + voting)")
|
||||
|
||||
return FederationResult(
|
||||
final_response=final_response,
|
||||
local_confidence=local_confidence,
|
||||
peer_votes=peer_votes,
|
||||
strategy=self.consensus_strategy
|
||||
strategy=self.consensus_strategy,
|
||||
winner=winner,
|
||||
global_tokens_per_second=global_tps
|
||||
)
|
||||
|
||||
def _weighted_vote(
|
||||
@@ -257,11 +336,14 @@ class FederatedSwarm:
|
||||
local_response: str,
|
||||
local_confidence: float,
|
||||
peer_votes: List[PeerVote]
|
||||
) -> str:
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Select best response using weighted voting.
|
||||
|
||||
Weights by confidence score. Higher confidence = more weight.
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_response, winner_name)
|
||||
"""
|
||||
# Collect all votes with their weights
|
||||
all_votes = [(local_response, local_confidence, "local")]
|
||||
@@ -292,15 +374,15 @@ class FederatedSwarm:
|
||||
best_idx = max(range(len(scores)), key=lambda i: scores[i])
|
||||
best = all_votes[best_idx]
|
||||
print(f" ✓ Selected response from {best[2]} (quality score: {scores[best_idx]:.2f})")
|
||||
return best[0]
|
||||
return best[0], best[2]
|
||||
|
||||
# Default: weighted selection - pick highest confidence
|
||||
best = max(all_votes, key=lambda x: x[1])
|
||||
print(f" ✓ Selected response from {best[2]} (confidence: {best[1]:.2f})")
|
||||
return best[0]
|
||||
return best[0], best[2]
|
||||
|
||||
async def get_federation_status(self) -> Dict[str, Any]:
|
||||
"""Get current federation status."""
|
||||
"""Get current federation status with peer metrics."""
|
||||
peers = self.discovery.get_peers()
|
||||
|
||||
# Check health of all peers
|
||||
@@ -308,7 +390,24 @@ class FederatedSwarm:
|
||||
health_results = await asyncio.gather(*health_checks, return_exceptions=True)
|
||||
|
||||
healthy_peers = []
|
||||
peer_metrics_info = []
|
||||
|
||||
for peer, healthy in zip(peers, health_results):
|
||||
peer_info = {
|
||||
"name": peer.name,
|
||||
"healthy": healthy is True,
|
||||
"timeout": peer.timeout_seconds,
|
||||
"model": peer.model_id,
|
||||
"instances": peer.instances,
|
||||
"metrics": {
|
||||
"success_rate": peer.metrics.success_rate,
|
||||
"avg_latency_ms": round(peer.metrics.avg_latency_ms, 2),
|
||||
"total_requests": peer.metrics.total_requests,
|
||||
"last_error": peer.metrics.last_error,
|
||||
}
|
||||
}
|
||||
peer_metrics_info.append(peer_info)
|
||||
|
||||
if healthy is True:
|
||||
healthy_peers.append(peer.name)
|
||||
|
||||
@@ -317,6 +416,7 @@ class FederatedSwarm:
|
||||
"total_peers": len(peers),
|
||||
"healthy_peers": len(healthy_peers),
|
||||
"peer_names": [p.name for p in peers],
|
||||
"peer_details": peer_metrics_info,
|
||||
"strategy": self.consensus_strategy
|
||||
}
|
||||
|
||||
|
||||
+21
-3
@@ -232,7 +232,7 @@ class SwarmManager:
|
||||
response = await worker.generate_with_progress(request)
|
||||
responses.append(response)
|
||||
if not self.mcp_mode:
|
||||
print(f" ✓ {worker.name} completed ({response.tokens_generated} tokens)")
|
||||
print(f" ✓ {worker.name} completed ({response.tokens_generated} tokens, {response.tokens_per_second:.1f} t/s)")
|
||||
except Exception as e:
|
||||
responses.append(e)
|
||||
if not self.mcp_mode:
|
||||
@@ -283,6 +283,11 @@ class SwarmManager:
|
||||
|
||||
if not self.mcp_mode:
|
||||
print(f" Got {len(valid_responses)} valid responses")
|
||||
|
||||
# Print performance summary
|
||||
print(f"\n 📊 Performance Summary:")
|
||||
for i, resp in enumerate(valid_responses, 1):
|
||||
print(f" Worker {i}: {resp.tokens_generated} tokens @ {resp.tokens_per_second:.1f} t/s ({resp.latency_ms:.0f}ms)")
|
||||
|
||||
# Run consensus
|
||||
result = await self.consensus.select_best(valid_responses)
|
||||
@@ -352,13 +357,21 @@ class SwarmManager:
|
||||
if not self.mcp_mode:
|
||||
print(f"🔄 Starting stream from {fastest_worker.name}...")
|
||||
chunk_count = 0
|
||||
total_chars = 0
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
async for chunk in fastest_worker.generate_with_progress_stream(request):
|
||||
chunk_count += 1
|
||||
total_chars += len(chunk)
|
||||
if not self.mcp_mode and chunk_count % 50 == 0: # Print progress every 50 chunks
|
||||
print(f" Streamed {chunk_count} chunks...")
|
||||
yield chunk
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
duration = end_time - start_time
|
||||
# Estimate tokens (roughly 4 chars per token)
|
||||
estimated_tokens = total_chars // 4
|
||||
tps = estimated_tokens / duration if duration > 0 else 0
|
||||
if not self.mcp_mode:
|
||||
print(f" Stream complete: {chunk_count} chunks total")
|
||||
print(f" Stream complete: {chunk_count} chunks, {estimated_tokens} tokens, {tps:.1f} t/s")
|
||||
|
||||
def get_status(self) -> SwarmStatus:
|
||||
"""Get current swarm status."""
|
||||
@@ -494,7 +507,7 @@ class SwarmManager:
|
||||
try:
|
||||
response = await worker.generate_with_progress(request)
|
||||
responses.append(response)
|
||||
print(f" ✓ Response {i+1} completed ({response.tokens_generated} tokens)")
|
||||
print(f" ✓ Response {i+1} completed ({response.tokens_generated} tokens, {response.tokens_per_second:.1f} t/s)")
|
||||
except Exception as e:
|
||||
responses.append(e)
|
||||
print(f" ✗ Response {i+1} failed: {e}")
|
||||
@@ -513,6 +526,11 @@ class SwarmManager:
|
||||
|
||||
print(f" Got {len(valid_responses)} valid responses")
|
||||
|
||||
# Print performance summary
|
||||
print(f"\n 📊 Performance Summary:")
|
||||
for i, resp in enumerate(valid_responses, 1):
|
||||
print(f" Seed {i}: {resp.tokens_generated} tokens @ {resp.tokens_per_second:.1f} t/s ({resp.latency_ms:.0f}ms)")
|
||||
|
||||
# Run consensus
|
||||
result = await self.consensus.select_best(valid_responses)
|
||||
print(f" Selected response using '{result.strategy}' strategy (confidence: {result.confidence:.2f})")
|
||||
|
||||
@@ -66,25 +66,19 @@ class StatusMonitor:
|
||||
if not self.swarm_manager or not self.swarm_manager.workers:
|
||||
return
|
||||
|
||||
# Clear previous display
|
||||
self._clear_display()
|
||||
|
||||
# Get worker status
|
||||
workers = self.swarm_manager.workers
|
||||
generating_workers = [w for w in workers if w._is_generating]
|
||||
|
||||
if not generating_workers:
|
||||
# No active generation, show minimal status
|
||||
lines = []
|
||||
lines.append("📊 Workers Idle")
|
||||
for w in workers:
|
||||
status = "🟢" if w.is_healthy() else "🔴"
|
||||
ip_str = f" [{w._ip_address}]" if w._is_remote else ""
|
||||
lines.append(f" {status} {w.name}{ip_str}: Idle")
|
||||
|
||||
self._print_lines(lines)
|
||||
# No active generation, clear display and return (don't spam "Workers Idle")
|
||||
if self._last_lines > 0:
|
||||
self._clear_display()
|
||||
return
|
||||
|
||||
# Clear previous display
|
||||
self._clear_display()
|
||||
|
||||
# Active generation - show detailed status
|
||||
lines = []
|
||||
lines.append(f"⚡ {len(generating_workers)} Worker{'s' if len(generating_workers) > 1 else ''} Active")
|
||||
|
||||
+34
-25
@@ -5,12 +5,14 @@ Remote execution allows a single "tool host" to manage the workspace
|
||||
while workers perform distributed generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import aiohttp
|
||||
from typing import Optional
|
||||
|
||||
from utils.project_discovery import discover_project_root
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -84,29 +86,7 @@ class ToolExecutor:
|
||||
logger.debug(f" ❌ Error contacting tool host: {e}")
|
||||
return f"Error contacting tool host: {str(e)}"
|
||||
|
||||
def _discover_project_root(self, start_dir: Optional[str] = None) -> str:
|
||||
"""Discover the project root directory by looking for common markers."""
|
||||
import os
|
||||
if start_dir is None:
|
||||
start_dir = os.getcwd()
|
||||
current = os.path.abspath(start_dir)
|
||||
|
||||
# Common project root markers
|
||||
markers = ['.git', 'package.json', 'pyproject.toml', 'Cargo.toml', 'go.mod',
|
||||
'requirements.txt', 'setup.py', 'pom.xml', 'build.gradle', '.project', '.venv']
|
||||
|
||||
while True:
|
||||
try:
|
||||
if any(os.path.exists(os.path.join(current, marker)) for marker in markers):
|
||||
return current
|
||||
except Exception:
|
||||
pass # Permission errors, just skip
|
||||
parent = os.path.dirname(current)
|
||||
if parent == current: # Reached filesystem root
|
||||
break
|
||||
current = parent
|
||||
|
||||
return start_dir
|
||||
|
||||
|
||||
async def _execute_local(self, tool_name: str, tool_args: dict) -> str:
|
||||
"""Execute tool locally."""
|
||||
@@ -117,6 +97,8 @@ class ToolExecutor:
|
||||
return await self._execute_write(tool_args)
|
||||
elif tool_name == "bash":
|
||||
return await self._execute_bash(tool_args)
|
||||
elif tool_name == "webfetch":
|
||||
return await self._execute_webfetch(tool_args)
|
||||
elif tool_name == "question":
|
||||
return f"Question: {tool_args}"
|
||||
elif tool_name == "skill":
|
||||
@@ -127,7 +109,7 @@ class ToolExecutor:
|
||||
return "Current todo list: (empty)"
|
||||
else:
|
||||
return f"Tool '{tool_name}' not implemented"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing {tool_name}: {str(e)}"
|
||||
|
||||
@@ -328,7 +310,34 @@ class ToolExecutor:
|
||||
logger.debug(f" 📄 Partial output (last 500 chars): ...{partial_output[-500:]}")
|
||||
|
||||
return f"Error executing bash: {error_msg}"
|
||||
|
||||
|
||||
async def _execute_webfetch(self, args: dict) -> str:
|
||||
"""Execute webfetch tool."""
|
||||
url = args.get("url", "")
|
||||
format = args.get("format", "text") # Default to text
|
||||
|
||||
if not url:
|
||||
return "Error: url required"
|
||||
|
||||
logger.debug(f" 🌐 Fetching: {url[:100]}... (format: {format})")
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status == 200:
|
||||
content = await resp.text()
|
||||
logger.debug(f" ✓ Fetched {len(content)} chars")
|
||||
return content
|
||||
else:
|
||||
logger.debug(f" ❌ HTTP {resp.status}: {url[:100]}")
|
||||
return f"Error: HTTP {resp.status} from {url[:100]}"
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f" ⏰ Timeout fetching: {url[:100]}")
|
||||
return f"Error: Timeout fetching {url[:100]} (30s)"
|
||||
except Exception as e:
|
||||
logger.debug(f" ❌ Error: {e}")
|
||||
return f"Error fetching {url[:100]}: {str(e)}"
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP session."""
|
||||
if self._session:
|
||||
|
||||
@@ -7,11 +7,11 @@ import logging
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logging(level=logging.DEBUG):
|
||||
def setup_logging(level=logging.INFO):
|
||||
"""Set up logging configuration.
|
||||
|
||||
Args:
|
||||
level: Logging level (default: DEBUG for development)
|
||||
level: Logging level (default: INFO)
|
||||
"""
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Network utilities for Local Swarm."""
|
||||
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_local_ip() -> str:
|
||||
"""Get the local network IP address (private networks only).
|
||||
|
||||
Returns:
|
||||
Local IP address or 127.0.0.1 if detection fails
|
||||
"""
|
||||
try:
|
||||
# Create a socket and connect to a public DNS server
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.settimeout(2)
|
||||
# Try to connect to Google's DNS - this doesn't actually send data
|
||||
s.connect(("8.8.8.8", 80))
|
||||
ip = s.getsockname()[0]
|
||||
s.close()
|
||||
|
||||
# Check if it's a private IP
|
||||
is_private = ip.startswith('192.168.')
|
||||
|
||||
if is_private:
|
||||
print(f" 📡 Detected local IP: {ip}")
|
||||
return ip
|
||||
else:
|
||||
print(f" ⚠️ IP {ip} is not a private network, binding to localhost")
|
||||
return "127.0.0.1"
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Could not detect local IP: {e}, using localhost")
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
def is_private_ip(ip: str) -> bool:
|
||||
"""Check if an IP address is private.
|
||||
|
||||
Args:
|
||||
ip: IP address string
|
||||
|
||||
Returns:
|
||||
True if IP is private
|
||||
"""
|
||||
return ip.startswith('192.168.') or ip.startswith('10.') or ip.startswith('172.16.')
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Project root discovery utilities.
|
||||
|
||||
Provides functionality to discover project root directories.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
# Common project root markers
|
||||
DEFAULT_MARKERS = [
|
||||
'.git', 'package.json', 'pyproject.toml', 'Cargo.toml', 'go.mod',
|
||||
'requirements.txt', 'setup.py', 'pom.xml', 'build.gradle', '.project', '.venv'
|
||||
]
|
||||
|
||||
|
||||
def discover_project_root(
|
||||
start_dir: Optional[str] = None,
|
||||
markers: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""Discover the project root directory by looking for common markers.
|
||||
|
||||
Args:
|
||||
start_dir: Directory to start searching from (defaults to cwd)
|
||||
markers: List of marker files/directories to look for (defaults to DEFAULT_MARKERS)
|
||||
|
||||
Returns:
|
||||
Path to project root, or start_dir if no markers found
|
||||
"""
|
||||
if start_dir is None:
|
||||
start_dir = os.getcwd()
|
||||
|
||||
if markers is None:
|
||||
markers = DEFAULT_MARKERS
|
||||
|
||||
current = os.path.abspath(start_dir)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if any(os.path.exists(os.path.join(current, marker)) for marker in markers):
|
||||
return current
|
||||
except (OSError, PermissionError):
|
||||
pass # Permission errors, just skip
|
||||
|
||||
parent = os.path.dirname(current)
|
||||
if parent == current: # Reached filesystem root
|
||||
break
|
||||
current = parent
|
||||
|
||||
return start_dir
|
||||
|
||||
|
||||
def is_within_project(path: str, project_root: str) -> bool:
|
||||
"""Check if a path is within a project root.
|
||||
|
||||
Args:
|
||||
path: Path to check
|
||||
project_root: Project root directory
|
||||
|
||||
Returns:
|
||||
True if path is within project root
|
||||
"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
real_root = os.path.realpath(project_root)
|
||||
return real_path.startswith(real_root)
|
||||
except (OSError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def get_relative_to_project(path: str, project_root: str) -> str:
|
||||
"""Get path relative to project root.
|
||||
|
||||
Args:
|
||||
path: Absolute or relative path
|
||||
project_root: Project root directory
|
||||
|
||||
Returns:
|
||||
Path relative to project root
|
||||
"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
real_root = os.path.realpath(project_root)
|
||||
return os.path.relpath(real_path, real_root)
|
||||
except (OSError, ValueError):
|
||||
return path
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Token counting utilities for Local Swarm.
|
||||
|
||||
Centralizes token counting functionality to avoid duplication across modules.
|
||||
"""
|
||||
|
||||
import tiktoken
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Initialize tokenizer for accurate token counting
|
||||
TOKEN_ENCODING = tiktoken.get_encoding('cl100k_base')
|
||||
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Count tokens in a text string using tiktoken.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Number of tokens
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
return len(TOKEN_ENCODING.encode(text))
|
||||
|
||||
|
||||
def count_tokens_in_messages(messages: list) -> int:
|
||||
"""Count tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects with content attribute
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total = 0
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'content') and msg.content:
|
||||
total += count_tokens(msg.content)
|
||||
return total
|
||||
|
||||
|
||||
def estimate_tokens_from_characters(char_count: int, chars_per_token: int = 4) -> int:
|
||||
"""Estimate token count from character count.
|
||||
|
||||
This is a fallback when tiktoken is not available or for quick estimates.
|
||||
|
||||
Args:
|
||||
char_count: Number of characters
|
||||
chars_per_token: Average characters per token (default 4)
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
return char_count // chars_per_token
|
||||
|
||||
|
||||
def truncate_to_max_tokens(text: str, max_tokens: int) -> str:
|
||||
"""Truncate text to fit within max tokens.
|
||||
|
||||
Args:
|
||||
text: Text to truncate
|
||||
max_tokens: Maximum number of tokens allowed
|
||||
|
||||
Returns:
|
||||
Truncated text
|
||||
"""
|
||||
tokens = TOKEN_ENCODING.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
truncated = tokens[:max_tokens]
|
||||
return TOKEN_ENCODING.decode(truncated)
|
||||
|
||||
|
||||
def format_token_info(prompt_tokens: int, completion_tokens: int) -> dict:
|
||||
"""Format token information for responses.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
Returns:
|
||||
Dictionary with token counts
|
||||
"""
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
# Patch to add real-time streaming for tools
|
||||
|
||||
# This patch adds real-time streaming of assistant content ("thinking") and tool calls
|
||||
# when tools are used. Previously, all content was buffered until complete,
|
||||
# causing opencode to wait with no feedback.
|
||||
|
||||
# Key changes:
|
||||
# 1. Stream model output incrementally as it's generated
|
||||
# 2. Parse for tool_calls and content in each chunk
|
||||
# 3. Send content chunks immediately (the "thinking")
|
||||
# 4. Send tool_calls deltas immediately when found
|
||||
# 5. Don't execute tools server-side in streaming mode
|
||||
# 6. Send DONE marker at end
|
||||
|
||||
# Apply this patch with:
|
||||
# patch -p1 < this_file src/api/routes.py
|
||||
@@ -0,0 +1,63 @@
|
||||
## Test Plan for CUDA and Android Support
|
||||
|
||||
### Unit Tests
|
||||
|
||||
#### Test Case 1: NVIDIA GPU Detection
|
||||
- **Input:** System with NVIDIA GPU and pynvml installed
|
||||
- **Expected Output:** GPUInfo with correct name, VRAM, and is_nvidia=True
|
||||
- **Location:** src/hardware/detector.py:detect_nvidia_gpu()
|
||||
|
||||
#### Test Case 2: GPU Layer Configuration for CUDA
|
||||
- **Input:** HardwareProfile with NVIDIA GPU (4GB VRAM)
|
||||
- **Expected Output:** n_gpu_layers=-1 (all layers), proper CUDA configuration
|
||||
- **Location:** src/backends/__init__.py:create_backend()
|
||||
|
||||
#### Test Case 3: Android Platform Detection
|
||||
- **Input:** platform.system() returns 'Linux', Termux environment detected
|
||||
- **Expected Output:** is_android=True, proper Android path handling
|
||||
- **Location:** src/hardware/detector.py:detect_android()
|
||||
|
||||
#### Test Case 4: PeerInfo with Timeout
|
||||
- **Input:** PeerInfo with custom timeout
|
||||
- **Expected Output:** FederationClient respects peer timeout
|
||||
- **Location:** src/network/discovery.py:PeerInfo
|
||||
|
||||
### Integration Tests
|
||||
|
||||
#### End-to-End Flow 1: CUDA Backend Creation
|
||||
1. Detect hardware with NVIDIA GPU
|
||||
2. Create backend via factory
|
||||
3. Verify n_gpu_layers=-1 set
|
||||
4. Load test model
|
||||
5. Expected: Successful GPU offload
|
||||
|
||||
#### End-to-End Flow 2: Android Device Join Federation
|
||||
1. Start discovery on Android (Termux)
|
||||
2. Advertise Android hardware
|
||||
3. Join federation from macOS peer
|
||||
4. Send vote request
|
||||
5. Expected: Android responds successfully
|
||||
|
||||
#### End-to-End Flow 3: Federation with Per-Peer Timeout
|
||||
1. Add peer with 30s timeout
|
||||
2. Add peer with 60s timeout
|
||||
3. Request votes from both
|
||||
4. Expected: Each peer uses its own timeout
|
||||
|
||||
### Manual Verification
|
||||
|
||||
#### Command to Run:
|
||||
```bash
|
||||
python -m pytest tests/ -v -k "cuda or android or federation"
|
||||
```
|
||||
|
||||
#### Expected Output:
|
||||
- All tests pass
|
||||
- No ImportError for pynvml
|
||||
- GPU layer detection works on CUDA machines
|
||||
- Android detection passes on Termux
|
||||
|
||||
#### Platform Testing:
|
||||
1. **macOS (Apple Silicon):** MLX backend loads
|
||||
2. **Linux (NVIDIA):** CUDA backend auto-detects
|
||||
3. **Android (Termux):** CPU-only mode, proper paths
|
||||
@@ -0,0 +1,166 @@
|
||||
"""Tests for federation metrics and peer timeout."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from network.discovery import PeerInfo, PeerMetrics
|
||||
from network.federation import FederationClient, PeerVote
|
||||
|
||||
|
||||
class TestPeerMetrics:
|
||||
"""Test peer metrics tracking."""
|
||||
|
||||
def test_peer_metrics_defaults(self):
|
||||
"""Test default metric values."""
|
||||
metrics = PeerMetrics()
|
||||
assert metrics.total_requests == 0
|
||||
assert metrics.successful_requests == 0
|
||||
assert metrics.failed_requests == 0
|
||||
assert metrics.success_rate == 1.0 # No requests = 100% success
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful requests."""
|
||||
metrics = PeerMetrics()
|
||||
metrics.record_success(100.0)
|
||||
|
||||
assert metrics.total_requests == 1
|
||||
assert metrics.successful_requests == 1
|
||||
assert metrics.failed_requests == 0
|
||||
assert metrics.success_rate == 1.0
|
||||
assert metrics.avg_latency_ms == 100.0
|
||||
|
||||
# Record another success
|
||||
metrics.record_success(200.0)
|
||||
assert metrics.total_requests == 2
|
||||
assert metrics.avg_latency_ms == 150.0 # (100 + 200) / 2
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed requests."""
|
||||
metrics = PeerMetrics()
|
||||
metrics.record_failure("Connection timeout")
|
||||
|
||||
assert metrics.total_requests == 1
|
||||
assert metrics.successful_requests == 0
|
||||
assert metrics.failed_requests == 1
|
||||
assert metrics.success_rate == 0.0
|
||||
assert metrics.last_error == "Connection timeout"
|
||||
assert metrics.last_error_time is not None
|
||||
|
||||
def test_mixed_success_and_failure(self):
|
||||
"""Test mixed success and failure recording."""
|
||||
metrics = PeerMetrics()
|
||||
metrics.record_success(100.0)
|
||||
metrics.record_failure("Error")
|
||||
metrics.record_success(150.0)
|
||||
|
||||
assert metrics.total_requests == 3
|
||||
assert metrics.successful_requests == 2
|
||||
assert metrics.failed_requests == 1
|
||||
assert metrics.success_rate == 2/3
|
||||
|
||||
|
||||
class TestPeerInfo:
|
||||
"""Test PeerInfo with metrics and timeout."""
|
||||
|
||||
def test_peer_info_defaults(self):
|
||||
"""Test PeerInfo default values."""
|
||||
peer = PeerInfo(
|
||||
host="192.168.1.100",
|
||||
port=17615,
|
||||
name="test-peer",
|
||||
version="0.1.0",
|
||||
instances=2,
|
||||
model_id="qwen:7b:q4",
|
||||
hardware_summary="Apple M1 Pro",
|
||||
last_seen=datetime.now()
|
||||
)
|
||||
|
||||
assert peer.timeout_seconds == 60.0 # Default timeout
|
||||
assert peer.metrics is not None
|
||||
assert isinstance(peer.metrics, PeerMetrics)
|
||||
assert peer.api_url == "http://192.168.1.100:17615"
|
||||
|
||||
def test_peer_info_custom_timeout(self):
|
||||
"""Test PeerInfo with custom timeout."""
|
||||
peer = PeerInfo(
|
||||
host="192.168.1.100",
|
||||
port=17615,
|
||||
name="slow-peer",
|
||||
version="0.1.0",
|
||||
instances=1,
|
||||
model_id="test-model",
|
||||
hardware_summary="CPU only",
|
||||
last_seen=datetime.now(),
|
||||
timeout_seconds=120.0 # Custom timeout
|
||||
)
|
||||
|
||||
assert peer.timeout_seconds == 120.0
|
||||
|
||||
|
||||
class TestFederationClient:
|
||||
"""Test FederationClient with peer-specific timeouts."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return FederationClient(timeout=60.0)
|
||||
|
||||
@pytest.fixture
|
||||
def fast_peer(self):
|
||||
return PeerInfo(
|
||||
host="192.168.1.10",
|
||||
port=17615,
|
||||
name="fast-peer",
|
||||
version="0.1.0",
|
||||
instances=2,
|
||||
model_id="qwen:7b:q4",
|
||||
hardware_summary="Apple M1 Max",
|
||||
last_seen=datetime.now(),
|
||||
timeout_seconds=30.0 # Fast peer with short timeout
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def slow_peer(self):
|
||||
return PeerInfo(
|
||||
host="192.168.1.11",
|
||||
port=17615,
|
||||
name="slow-peer",
|
||||
version="0.1.0",
|
||||
instances=1,
|
||||
model_id="qwen:7b:q4",
|
||||
hardware_summary="CPU only",
|
||||
last_seen=datetime.now(),
|
||||
timeout_seconds=90.0 # Slow peer with longer timeout
|
||||
)
|
||||
|
||||
def test_peer_timeout_override(self, client, fast_peer, slow_peer):
|
||||
"""Test that peer-specific timeout overrides default."""
|
||||
# The client should use the peer's timeout, not the default
|
||||
assert fast_peer.timeout_seconds == 30.0
|
||||
assert slow_peer.timeout_seconds == 90.0
|
||||
assert client.timeout == 60.0 # Default unchanged
|
||||
|
||||
def test_metrics_updated_on_success(self, fast_peer):
|
||||
"""Test that metrics are updated on successful request."""
|
||||
assert fast_peer.metrics.total_requests == 0
|
||||
|
||||
# Simulate recording a success (this would happen in request_vote)
|
||||
fast_peer.metrics.record_success(150.0)
|
||||
|
||||
assert fast_peer.metrics.total_requests == 1
|
||||
assert fast_peer.metrics.successful_requests == 1
|
||||
assert fast_peer.metrics.success_rate == 1.0
|
||||
|
||||
def test_metrics_updated_on_failure(self, slow_peer):
|
||||
"""Test that metrics are updated on failed request."""
|
||||
assert slow_peer.metrics.total_requests == 0
|
||||
|
||||
# Simulate recording a failure
|
||||
slow_peer.metrics.record_failure("Connection refused")
|
||||
|
||||
assert slow_peer.metrics.total_requests == 1
|
||||
assert slow_peer.metrics.failed_requests == 1
|
||||
assert slow_peer.metrics.success_rate == 0.0
|
||||
assert slow_peer.metrics.last_error == "Connection refused"
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Tests for hardware detection and GPU layer configuration."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from hardware.detector import (
|
||||
GPUInfo, HardwareProfile, detect_nvidia_gpu,
|
||||
calculate_gpu_layers, validate_gpu_layers, is_android
|
||||
)
|
||||
|
||||
|
||||
class TestNvidiaGPU:
|
||||
"""Test NVIDIA GPU detection."""
|
||||
|
||||
def test_detect_nvidia_gpu_success(self):
|
||||
"""Test successful NVIDIA GPU detection."""
|
||||
# Mock the entire import system
|
||||
mock_pynvml = Mock()
|
||||
mock_pynvml.nvmlInit = Mock()
|
||||
mock_pynvml.nvmlShutdown = Mock()
|
||||
mock_pynvml.nvmlDeviceGetCount = Mock(return_value=1)
|
||||
|
||||
# Mock device handle and info
|
||||
mock_handle = Mock()
|
||||
mock_pynvml.nvmlDeviceGetHandleByIndex = Mock(return_value=mock_handle)
|
||||
mock_pynvml.nvmlDeviceGetName = Mock(return_value="NVIDIA GeForce RTX 3080")
|
||||
|
||||
# Mock memory info
|
||||
mock_mem = Mock()
|
||||
mock_mem.total = 10737418240 # 10 GB
|
||||
mock_pynvml.nvmlDeviceGetMemoryInfo = Mock(return_value=mock_mem)
|
||||
|
||||
# Mock driver version
|
||||
mock_pynvml.nvmlSystemGetDriverVersion = Mock(return_value="535.104.05")
|
||||
|
||||
# Mock compute capability
|
||||
mock_pynvml.nvmlDeviceGetCudaComputeCapability = Mock(return_value=(8, 6))
|
||||
|
||||
# Patch __import__ to return our mock
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == 'pynvml':
|
||||
return mock_pynvml
|
||||
return __builtins__.__import__(name, *args, **kwargs)
|
||||
|
||||
with patch('builtins.__import__', side_effect=mock_import):
|
||||
gpu = detect_nvidia_gpu()
|
||||
|
||||
assert gpu is not None
|
||||
assert gpu.name == "NVIDIA GeForce RTX 3080"
|
||||
assert gpu.vram_gb == 10.0
|
||||
assert gpu.driver_version == "535.104.05"
|
||||
assert gpu.is_nvidia is True
|
||||
assert gpu.compute_capability == "8.6"
|
||||
assert gpu.device_count == 1
|
||||
|
||||
def test_detect_nvidia_gpu_no_gpu(self):
|
||||
"""Test detection when no NVIDIA GPU present."""
|
||||
mock_pynvml = Mock()
|
||||
mock_pynvml.nvmlInit = Mock()
|
||||
mock_pynvml.nvmlShutdown = Mock()
|
||||
mock_pynvml.nvmlDeviceGetCount = Mock(return_value=0)
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == 'pynvml':
|
||||
return mock_pynvml
|
||||
return __builtins__.__import__(name, *args, **kwargs)
|
||||
|
||||
with patch('builtins.__import__', side_effect=mock_import):
|
||||
gpu = detect_nvidia_gpu()
|
||||
|
||||
assert gpu is None
|
||||
|
||||
def test_detect_nvidia_gpu_import_error(self):
|
||||
"""Test detection when pynvml is not installed."""
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == 'pynvml':
|
||||
raise ImportError("No module named 'pynvml'")
|
||||
return __builtins__.__import__(name, *args, **kwargs)
|
||||
|
||||
with patch('builtins.__import__', side_effect=mock_import):
|
||||
gpu = detect_nvidia_gpu()
|
||||
|
||||
assert gpu is None
|
||||
|
||||
|
||||
class TestGPULayerCalculation:
|
||||
"""Test GPU layer auto-configuration."""
|
||||
|
||||
def test_calculate_gpu_layers_apple_silicon(self):
|
||||
"""Test layer calculation for Apple Silicon."""
|
||||
gpu = GPUInfo(
|
||||
name="Apple Silicon GPU",
|
||||
vram_gb=32.0,
|
||||
is_apple_silicon=True
|
||||
)
|
||||
assert calculate_gpu_layers(gpu) == -1
|
||||
|
||||
def test_calculate_gpu_layers_nvidia(self):
|
||||
"""Test layer calculation for NVIDIA GPU."""
|
||||
gpu = GPUInfo(
|
||||
name="NVIDIA GeForce RTX 3080",
|
||||
vram_gb=10.0,
|
||||
is_nvidia=True,
|
||||
compute_capability="8.6"
|
||||
)
|
||||
assert calculate_gpu_layers(gpu) == -1
|
||||
|
||||
def test_calculate_gpu_layers_old_nvidia(self):
|
||||
"""Test layer calculation for old NVIDIA GPU."""
|
||||
gpu = GPUInfo(
|
||||
name="NVIDIA GeForce GTX 680",
|
||||
vram_gb=2.0,
|
||||
is_nvidia=True,
|
||||
compute_capability="3.0"
|
||||
)
|
||||
assert calculate_gpu_layers(gpu) == 0 # Too old
|
||||
|
||||
def test_calculate_gpu_layers_no_gpu(self):
|
||||
"""Test layer calculation with no GPU."""
|
||||
assert calculate_gpu_layers(None) == 0
|
||||
|
||||
def test_validate_gpu_layers_success(self):
|
||||
"""Test successful layer validation."""
|
||||
gpu = GPUInfo(
|
||||
name="NVIDIA GeForce RTX 3080",
|
||||
vram_gb=10.0,
|
||||
is_nvidia=True,
|
||||
compute_capability="8.6"
|
||||
)
|
||||
assert validate_gpu_layers(-1, gpu) == -1
|
||||
|
||||
def test_validate_gpu_layers_no_gpu_error(self):
|
||||
"""Test validation error when GPU requested but none available."""
|
||||
with pytest.raises(ValueError, match="no GPU detected"):
|
||||
validate_gpu_layers(-1, None)
|
||||
|
||||
def test_validate_gpu_layers_old_gpu_error(self):
|
||||
"""Test validation error for unsupported GPU."""
|
||||
gpu = GPUInfo(
|
||||
name="NVIDIA GeForce GTX 680",
|
||||
vram_gb=2.0,
|
||||
is_nvidia=True,
|
||||
compute_capability="3.0"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Minimum required is 5.0"):
|
||||
validate_gpu_layers(-1, gpu)
|
||||
|
||||
|
||||
class TestAndroidDetection:
|
||||
"""Test Android platform detection."""
|
||||
|
||||
@patch.dict('os.environ', {'ANDROID_ROOT': '/system'}, clear=True)
|
||||
@patch('os.path.exists')
|
||||
def test_is_android_env_var(self, mock_exists):
|
||||
"""Test Android detection via environment variables."""
|
||||
mock_exists.return_value = False
|
||||
assert is_android() is True
|
||||
|
||||
@patch.dict('os.environ', {}, clear=True)
|
||||
@patch('os.path.exists')
|
||||
def test_is_android_paths(self, mock_exists):
|
||||
"""Test Android detection via filesystem paths."""
|
||||
def exists_side_effect(path):
|
||||
return path == "/system/build.prop"
|
||||
mock_exists.side_effect = exists_side_effect
|
||||
assert is_android() is True
|
||||
|
||||
@patch.dict('os.environ', {}, clear=True)
|
||||
@patch('os.path.exists')
|
||||
def test_is_not_android(self, mock_exists):
|
||||
"""Test non-Android system."""
|
||||
mock_exists.return_value = False
|
||||
assert is_android() is False
|
||||
@@ -133,7 +133,7 @@ ls -la
|
||||
|
||||
def test_tool_instructions_content():
|
||||
"""Test that tool instructions contain required sections (REVIEW-2026-02-24 Blocker #4)."""
|
||||
from api.routes import _load_tool_instructions
|
||||
from api.formatting import _load_tool_instructions
|
||||
|
||||
# Load instructions from config file
|
||||
instructions = _load_tool_instructions()
|
||||
@@ -147,7 +147,7 @@ def test_tool_instructions_content():
|
||||
|
||||
def test_tool_instructions_token_count():
|
||||
"""Test that tool instructions are within token budget (REVIEW-2026-02-24 Blocker #1)."""
|
||||
from api.routes import _load_tool_instructions
|
||||
from api.formatting import _load_tool_instructions
|
||||
|
||||
# Load instructions from config file
|
||||
instructions = _load_tool_instructions()
|
||||
|
||||
Reference in New Issue
Block a user