Fix opencode integration: streaming, response format, and tool handling

- Fix streaming to work even when tools are present (was forcing JSON mode)
- Fix response format: use empty list [] instead of null for tool_calls
- Add exclude_none config to ChatMessage model to match OpenAI format
- Remove tool instructions from prompt (were confusing 3B model)
- Fix tool call parsing to handle markdown code blocks properly
- Change default instances from 3 to 1 for faster debugging
- Allow 1 instance minimum in interactive config (was 2 on Mac)
- Add debug logging to track requests and responses

Fixes infinite loop issue where opencode would retry requests repeatedly
This commit is contained in:
2026-02-24 03:44:46 +01:00
parent 2461f45ca8
commit d30eedaa63
10 changed files with 230 additions and 657 deletions
-30
View File
@@ -34,35 +34,5 @@
{"t":"progress","c":33,"n":33,"f":"tests/__init__.py"} {"t":"progress","c":33,"n":33,"f":"tests/__init__.py"}
{"t":"done","indexed":0,"skipped":33,"total":33} {"t":"done","indexed":0,"skipped":33,"total":33}
{"t":"watch","files":33} {"t":"watch","files":33}
{"t":"reindex","f":"src/swarm/manager.py","s":0}
{"t":"reindex","f":"src/swarm/manager.py","s":0}
{"t":"reindex","f":"src/swarm/manager.py","s":0}
{"t":"reindex","f":"src/swarm/manager.py","s":0}
{"t":"reindex","f":"src/swarm/manager.py","s":0}
{"t":"watch","files":33}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"watch","files":33}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"watch","files":33}
{"t":"reindex","f":"src/models/selector.py","s":0} {"t":"reindex","f":"src/models/selector.py","s":0}
{"t":"reindex","f":"src/models/selector.py","s":0} {"t":"reindex","f":"src/models/selector.py","s":0}
{"t":"reindex","f":"src/models/selector.py","s":0}
{"t":"watch","files":33}
{"t":"reindex","f":"src/models/downloader.py","s":0}
{"t":"watch","files":33}
{"t":"reindex","f":"src/models/registry.py","s":0}
{"t":"reindex","f":"src/models/registry.py","s":0}
{"t":"watch","files":33}
{"t":"reindex","f":"src/models/registry.py","s":0}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"reindex","f":"src/interactive.py","s":0}
{"t":"watch","files":33}
{"t":"reindex","f":"src/models/selector.py","s":0}
{"t":"watch","files":33}
{"t":"reindex","f":"src/models/selector.py","s":0}
-378
View File
@@ -1,378 +0,0 @@
# Agent Guidelines for Code Graph Project
This is a **fast code graph creation project** with MCP (Model Context Protocol) integration for AI-assisted development. The goal is to build a git-aware indexing system that enables semantic code search.
**Stack**: Bun runtime, TypeScript compiler API + tree-sitter (parsing), SurrealDB embedded (storage), MCP protocol (AI tool interface).
---
## IMPORTANT: Use the code-graph MCP server for code navigation
**When working on this project, ALWAYS prefer the code-graph MCP tools over grep, glob, LSP, or other search methods for code navigation.** This is the project's own product — dogfooding it is how we find bugs and gaps.
Available MCP tools (via the `code-graph` server):
- **`search_code`** — Find symbols by name, kind, file, or export status. Use this instead of grep/glob for finding functions, classes, types, etc.
- **`search_pattern`** — AST structural pattern matching with `$VAR`/`$$$VAR` metavariables. Use for finding code patterns like `$FN($$$ARGS)`.
- **`search_regex`** — Regex search on AST nodes. Use for text patterns across the codebase.
- **`find_references`** — Find all references to a symbol (imports, calls, usage sites). Use instead of "find references" via grep.
- **`call_hierarchy`** — Show callers and callees of a function/method. Use to understand call chains.
- **`type_hierarchy`** — Show extends/implements tree. Use to understand class/interface relationships.
- **`resolve_symbol`** — Get full details of a specific symbol by name.
**Note:** The code-graph MCP server now supports Python codebases in addition to TypeScript/JavaScript. All tools above work with Python symbols, functions, classes, and imports.
**Even when other tools might seem faster or more convenient, use code-graph first.** If it fails or gives bad results, open a GitHub issue with `gh issue create --label bug --title "..." --body "..."` describing the query, expected result, and actual result. Then fall back to other tools to continue your work. Real usage is what makes the tool better.
If the index is stale: `rm -rf .git/fcg-index && bun src/cli/index.ts index .`
---
## Code Structure
- See [docs/structure.md](./docs/structure.md) for the full file tree and test file listing.
---
## Key Technical Decisions
### Runtime
- **Bun** is the primary runtime (v1.3.9+)
- All commands use `bun` not `node`
- SQLite available via `bun:sqlite`
- Default to using Bun instead of Node.js:
- `bun <file>`, `bun test`, `bun install`, `bun run <script>`, `bunx <pkg>`
- Prefer `Bun.file` over `node:fs` readFile/writeFile
- `Bun.$\`cmd\`` instead of execa
- Bun auto-loads .env — no dotenv needed
- `bun build --compile` does NOT work yet (NAPI addon can't bundle)
### Database
- **SurrealDB in embedded mode** - no external DB setup required
- Local graph storage for fast queries
- **DB location**: `.git/commondir/fcg-index` (shared across worktrees)
- **SurrealDB**: `mem://` for tests, `surrealkv:///path` for persistent. `SCHEMAFULL` tables. `option<string>` fields must be omitted (not null).
### Git Integration
- Must support **git worktrees** natively
- Uses **git object hashes** for deduplication
- **libgit2** via native Bun FFI (`cc()`) is the sole git backend
- FFI functions return complete NAPI arrays in single calls (no iterative/global state)
- **Future exploration**: Git Butler integration (virtual branches, stacking workflows)
- **Content-addressed**: Symbols keyed by git blob hash, not file path
### Code Parsing
- **TypeScript compiler** for TS files (default — full type info + all relations)
- **Tree-sitter** (web-tree-sitter WASM) for fast mode (`--fast` — ~5x faster, symbols + extends/implements only)
- **ParserAdapter interface** (`src/common/extractors/parser.ts`) abstracts over both backends
- **Transform executables** communicate over stdio (allows polyglot analyzers)
- Files to process passed as CLI args (or via `-f <filelist>` for large batches)
- Output is **NDJSON** by default (one JSON object per input file)
- A `--pretty` flag for human/agent-readable pretty-printed JSON output
- Input: file paths as args → Output: one JSON object per file on stdout
### Search Features
- AST-aware embeddings using **Voyage AI**
- Reranking for relevance
- Fuzzy search on function/variable names
- **Watch mode**: Git-driven incremental reindexing on file changes
### MCP Server
- **Dual Transport**: Supports both stdio (one instance per agent) and HTTP (one instance for multiple agents)
- **Shared Database**: Multiple MCP instances share the same SurrealDB for fast queries
- **Performance**: Both transports optimized for low-latency responses
- **Raw DB Access Tool**: An MCP tool that allows direct SurrealQL queries against the database. **Disabled by default** — must be explicitly enabled via config/flag. Paired with MCP resources that document the DB schema and how to write SurrealQL queries, so agents can self-serve when the built-in tools aren't enough.
### Developer Tools
- **Object Explorer**: Tree view of object properties with dot notation navigation
- **Property Introspection**: Type inference and documentation for object members
### TypeScript Integration
- **Type Storage**: Type information attached to individual functions/variables (per-symbol, not type graph)
- **IDE-Like Completion**: Property trees with types for Object Explorer
- **Documentation**: JSDoc/tsdoc comments preserved with symbols and searchable
### Embeddings Strategy
- **Code Embeddings**: Voyage AI code model for AST-aware code embeddings
- **Text Embeddings**:
- Separate text model for documentation/comments
- Function docs searchable separately from code embeddings
- **Library Data**: Graph-only (no embeddings), indexed by package hash (pnpm), types/docs from `.d.ts`
### Query Optimization
- **Search Intent Parameter**: Optional first parameter on MCP search functions for search goals
- **Fast Model Layer**: Lightweight model sees intent + results and optimizes output
- **Model Config**: OpenAI-compatible API (works with Ollama, llama.cpp, OpenRouter)
- **Intent-Based Reranking**: Results prioritized by search purpose
---
## Content-Addressed Storage (CRITICAL)
**Everything is keyed by blob hash (`objectHash`), NEVER by file path.** This is the single most important design principle.
- **Symbols**: keyed by `blob` (git blob SHA). Same content at different paths = stored once.
- **Relations** (`calls`, `references`, `imports`, `extends`, `implements`): keyed by `fromBlob`/`toBlob`.
- **`branch_file`**: maps `(branch, filePath) → blob`. This is how you find "what blobs exist on branch X".
- **`file_index`**: should be keyed by `objectHash` — it's a "has this blob been parsed?" cache.
- **File path is metadata**, not a key. It's stored for display/navigation but never used as a primary lookup.
**Why**: The DB is shared across worktrees. Two worktrees can have the same file path with different content (different blobs). If you key by filePath, they stomp each other. If you key by blob, they coexist.
**When querying**: Always go through `branch_file` to scope results to a branch's blobs, then look up symbols/relations by those blobs. Never query symbols by filePath directly for branch-scoped results.
```
branch "main": src/foo.ts → blob:⟨aaa⟩, src/bar.ts → blob:⟨bbb⟩
branch "feature": src/foo.ts → blob:⟨aaa⟩, src/bar.ts → blob:⟨ccc⟩
Symbols for blob:⟨aaa⟩: stored ONCE, visible to both branches
Symbols for blob:⟨bbb⟩: visible to main only
Symbols for blob:⟨ccc⟩: visible to feature only
```
---
## Key Patterns
- **Branded types**: `OID`, `RepoPath`, `RelPath`, `GitRef` — use helpers `rel()`, `oid()`, `ref()` in tests
- **Content-addressed**: Symbols keyed by git blob hash, not file path (see section above)
- **DB location**: `.git/commondir/fcg-index` (shared across worktrees)
- **SurrealDB**: `mem://` for tests, `surrealkv:///path` for persistent. `SCHEMAFULL` tables. `option<string>` fields must be omitted (not null).
- **Zod v4**: Use `.nonnegative()` not `.nonneg()`. Requires `esModuleInterop: true` in tsconfig.
---
## Development Workflow
```bash
# Install dependencies
bun install
# Run the project (default — TypeScript compiler, full type info)
bun run src/cli/index.ts index .
# Run the project (fast mode — tree-sitter, ~5x faster, symbols only)
bun run src/cli/index.ts index --fast .
# Run tests
bun test
```
### Pre-commit Checks
```bash
bunx tsc --noEmit # Type check
bun test # Run all tests
```
### Testing the MCP Server
Config lives in `mcp_servers.json` (gitignored). Test with `mcp-cli`:
```bash
# List all tools
mcp-cli -d
# Call a tool
mcp-cli call code-graph call_hierarchy '{"symbolName":"parseFile","direction":"both"}'
mcp-cli call code-graph find_references '{"symbolName":"Store","direction":"incoming"}'
mcp-cli call code-graph type_hierarchy '{"symbolName":"SurrealDBStore"}'
mcp-cli call code-graph search_code '{"query":"parseFile","fuzzy":false}'
# Show tool schema
mcp-cli info code-graph call_hierarchy
```
If index is stale after code changes, re-index first:
```bash
rm -rf .git/fcg-index && bun src/cli/index.ts index .
```
---
## Resolved Decisions
1. **Git library**: libgit2 native FFI via Bun `cc()` — isomorphic-git and wasm-git removed
2. **Language scope**: TypeScript-only (TS Compiler API + tree-sitter fast mode), tree-sitter planned for other languages
3. **Database**: SurrealDB embedded (SurrealKV), no external server needed
4. **Content addressing**: Symbols keyed by blob hash for multi-branch dedup
5. **Parser abstraction**: `ParserAdapter` interface with tree-sitter (`--fast`) and TS compiler (default) backends
---
## Architecture Principles
1. **Speed First**: Every design decision prioritizes search speed
2. **Git-Native**: Leverage git's object model for efficiency
3. **Extensible**: Transform pipeline allows language-specific tooling
4. **Embedded**: No external dependencies (SurrealDB embedded)
5. **MCP-First**: Built for AI assistant integration
6. **KISS & DRY**: Keep it simple, don't repeat yourself
7. **Independently Testable**: Every component testable in isolation without spinning up a DB
8. **Design by Contract**: Use contracts to make assumptions explicit and catch bugs early
---
## Design by Contract
All code uses contract utilities from `src/utils/contracts.ts`. Contracts make assumptions explicit and self-documenting.
### Utilities
- **`requires(condition, message)`** — Precondition. Call at the start of a function to validate inputs.
- **`invariant(condition, message)`** — General assertion. Use anywhere a condition must hold.
- **`satisfies(condition, message)`** — Postcondition. Call before returning to validate outputs.
All throw `ContractViolation` with the kind (`precondition`, `invariant`, `postcondition`) and message.
### Convention: `pre:` and `post:` Labels
Use JavaScript labeled statements to visually mark contract sections:
```ts
function buildIndex(files: GitFileInfo[], store: Store): IndexResult {
pre: requires(files.length > 0, "must have files to index");
const result = doIndexing(files, store);
post: satisfies(result.indexed <= files.length, "cannot index more files than given");
return result;
}
```
### Rules
- **Adapters**: Use `requires()` to validate inputs from external callers.
- **Core (pure functions)**: Use `requires()` for input constraints, `satisfies()` for output guarantees.
- **MCP tools**: Use `requires()` to validate params after Zod parsing if additional semantic checks are needed.
- Keep contract messages short and descriptive — they show up in error traces.
---
## Code Architecture
### Design Pattern: Pure Functions + Thin Adapters
The codebase is split into layers. The rule is simple:
- **`common/extractors/`** — Pure parsing/extraction functions. Data in, data out. Zero side effects, zero DB imports. Trivially testable. Also contains parser adapters (TypeScript compiler, tree-sitter).
- **`common/git/`** — Git adapters and pure git logic. Thin wrappers around libgit2 FFI. Implements a typed interface for git operations.
- **`common/db/`** — Database adapters (SurrealDB store). Implements the `Store` interface. This is the **only** code that talks to the database.
- **`common/ai/`** — AI client and worker for embeddings (Voyage AI / OpenRouter). Thin adapter around external AI APIs.
- **`mcp/tools/`** — One file per MCP tool. Each exports a self-contained definition (name, schema, handler). Handlers receive a `Store` interface, never a raw DB connection.
- **`common/ingest/`** — Orchestrates the indexing pipeline. Receives adapters, calls core logic. Runs **outside** the MCP server.
- **`cli/commands/`** — CLI subcommand handlers. Thin glue that wires adapters to core logic.
No dependency injection frameworks. Just pass interfaces to functions.
### MCP Tool Pattern
Each MCP tool is one file with a manual registry:
```ts
// mcp/tools/search-code.ts
export const searchCode = {
name: "search_code",
schema: z.object({ query: z.string(), intent: z.string().optional() }),
handler: async (params, store: Store) => { /* ... */ }
}
// mcp/registry.ts — explicit, no magic auto-discovery
import { searchCode } from "./tools/search-code"
import { searchDocs } from "./tools/search-docs"
export const tools = [searchCode, searchDocs]
```
### CLI: Single Binary, Subcommands
```
bun run src/cli/index.ts index [path] [--scope <prefix>] # Index a repo (with progress bar + ETA)
bun run src/cli/index.ts index [path] --fast # Fast mode — tree-sitter only (~5x faster, no typeInfo/calls/refs)
bun run src/cli/index.ts serve # Start MCP server (stdio or HTTP)
bun run src/cli/index.ts search <query> # Quick CLI search
```
Indexing runs **completely independently** from the MCP server. The MCP server only reads from the DB; the indexer writes to it.
### Build: Single Binary
The final deliverable is a **single self-contained binary** compiled with:
```bash
bun build --compile src/cli/index.ts --outfile code-graph
```
This produces a standalone executable with no runtime dependencies. Users run it directly:
```bash
./code-graph index [path]
./code-graph serve
./code-graph search <query>
```
### Progress Bar + ETA
**Non-negotiable.** Any indexing operation MUST show:
- A visual progress bar
- Percentage complete
- Files processed / total files
- Estimated time remaining (ETA)
```
Indexing my-repo...
[████████████░░░░░░░░] 62% | 1,240/2,000 files | ETA: 00:42
```
### Testability
Each layer is testable in isolation:
| Layer | How to test | Needs DB? | Needs git repo? |
|-------|-------------|-----------|-----------------|
| `common/extractors/*` | Call pure functions with test data | No | No |
| `mcp/tools/*` | Pass mock store to handler | No | No |
| `common/db/surrealdb.ts` | Integration test with real SurrealDB | Yes | No |
| `common/git/git.ts` | Integration test with temp git repos | No | Yes (temp) |
| `common/ingest/*` | Pass mock store + mock git adapter | No | No |
**Git testing strategy:** Git operations are inherently side-effectful, so they live in `common/git/git.ts` behind a typed interface. Tests for git logic create **temporary git repos** (init, commit, branch, etc.) and run assertions against those. This isolates git testing from everything else — no DB needed, no MCP needed, just git. The rest of the codebase receives git data through the interface and never touches git directly.
### Data Flow
```
CLI (cli/commands/)
→ Indexer (common/ingest/)
→ Git Adapter (common/git/git.ts) reads repo state
→ Extractors (common/extractors/) pure transforms on file data
→ Store Adapter (common/db/surrealdb.ts) persists to SurrealDB
MCP Server (mcp/server.ts)
→ Registry (mcp/registry.ts) routes to tool
→ Tool handler (mcp/tools/*.ts)
→ Extractors (common/extractors/) pure logic
→ Store Adapter (common/db/surrealdb.ts) reads from SurrealDB
```
---
## Instructions
- Functional, easy-to-test patterns. Keep side effects out of tests.
- Use `pre:` / `post:` labeled statements for design-by-contract
- KISS and DRY — no enterprise code, we are a startup
- Read [PLAN.md](./PLAN.md) for the current implementation plan
- Read [TODO.md](./TODO.md) for the full project roadmap
- Read topic-specific learnings docs **on demand** (use Read tool when working on related code):
- [docs/surrealdb.md](./docs/surrealdb.md) — SurrealDB schema, performance, transactions
- [docs/git.md](./docs/git.md) — Git adapters, watch mode, worktrees
- [docs/typescript.md](./docs/typescript.md) — TypeScript compiler API, parser
- [docs/bun.md](./docs/bun.md) — Bun runtime, FFI, NAPI
- [docs/mcp.md](./docs/mcp.md) — MCP server, SDK, tools
- [docs/architecture.md](./docs/architecture.md) — Design patterns, contracts, testing
- [docs/openrouter.md](./docs/openrouter.md) — OpenRouter API, embeddings, model calls
- [docs/bun-ipc.md](./docs/bun-ipc.md) — Bun IPC for AI worker processes
---
## Library References
Each `docs/*.md` file has upstream doc links at the top. Fetch those when you need library API details.
+6 -3
View File
@@ -31,9 +31,12 @@ class ChatMessage(BaseModel):
"""A chat message.""" """A chat message."""
role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Role of the message sender") role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Role of the message sender")
content: Optional[str] = Field(default=None, description="Message content") content: Optional[str] = Field(default=None, description="Message content")
tool_calls: Optional[List[ToolCall]] = Field(default=None, description="Tool calls from assistant") tool_calls: Optional[List[ToolCall]] = Field(default_factory=list, description="Tool calls from assistant")
tool_call_id: Optional[str] = Field(default=None, description="ID of tool call this message is responding to") #tool_call_id: Optional[str] = Field(default=None, description="ID of tool call this message is responding to")
name: Optional[str] = Field(default=None, description="Name of the tool/function") #name: Optional[str] = Field(default=None, description="Name of the tool/function")
class Config:
exclude_none = True
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
+6 -39
View File
@@ -51,42 +51,13 @@ def format_tool_description(tool) -> str:
def format_messages_with_tools(messages: list, tools: Optional[list] = None) -> str: def format_messages_with_tools(messages: list, tools: Optional[list] = None) -> str:
"""Format chat messages and tools into a single prompt using ChatML format.""" """Format chat messages into a single prompt using ChatML format.
Note: Tools are currently ignored - the model will respond normally.
"""
formatted = [] formatted = []
# Add system message with tool instructions if tools are present # Tools are accepted but ignored for now - model responds normally
if tools:
tool_instructions = """You are a helpful assistant with access to tools.
When you need to use a tool, you MUST respond with ONLY a JSON object in this exact format:
{"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "tool_name", "arguments": "{\\"param\\": \\"value\\"}"}}]}
Important:
- Use valid JSON format
- The arguments field must be a JSON string (serialized JSON)
- Do not include any other text when using tools
- If you don't need a tool, respond normally without JSON
Available tools:
"""
for tool in tools:
tool_instructions += format_tool_description(tool) + "\n"
tool_instructions += "\nIf you don't need to use a tool, respond normally. If you use a tool, make sure to format your response as valid JSON with the tool_calls field."
# Prepend tool instructions to system message or create one
has_system = False
for msg in messages:
if msg.role == "system":
msg.content = tool_instructions + "\n\n" + (msg.content or "")
has_system = True
break
if not has_system:
# Insert system message at the beginning
from api.models import ChatMessage
messages.insert(0, ChatMessage(role="system", content=tool_instructions))
for msg in messages: for msg in messages:
role = msg.role role = msg.role
@@ -226,10 +197,6 @@ async def chat_completions(request: ChatCompletionRequest):
completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time()) created = int(time.time())
# If tools are present, force non-streaming mode for proper tool call handling
if request.tools and request.stream:
request.stream = False
if request.stream: if request.stream:
# Return streaming response # Return streaming response
async def stream_generator() -> AsyncIterator[str]: async def stream_generator() -> AsyncIterator[str]:
@@ -303,7 +270,7 @@ async def chat_completions(request: ChatCompletionRequest):
# Parse tool calls if tools were provided # Parse tool calls if tools were provided
content = response_text content = response_text
tool_calls = None tool_calls = []
finish_reason = "stop" finish_reason = "stop"
if has_tools: if has_tools:
+66 -59
View File
@@ -14,6 +14,10 @@ from backends.base import LLMBackend, GenerationRequest, GenerationResponse, Mod
class MLXBackend(LLMBackend): class MLXBackend(LLMBackend):
"""Backend using mlx-lm for Apple Silicon optimized inference.""" """Backend using mlx-lm for Apple Silicon optimized inference."""
# Class-level lock to prevent concurrent access to MLX models
# MLX/Metal is not thread-safe and concurrent access causes segfaults
_mlx_lock = asyncio.Lock()
def __init__(self, context_size: int = 4096): def __init__(self, context_size: int = 4096):
""" """
Initialize MLX backend. Initialize MLX backend.
@@ -85,69 +89,72 @@ class MLXBackend(LLMBackend):
if not self._loaded or not self._model or not self._tokenizer: if not self._loaded or not self._model or not self._tokenizer:
raise GenerationError("Model not loaded") raise GenerationError("Model not loaded")
try: # Acquire class-level lock to prevent concurrent MLX access
from mlx_lm import generate as mlx_generate # This prevents segfaults when multiple workers try to use the model simultaneously
from mlx_lm.sample_utils import make_sampler async with self._mlx_lock:
start_time = time.time()
# Create sampler with temperature and top_p
sampler = make_sampler(
temp=request.temperature,
top_p=request.top_p,
min_p=0.0,
min_tokens_to_keep=1,
top_k=0
)
# Define stop sequences for proper response termination
# Common stop sequences for chat models
stop_sequences = [
"<|im_end|>", # Qwen, ChatML format
"<|endoftext|>", # GPT-2, Qwen
"<|end|>", # Generic
"Human:", # Prevent answering for the user
"User:", # Prevent answering for the user
"Assistant:" # Prevent multiple assistant turns
]
response_text = await asyncio.to_thread(
mlx_generate,
self._model,
self._tokenizer,
prompt=request.prompt,
max_tokens=request.max_tokens,
sampler=sampler,
verbose=False
)
# Clean up the response - remove any stop sequences that might have been included
for stop_seq in stop_sequences:
if stop_seq in response_text:
response_text = response_text.split(stop_seq)[0].strip()
end_time = time.time()
latency_ms = (end_time - start_time) * 1000
# Use tokenizer to get accurate token count
try: try:
tokens_generated = len(self._tokenizer.encode(response_text)) from mlx_lm import generate as mlx_generate
except Exception: from mlx_lm.sample_utils import make_sampler
# Fallback: rough estimate
tokens_generated = len(response_text) // 4
tokens_per_second = tokens_generated / (latency_ms / 1000) if latency_ms > 0 else 0 start_time = time.time()
return GenerationResponse( # Create sampler with temperature and top_p
text=response_text, sampler = make_sampler(
tokens_generated=tokens_generated, temp=request.temperature,
tokens_per_second=tokens_per_second, top_p=request.top_p,
latency_ms=latency_ms, min_p=0.0,
backend_name=self.backend_name min_tokens_to_keep=1,
) top_k=0
)
except Exception as e: # Define stop sequences for proper response termination
raise GenerationError(f"MLX generation failed: {e}") # Common stop sequences for chat models
stop_sequences = [
"<|im_end|>", # Qwen, ChatML format
"<|endoftext|>", # GPT-2, Qwen
"<|end|>", # Generic
"Human:", # Prevent answering for the user
"User:", # Prevent answering for the user
"Assistant:" # Prevent multiple assistant turns
]
response_text = await asyncio.to_thread(
mlx_generate,
self._model,
self._tokenizer,
prompt=request.prompt,
max_tokens=request.max_tokens,
sampler=sampler,
verbose=False
)
# Clean up the response - remove any stop sequences that might have been included
for stop_seq in stop_sequences:
if stop_seq in response_text:
response_text = response_text.split(stop_seq)[0].strip()
end_time = time.time()
latency_ms = (end_time - start_time) * 1000
# Use tokenizer to get accurate token count
try:
tokens_generated = len(self._tokenizer.encode(response_text))
except Exception:
# Fallback: rough estimate
tokens_generated = len(response_text) // 4
tokens_per_second = tokens_generated / (latency_ms / 1000) if latency_ms > 0 else 0
return GenerationResponse(
text=response_text,
tokens_generated=tokens_generated,
tokens_per_second=tokens_per_second,
latency_ms=latency_ms,
backend_name=self.backend_name
)
except Exception as e:
raise GenerationError(f"MLX generation failed: {e}")
def generate_stream(self, request: GenerationRequest) -> AsyncIterator[str]: def generate_stream(self, request: GenerationRequest) -> AsyncIterator[str]:
"""Generate text with streaming (simulated for MLX).""" """Generate text with streaming (simulated for MLX)."""
+4 -10
View File
@@ -159,10 +159,7 @@ def list_available_configurations(
use_mlx = hardware.is_apple_silicon if hardware else False use_mlx = hardware.is_apple_silicon if hardware else False
is_mac = use_mlx # Same flag for Mac detection is_mac = use_mlx # Same flag for Mac detection
# On Mac, check which quantizations are actually available for model in list_models(use_mlx=use_mlx):
check_available = use_mlx
for model in list_models(use_mlx=use_mlx, check_available=check_available):
for variant in model.variants: for variant in model.variants:
for quant in variant.quantizations: for quant in variant.quantizations:
# Calculate memory with context and offload # Calculate memory with context and offload
@@ -409,11 +406,8 @@ def custom_configuration(
# Use MLX models on Apple Silicon # Use MLX models on Apple Silicon
use_mlx = hardware.is_apple_silicon if hardware else False use_mlx = hardware.is_apple_silicon if hardware else False
# On Mac, check which quantizations are actually available
check_available = use_mlx
# List available models with context labels # List available models with context labels
models = list_models(use_mlx=use_mlx, check_available=check_available) models = list_models(use_mlx=use_mlx)
print(" Available Models:") print(" Available Models:")
for i, model in enumerate(models, 1): for i, model in enumerate(models, 1):
ctx_label = model.context_label ctx_label = model.context_label
@@ -486,10 +480,10 @@ def custom_configuration(
is_mac = hardware.is_apple_silicon is_mac = hardware.is_apple_silicon
count_term = "responses" if is_mac else "instances" count_term = "responses" if is_mac else "instances"
# On Mac with seed variation, we can use 2-5 responses (doesn't use more VRAM) # On Mac with seed variation, we can use 1-5 responses (doesn't use more VRAM)
# On other platforms, calculate based on available VRAM # On other platforms, calculate based on available VRAM
if is_mac: if is_mac:
min_count = 2 min_count = 1
max_count = 5 max_count = 5
default_count = 3 default_count = 3
print(f"\n 🍎 Apple Silicon: Using seed variation mode") print(f"\n 🍎 Apple Silicon: Using seed variation mode")
+4 -3
View File
@@ -25,9 +25,10 @@ def get_model_folder_name(model_id: str, variant: ModelVariant, quant: Quantizat
return f"{model_id}-{variant.size}-{quant.name}" return f"{model_id}-{variant.size}-{quant.name}"
def get_model_folder_name_mlx(model_id: str, variant: ModelVariant) -> str: def get_model_folder_name_mlx(model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> str:
"""Generate a unique folder name for an MLX model configuration.""" """Generate a unique folder name for an MLX model configuration."""
return f"{model_id}-{variant.size}-mlx" # Include quantization in folder name to avoid conflicts
return f"{model_id}-{variant.size}-{quant.name}-mlx"
class ModelDownloader: class ModelDownloader:
@@ -41,7 +42,7 @@ class ModelDownloader:
def get_model_folder_path(self, model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> Path: def get_model_folder_path(self, model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> Path:
"""Get the folder path where a model should be cached.""" """Get the folder path where a model should be cached."""
if self.use_mlx: if self.use_mlx:
folder_name = get_model_folder_name_mlx(model_id, variant) folder_name = get_model_folder_name_mlx(model_id, variant, quant)
else: else:
folder_name = get_model_folder_name(model_id, variant, quant) folder_name = get_model_folder_name(model_id, variant, quant)
return self.cache_dir / folder_name return self.cache_dir / folder_name
+44 -73
View File
@@ -86,38 +86,57 @@ class Model:
# MLX quantization sizes (GB) based on mlx-community models # MLX quantization sizes (GB) based on mlx-community models
# These are approximate sizes for the quantized models # HARDOCODED: These are verified to exist on HuggingFace mlx-community
# Last verified: 2025-02-23
# DO NOT make API calls on startup - use this hardcoded list
MLX_QUANT_SIZES = { MLX_QUANT_SIZES = {
# Format: model_id: {variant_size: {quant_bit: vram_gb}} # Format: model_id: {variant_size: {quant_bit: vram_gb}}
# Only includes quantizations that actually exist on HF
"qwen2.5-coder": { "qwen2.5-coder": {
"3b": {"3bit": 1.3, "4bit": 1.7, "5bit": 2.1, "6bit": 2.5, "8bit": 3.3}, "3b": {"3bit": 1.3, "4bit": 1.7, "6bit": 2.5, "8bit": 3.3},
"7b": {"3bit": 3.1, "4bit": 4.1, "5bit": 5.1, "6bit": 6.1, "8bit": 8.1}, # 5bit does NOT exist for 3b
"14b": {"3bit": 6.2, "4bit": 8.2, "5bit": 10.2, "6bit": 12.2, "8bit": 16.2}, "7b": {"3bit": 3.1, "4bit": 4.1, "6bit": 6.1, "8bit": 8.1},
# 5bit does NOT exist for 7b
"14b": {"3bit": 6.2, "4bit": 8.2, "6bit": 12.2, "8bit": 16.2},
# 5bit does NOT exist for 14b
}, },
"deepseek-coder": { "deepseek-coder": {
"1.3b": {"3bit": 0.6, "4bit": 0.8, "5bit": 1.0, "6bit": 1.2, "8bit": 1.6}, "1.3b": {"4bit": 0.8, "6bit": 1.2},
"6.7b": {"3bit": 2.9, "4bit": 3.9, "5bit": 4.9, "6bit": 5.9, "8bit": 7.9}, # 3bit, 5bit, 8bit do NOT exist
"6.7b": {"4bit": 3.9, "6bit": 5.9, "8bit": 7.9},
# 3bit, 5bit do NOT exist
}, },
"codellama": { "codellama": {
"7b": {"3bit": 3.1, "4bit": 4.1, "5bit": 5.1, "6bit": 6.1, "8bit": 8.1}, "7b": {"4bit": 4.1, "6bit": 6.1, "8bit": 8.1},
"13b": {"3bit": 5.7, "4bit": 7.6, "5bit": 9.5, "6bit": 11.4, "8bit": 15.2}, # 3bit, 5bit do NOT exist
"13b": {"4bit": 7.6, "6bit": 11.4, "8bit": 15.2},
# 3bit, 5bit do NOT exist
}, },
"llama-3.2": { "llama-3.2": {
"1b": {"3bit": 0.5, "4bit": 0.6, "5bit": 0.8, "6bit": 0.9, "8bit": 1.2}, "1b": {"4bit": 0.6, "8bit": 1.2},
"3b": {"3bit": 1.3, "4bit": 1.8, "5bit": 2.2, "6bit": 2.6, "8bit": 3.5}, # 3bit, 5bit, 6bit do NOT exist
"3b": {"4bit": 1.8, "6bit": 2.6, "8bit": 3.5},
# 3bit, 5bit do NOT exist
}, },
"phi-4": { "phi-4": {
"4b": {"3bit": 1.8, "4bit": 2.4, "5bit": 3.0, "6bit": 3.6, "8bit": 4.8}, "4b": {"4bit": 2.4, "6bit": 3.6, "8bit": 4.8},
# 3bit, 5bit do NOT exist
}, },
"gemma-2": { "gemma-2": {
"2b": {"3bit": 0.9, "4bit": 1.2, "5bit": 1.5, "6bit": 1.8, "8bit": 2.4}, "2b": {"4bit": 1.2, "6bit": 1.8, "8bit": 2.4},
"4b": {"3bit": 1.8, "4bit": 2.4, "5bit": 3.0, "6bit": 3.6, "8bit": 4.8}, # 3bit, 5bit do NOT exist
"9b": {"3bit": 4.0, "4bit": 5.3, "5bit": 6.6, "6bit": 7.9, "8bit": 10.5}, "4b": {"4bit": 2.4, "6bit": 3.6, "8bit": 4.8},
# 3bit, 5bit do NOT exist
"9b": {"4bit": 5.3, "6bit": 7.9, "8bit": 10.5},
# 3bit, 5bit do NOT exist
}, },
"starcoder2": { "starcoder2": {
"3b": {"3bit": 1.3, "4bit": 1.8, "5bit": 2.2, "6bit": 2.6, "8bit": 3.5}, "3b": {"4bit": 1.8, "6bit": 2.6, "8bit": 3.5},
"7b": {"3bit": 3.1, "4bit": 4.1, "5bit": 5.1, "6bit": 6.1, "8bit": 8.1}, # 3bit, 5bit do NOT exist
"15b": {"3bit": 6.6, "4bit": 8.8, "5bit": 11.0, "6bit": 13.2, "8bit": 17.6}, "7b": {"4bit": 4.1, "6bit": 6.1, "8bit": 8.1},
# 3bit, 5bit do NOT exist
"15b": {"4bit": 8.8, "6bit": 13.2, "8bit": 17.6},
# 3bit, 5bit do NOT exist
}, },
} }
@@ -241,51 +260,7 @@ def get_quality_map(use_mlx: bool = False) -> Dict[str, str]:
return GGUF_QUALITY_MAP return GGUF_QUALITY_MAP
def filter_available_mlx_quants(model_id: str, variant_size: str) -> Dict[str, float]: def build_model_variants(model_id: str, use_mlx: bool = False) -> List[ModelVariant]:
"""
Check which MLX quantizations are actually available on HuggingFace.
Returns a dict of available quantization names and their sizes.
If check fails, returns all defined quantizations.
"""
import requests
all_quants = MLX_QUANT_SIZES.get(model_id, {}).get(variant_size, {})
if not all_quants:
return {}
# Build base repo path (without quantization suffix)
mlx_repo_map = {
"qwen2.5-coder": f"mlx-community/Qwen2.5-Coder-{variant_size.capitalize()}-Instruct",
"deepseek-coder": f"mlx-community/deepseek-coder-{variant_size}-base",
"codellama": f"mlx-community/CodeLlama-{variant_size}-Instruct",
"llama-3.2": f"mlx-community/Llama-3.2-{variant_size}-Instruct",
"phi-4": f"mlx-community/phi-4",
"gemma-2": f"mlx-community/gemma-2-{variant_size}-it",
"starcoder2": f"mlx-community/starcoder2-{variant_size}",
}
base_repo = mlx_repo_map.get(model_id, "")
if not base_repo:
return all_quants
# Check which quantizations exist
available = {}
for quant_name in all_quants.keys():
repo_id = f"{base_repo}-{quant_name}"
try:
api_url = f"https://huggingface.co/api/models/{repo_id}"
response = requests.get(api_url, timeout=5)
if response.status_code == 200:
available[quant_name] = all_quants[quant_name]
except Exception:
# If check fails, include it anyway (will fail at download with better error)
available[quant_name] = all_quants[quant_name]
return available if available else all_quants
def build_model_variants(model_id: str, use_mlx: bool = False, check_available: bool = False) -> List[ModelVariant]:
"""Build model variants with appropriate quantizations for the platform.""" """Build model variants with appropriate quantizations for the platform."""
metadata = MODEL_METADATA.get(model_id) metadata = MODEL_METADATA.get(model_id)
if not metadata: if not metadata:
@@ -295,11 +270,7 @@ def build_model_variants(model_id: str, use_mlx: bool = False, check_available:
variants = [] variants = []
for variant_size in metadata["variants"]: for variant_size in metadata["variants"]:
# For MLX, optionally check which quantizations are actually available quant_sizes = get_quantization_sizes(model_id, use_mlx).get(variant_size, {})
if use_mlx and check_available:
quant_sizes = filter_available_mlx_quants(model_id, variant_size)
else:
quant_sizes = get_quantization_sizes(model_id, use_mlx).get(variant_size, {})
if not quant_sizes: if not quant_sizes:
continue continue
@@ -321,12 +292,12 @@ def build_model_variants(model_id: str, use_mlx: bool = False, check_available:
return variants return variants
def build_models(use_mlx: bool = False, check_available: bool = False) -> Dict[str, Model]: def build_models(use_mlx: bool = False) -> Dict[str, Model]:
"""Build the model registry with platform-appropriate quantizations.""" """Build the model registry with platform-appropriate quantizations."""
models = {} models = {}
for model_id, metadata in MODEL_METADATA.items(): for model_id, metadata in MODEL_METADATA.items():
variants = build_model_variants(model_id, use_mlx, check_available=check_available) variants = build_model_variants(model_id, use_mlx)
if not variants: if not variants:
continue continue
@@ -346,19 +317,19 @@ def build_models(use_mlx: bool = False, check_available: bool = False) -> Dict[s
DEFAULT_MODELS = build_models(use_mlx=False) DEFAULT_MODELS = build_models(use_mlx=False)
def get_model(model_id: str, use_mlx: bool = False, check_available: bool = False) -> Optional[Model]: def get_model(model_id: str, use_mlx: bool = False) -> Optional[Model]:
"""Get a model by ID with platform-appropriate quantizations.""" """Get a model by ID with platform-appropriate quantizations."""
if use_mlx: if use_mlx:
models = build_models(use_mlx=True, check_available=check_available) models = build_models(use_mlx=True)
return models.get(model_id) return models.get(model_id)
else: else:
return DEFAULT_MODELS.get(model_id) return DEFAULT_MODELS.get(model_id)
def list_models(use_mlx: bool = False, check_available: bool = False) -> List[Model]: def list_models(use_mlx: bool = False) -> List[Model]:
"""List all available models sorted by priority.""" """List all available models sorted by priority."""
if use_mlx: if use_mlx:
models = build_models(use_mlx=True, check_available=check_available) models = build_models(use_mlx=True)
else: else:
models = DEFAULT_MODELS models = DEFAULT_MODELS
return sorted(models.values(), key=lambda m: m.priority) return sorted(models.values(), key=lambda m: m.priority)
+4 -4
View File
@@ -219,10 +219,10 @@ def select_optimal_model(
# Only check when user is actually browsing or selecting custom config # Only check when user is actually browsing or selecting custom config
if preferred_model: if preferred_model:
from models.registry import get_model from models.registry import get_model
preferred = get_model(preferred_model, use_mlx=use_mlx, check_available=False) preferred = get_model(preferred_model, use_mlx=use_mlx)
models = [preferred] if preferred else [] models = [preferred] if preferred else []
else: else:
models = list_models(use_mlx=use_mlx, check_available=False) models = list_models(use_mlx=use_mlx)
# Note: On Apple Silicon with MLX, multiple instances work fine in sequential mode # Note: On Apple Silicon with MLX, multiple instances work fine in sequential mode
# The swarm manager will handle sequential execution to avoid GPU conflicts # The swarm manager will handle sequential execution to avoid GPU conflicts
@@ -295,7 +295,7 @@ def _try_model_with_context(
# On Mac with MLX (use_mlx=True), use 3 responses by default # On Mac with MLX (use_mlx=True), use 3 responses by default
# On other platforms, calculate based on VRAM # On other platforms, calculate based on VRAM
if use_mlx: if use_mlx:
instances = 3 # Default for seed variation mode instances = 1 # DEBUG: Changed from 3 to 1 for faster testing
else: else:
instances = calculate_max_instances(available_vram, vram_per_instance) instances = calculate_max_instances(available_vram, vram_per_instance)
@@ -354,7 +354,7 @@ def _try_smallest_variant_with_context(
# On Mac with MLX, use 3 responses by default # On Mac with MLX, use 3 responses by default
if use_mlx: if use_mlx:
instances = force_instances or 3 instances = force_instances or 1 # DEBUG: Changed from 3 to 1
else: else:
instances = force_instances or calculate_max_instances(available_vram, vram_per_instance) instances = force_instances or calculate_max_instances(available_vram, vram_per_instance)
instances = max(instances, 1) instances = max(instances, 1)
+96 -58
View File
@@ -37,7 +37,8 @@ class SwarmManager:
sequential_mode: Optional[bool] = None, sequential_mode: Optional[bool] = None,
use_seed_variation: Optional[bool] = None, use_seed_variation: Optional[bool] = None,
enable_reviewer: bool = False, enable_reviewer: bool = False,
max_retries: int = 2 max_retries: int = 2,
mcp_mode: bool = False
): ):
""" """
Initialize swarm manager. Initialize swarm manager.
@@ -52,6 +53,7 @@ class SwarmManager:
Auto-enabled for Apple Silicon to save memory. Auto-enabled for Apple Silicon to save memory.
enable_reviewer: If True, enable a reviewer/critic worker that validates consensus results enable_reviewer: If True, enable a reviewer/critic worker that validates consensus results
max_retries: Maximum number of retries if reviewer rejects the result max_retries: Maximum number of retries if reviewer rejects the result
mcp_mode: If True, suppress console output for MCP stdio compatibility
""" """
self.model_config = model_config self.model_config = model_config
self.hardware = hardware self.hardware = hardware
@@ -59,27 +61,30 @@ class SwarmManager:
self.consensus = ConsensusEngine(strategy=consensus_strategy) self.consensus = ConsensusEngine(strategy=consensus_strategy)
self._model_path: Optional[str] = None self._model_path: Optional[str] = None
self._running = False self._running = False
self.mcp_mode = mcp_mode
# Auto-enable sequential mode for Apple Silicon to avoid GPU conflicts # Auto-enable sequential mode for Apple Silicon to avoid GPU conflicts
if sequential_mode is None and hardware.is_apple_silicon: if sequential_mode is None and hardware.is_apple_silicon:
self.sequential_mode = True self.sequential_mode = True
print("🍎 Apple Silicon detected: Using sequential generation mode to avoid GPU conflicts") if not self.mcp_mode:
print(" Workers will run one at a time, but all stay loaded in memory") print("🍎 Apple Silicon detected: Using sequential generation mode to avoid GPU conflicts")
print(" Workers will run one at a time, but all stay loaded in memory")
else: else:
self.sequential_mode = sequential_mode or False self.sequential_mode = sequential_mode or False
# Auto-enable seed variation on Apple Silicon to save memory # Auto-enable seed variation on Apple Silicon to save memory
if use_seed_variation is None and hardware.is_apple_silicon: if use_seed_variation is None and hardware.is_apple_silicon:
self.use_seed_variation = True self.use_seed_variation = True
print("🌱 Using seed variation mode: One model, multiple responses with different seeds") if not self.mcp_mode:
print(f" Will generate {model_config.instances} responses with different random seeds") print("🌱 Using seed variation mode: One model, multiple responses with different seeds")
print(f" Will generate {model_config.instances} responses with different random seeds")
else: else:
self.use_seed_variation = use_seed_variation or False self.use_seed_variation = use_seed_variation or False
# Reviewer/critic mode # Reviewer/critic mode
self.enable_reviewer = enable_reviewer self.enable_reviewer = enable_reviewer
self.max_retries = max_retries self.max_retries = max_retries
if enable_reviewer: if enable_reviewer and not self.mcp_mode:
print("👁️ Reviewer mode enabled: A critic worker will validate consensus results") print("👁️ Reviewer mode enabled: A critic worker will validate consensus results")
print(f" Up to {max_retries} retries if output looks suspicious") print(f" Up to {max_retries} retries if output looks suspicious")
@@ -95,11 +100,13 @@ class SwarmManager:
""" """
self._model_path = model_path self._model_path = model_path
print(f"\n🚀 Initializing swarm with {self.model_config.instances} workers...") if not self.mcp_mode:
print(f"\n🚀 Initializing swarm with {self.model_config.instances} workers...")
# Create and load workers # Create and load workers
for i in range(self.model_config.instances): for i in range(self.model_config.instances):
print(f" Starting worker {i + 1}/{self.model_config.instances}...") if not self.mcp_mode:
print(f" Starting worker {i + 1}/{self.model_config.instances}...")
# Create backend for this worker # Create backend for this worker
backend = create_backend_for_config(self.model_config, self.hardware) backend = create_backend_for_config(self.model_config, self.hardware)
@@ -112,18 +119,23 @@ class SwarmManager:
success = await worker.load_model(model_path) success = await worker.load_model(model_path)
if success: if success:
self.workers.append(worker) self.workers.append(worker)
print(f"{worker.name} ready") if not self.mcp_mode:
print(f"{worker.name} ready")
else: else:
print(f"{worker.name} failed to load model") if not self.mcp_mode:
print(f"{worker.name} failed to load model")
except Exception as e: except Exception as e:
print(f"{worker.name} error: {e}") if not self.mcp_mode:
print(f"{worker.name} error: {e}")
if len(self.workers) == 0: if len(self.workers) == 0:
print("❌ No workers could be started") if not self.mcp_mode:
print("❌ No workers could be started")
return False return False
self._running = True self._running = True
print(f"✓ Swarm initialized with {len(self.workers)} workers") if not self.mcp_mode:
print(f"✓ Swarm initialized with {len(self.workers)} workers")
# Preload embedding model for consensus (if using similarity strategy) # Preload embedding model for consensus (if using similarity strategy)
if self.consensus.strategy == "similarity": if self.consensus.strategy == "similarity":
@@ -169,20 +181,25 @@ class SwarmManager:
if len(healthy_workers) == 1 or not use_consensus: if len(healthy_workers) == 1 or not use_consensus:
# Only one worker, no need for consensus # Only one worker, no need for consensus
# Use generate_with_progress to enable status tracking # Use generate_with_progress to enable status tracking
print(f"\n📝 Running single worker {healthy_workers[0].name}...") if not self.mcp_mode:
print(f"\n📝 Running single worker {healthy_workers[0].name}...")
# Start live display task # Start live display task (only in non-MCP mode)
stop_event = asyncio.Event() stop_event = None
display_task = asyncio.create_task(self._live_worker_display(healthy_workers, stop_event)) display_task = None
if not self.mcp_mode:
stop_event = asyncio.Event()
display_task = asyncio.create_task(self._live_worker_display(healthy_workers, stop_event))
try: try:
response = await healthy_workers[0].generate_with_progress(request) response = await healthy_workers[0].generate_with_progress(request)
finally: finally:
stop_event.set() if not self.mcp_mode and stop_event is not None and display_task is not None:
try: stop_event.set()
await asyncio.wait_for(display_task, timeout=1.0) try:
except asyncio.TimeoutError: await asyncio.wait_for(display_task, timeout=1.0)
display_task.cancel() except asyncio.TimeoutError:
display_task.cancel()
return ConsensusResult( return ConsensusResult(
selected_response=response, selected_response=response,
@@ -195,67 +212,83 @@ class SwarmManager:
# Send to all workers - either in parallel or sequentially # Send to all workers - either in parallel or sequentially
if self.sequential_mode: if self.sequential_mode:
# Sequential mode: Run workers one at a time to avoid GPU conflicts # Sequential mode: Run workers one at a time to avoid GPU conflicts
print(f"\n📝 Running {len(healthy_workers)} workers sequentially (Apple Silicon mode)...") if not self.mcp_mode:
print(f" All workers stay loaded in memory, but run one at a time") print(f"\n📝 Running {len(healthy_workers)} workers sequentially (Apple Silicon mode)...")
print(f" All workers stay loaded in memory, but run one at a time")
responses = [] responses = []
for i, worker in enumerate(healthy_workers): for i, worker in enumerate(healthy_workers):
print(f"\n [{i+1}/{len(healthy_workers)}] Running {worker.name}...") if not self.mcp_mode:
print(f"\n [{i+1}/{len(healthy_workers)}] Running {worker.name}...")
# Start live display for this worker # Start live display for this worker (only in non-MCP mode)
stop_event = asyncio.Event() stop_event = None
display_task = asyncio.create_task(self._live_worker_display([worker], stop_event)) display_task = None
if not self.mcp_mode:
stop_event = asyncio.Event()
display_task = asyncio.create_task(self._live_worker_display([worker], stop_event))
try: try:
response = await worker.generate_with_progress(request) response = await worker.generate_with_progress(request)
responses.append(response) responses.append(response)
print(f"{worker.name} completed ({response.tokens_generated} tokens)") if not self.mcp_mode:
print(f"{worker.name} completed ({response.tokens_generated} tokens)")
except Exception as e: except Exception as e:
responses.append(e) responses.append(e)
print(f"{worker.name} failed: {e}") if not self.mcp_mode:
print(f"{worker.name} failed: {e}")
finally: finally:
stop_event.set() if not self.mcp_mode and stop_event is not None:
try: stop_event.set()
await asyncio.wait_for(display_task, timeout=0.5) try:
except asyncio.TimeoutError: await asyncio.wait_for(display_task, timeout=0.5)
display_task.cancel() except asyncio.TimeoutError:
display_task.cancel()
else: else:
# Parallel mode: Run all workers simultaneously # Parallel mode: Run all workers simultaneously
print(f"\n📝 Sending request to {len(healthy_workers)} workers in parallel...") if not self.mcp_mode:
print(f"\n📝 Sending request to {len(healthy_workers)} workers in parallel...")
# Start live display task # Start live display task (only in non-MCP mode)
stop_event = asyncio.Event() stop_event = None
display_task = asyncio.create_task(self._live_worker_display(healthy_workers, stop_event)) display_task = None
if not self.mcp_mode:
stop_event = asyncio.Event()
display_task = asyncio.create_task(self._live_worker_display(healthy_workers, stop_event))
try: try:
tasks = [w.generate_with_progress(request) for w in healthy_workers] tasks = [w.generate_with_progress(request) for w in healthy_workers]
responses = await asyncio.gather(*tasks, return_exceptions=True) responses = await asyncio.gather(*tasks, return_exceptions=True)
finally: finally:
stop_event.set() if not self.mcp_mode and stop_event is not None and display_task is not None:
try: stop_event.set()
await asyncio.wait_for(display_task, timeout=1.0) try:
except asyncio.TimeoutError: await asyncio.wait_for(display_task, timeout=1.0)
display_task.cancel() except asyncio.TimeoutError:
display_task.cancel()
# Filter out errors # Filter out errors
valid_responses = [] valid_responses = []
for i, resp in enumerate(responses): for i, resp in enumerate(responses):
if isinstance(resp, Exception): if isinstance(resp, Exception):
print(f"{healthy_workers[i].name} failed: {resp}") if not self.mcp_mode:
print(f"{healthy_workers[i].name} failed: {resp}")
else: else:
if not self.sequential_mode: if not self.sequential_mode and not self.mcp_mode:
print(f"{healthy_workers[i].name} completed") print(f"{healthy_workers[i].name} completed")
valid_responses.append(resp) valid_responses.append(resp)
if len(valid_responses) == 0: if len(valid_responses) == 0:
raise RuntimeError("All workers failed to generate") raise RuntimeError("All workers failed to generate")
print(f" Got {len(valid_responses)} valid responses") if not self.mcp_mode:
print(f" Got {len(valid_responses)} valid responses")
# Run consensus # Run consensus
result = await self.consensus.select_best(valid_responses) result = await self.consensus.select_best(valid_responses)
print(f" Selected response using '{result.strategy}' strategy (confidence: {result.confidence:.2f})") if not self.mcp_mode:
print(f" Selected response using '{result.strategy}' strategy (confidence: {result.confidence:.2f})")
return result return result
@@ -295,32 +328,37 @@ class SwarmManager:
temperature=temperature temperature=temperature
) )
print(f"\n🎙️ Streaming from {fastest_worker.name} (fastest)") if not self.mcp_mode:
print(f" Total workers: {len(healthy_workers)}") print(f"\n🎙️ Streaming from {fastest_worker.name} (fastest)")
print(f" Total workers: {len(healthy_workers)}")
# Run all workers - we need responses from all for consensus # Run all workers - we need responses from all for consensus
# Stream comes from fastest, but we collect all responses # Stream comes from fastest, but we collect all responses
if len(healthy_workers) > 1: if len(healthy_workers) > 1:
if self.sequential_mode: if not self.mcp_mode:
print(f"📝 Running {len(healthy_workers)} workers sequentially...") if self.sequential_mode:
else: print(f"📝 Running {len(healthy_workers)} workers sequentially...")
print(f"📝 Running {len(healthy_workers)} workers in parallel...") else:
print(f"📝 Running {len(healthy_workers)} workers in parallel...")
# Start all other workers (they'll run sequentially or in parallel) # Start all other workers (they'll run sequentially or in parallel)
other_workers = [w for w in healthy_workers if w != fastest_worker] other_workers = [w for w in healthy_workers if w != fastest_worker]
for w in other_workers: for w in other_workers:
print(f" Queueing {w.name}...") if not self.mcp_mode:
print(f" Queueing {w.name}...")
asyncio.create_task(w.generate_with_progress(request)) asyncio.create_task(w.generate_with_progress(request))
# Stream from fastest worker with progress tracking # Stream from fastest worker with progress tracking
print(f"🔄 Starting stream from {fastest_worker.name}...") if not self.mcp_mode:
print(f"🔄 Starting stream from {fastest_worker.name}...")
chunk_count = 0 chunk_count = 0
async for chunk in fastest_worker.generate_with_progress_stream(request): async for chunk in fastest_worker.generate_with_progress_stream(request):
chunk_count += 1 chunk_count += 1
if chunk_count % 50 == 0: # Print progress every 50 chunks if not self.mcp_mode and chunk_count % 50 == 0: # Print progress every 50 chunks
print(f" Streamed {chunk_count} chunks...") print(f" Streamed {chunk_count} chunks...")
yield chunk yield chunk
print(f" Stream complete: {chunk_count} chunks total") if not self.mcp_mode:
print(f" Stream complete: {chunk_count} chunks total")
def get_status(self) -> SwarmStatus: def get_status(self) -> SwarmStatus:
"""Get current swarm status.""" """Get current swarm status."""