refactor: unified generation interface for federation and local modes

- Created _generate_with_consensus() that handles both federation and local generation
- Callers don't need to know which mode is being used - it's transparent
- Tool execution loop uses same unified interface for all iterations
- Removed special-case federation logic from main handler
- Federation is now a transparent layer around generation
- All 41 tests passing
This commit is contained in:
2026-02-25 23:36:24 +01:00
parent 414cb444f3
commit 93844a81b0
+44 -29
View File
@@ -228,25 +228,48 @@ def _create_response(
return response
async def _generate_with_local_swarm(
swarm_manager,
async def _generate_with_consensus(
prompt: str,
max_tokens: int,
temperature: float,
stream: bool = False
swarm_manager,
federated_swarm=None
) -> tuple[str, int, float]:
"""Generate response using local swarm.
"""Generate response with consensus (local or federated).
This is the unified generation interface - it handles both local-only
and federated generation transparently. Callers don't need to know
which mode is being used.
Args:
swarm_manager: Swarm manager instance
prompt: Prompt to generate from
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
stream: Whether this is a streaming request
swarm_manager: Local swarm manager instance
federated_swarm: Optional federated swarm for multi-node consensus
Returns:
Tuple of (response_text, tokens_generated, tokens_per_second)
"""
# Check if federation is available
if federated_swarm is not None:
peers = federated_swarm.discovery.get_peers()
if peers:
logger.debug(f"🌐 Using federation with {len(peers)} peer(s)")
try:
content, tool_calls, finish_reason = await federated_swarm.generate_with_federation(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature
)
# Federation returns content directly
# Note: tool_calls from federation should be ignored - head node handles tools
return content, 0, 0.0 # Tokens/TPS not tracked in federation mode
except Exception as e:
logger.warning(f"Federation failed, falling back to local: {e}")
# Fall through to local generation
# Local generation (fallback or no federation)
try:
result = await swarm_manager.generate(
prompt=prompt,
@@ -254,13 +277,8 @@ async def _generate_with_local_swarm(
temperature=temperature,
use_consensus=True
)
response = result.selected_response
return (
response.text,
response.tokens_generated,
response.tokens_per_second
)
return response.text, response.tokens_generated, response.tokens_per_second
except Exception as e:
logger.exception("Error in swarm generation")
raise
@@ -544,22 +562,15 @@ async def handle_chat_completion(
iteration += 1
logger.info(f"--- Tool Execution Iteration {iteration} ---")
# Generate response (use federation if available)
# Generate response (unified interface - handles federation automatically)
logger.debug(f"Generating response...")
if use_federation and iteration == 1:
# First iteration: use federation for consensus
logger.info(f"🌐 Using federation for generation...")
content, tool_calls, finish_reason = await _generate_with_federation(
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7
)
response_text = content
tokens_generated = 0 # Will be calculated from usage if needed
tps = 0.0
else:
# Subsequent iterations or no federation: use local swarm
response_text, tokens_generated, tps = await _generate_with_local_swarm(
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
)
response_text, tokens_generated, tps = await _generate_with_consensus(
prompt=prompt,
max_tokens=request.max_tokens or 1024,
temperature=request.temperature or 0.7,
swarm_manager=swarm_manager,
federated_swarm=federated_swarm
)
logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)")
logger.debug(f"Response: {response_text[:200]}...")
@@ -685,8 +696,12 @@ async def handle_chat_completion(
logger.info(f" (tool_call_id: {msg.tool_call_id}, name: {msg.name})")
logger.debug(f"Full prompt:\n{next_prompt[:1000]}...")
response_text, tokens_generated, tps = await _generate_with_local_swarm(
swarm_manager, next_prompt, request.max_tokens or 1024, request.temperature or 0.7
response_text, tokens_generated, tps = await _generate_with_consensus(
prompt=next_prompt,
max_tokens=request.max_tokens or 1024,
temperature=request.temperature or 0.7,
swarm_manager=swarm_manager,
federated_swarm=federated_swarm
)
logger.info(f"✅ Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")