refactor: unified generation interface for federation and local modes
- Created _generate_with_consensus() that handles both federation and local generation - Callers don't need to know which mode is being used - it's transparent - Tool execution loop uses same unified interface for all iterations - Removed special-case federation logic from main handler - Federation is now a transparent layer around generation - All 41 tests passing
This commit is contained in:
+44
-29
@@ -228,25 +228,48 @@ def _create_response(
|
||||
return response
|
||||
|
||||
|
||||
async def _generate_with_local_swarm(
|
||||
swarm_manager,
|
||||
async def _generate_with_consensus(
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
stream: bool = False
|
||||
swarm_manager,
|
||||
federated_swarm=None
|
||||
) -> tuple[str, int, float]:
|
||||
"""Generate response using local swarm.
|
||||
"""Generate response with consensus (local or federated).
|
||||
|
||||
This is the unified generation interface - it handles both local-only
|
||||
and federated generation transparently. Callers don't need to know
|
||||
which mode is being used.
|
||||
|
||||
Args:
|
||||
swarm_manager: Swarm manager instance
|
||||
prompt: Prompt to generate from
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
stream: Whether this is a streaming request
|
||||
swarm_manager: Local swarm manager instance
|
||||
federated_swarm: Optional federated swarm for multi-node consensus
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tokens_generated, tokens_per_second)
|
||||
"""
|
||||
# Check if federation is available
|
||||
if federated_swarm is not None:
|
||||
peers = federated_swarm.discovery.get_peers()
|
||||
if peers:
|
||||
logger.debug(f"🌐 Using federation with {len(peers)} peer(s)")
|
||||
try:
|
||||
content, tool_calls, finish_reason = await federated_swarm.generate_with_federation(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature
|
||||
)
|
||||
# Federation returns content directly
|
||||
# Note: tool_calls from federation should be ignored - head node handles tools
|
||||
return content, 0, 0.0 # Tokens/TPS not tracked in federation mode
|
||||
except Exception as e:
|
||||
logger.warning(f"Federation failed, falling back to local: {e}")
|
||||
# Fall through to local generation
|
||||
|
||||
# Local generation (fallback or no federation)
|
||||
try:
|
||||
result = await swarm_manager.generate(
|
||||
prompt=prompt,
|
||||
@@ -254,13 +277,8 @@ async def _generate_with_local_swarm(
|
||||
temperature=temperature,
|
||||
use_consensus=True
|
||||
)
|
||||
|
||||
response = result.selected_response
|
||||
return (
|
||||
response.text,
|
||||
response.tokens_generated,
|
||||
response.tokens_per_second
|
||||
)
|
||||
return response.text, response.tokens_generated, response.tokens_per_second
|
||||
except Exception as e:
|
||||
logger.exception("Error in swarm generation")
|
||||
raise
|
||||
@@ -544,22 +562,15 @@ async def handle_chat_completion(
|
||||
iteration += 1
|
||||
logger.info(f"--- Tool Execution Iteration {iteration} ---")
|
||||
|
||||
# Generate response (use federation if available)
|
||||
# Generate response (unified interface - handles federation automatically)
|
||||
logger.debug(f"Generating response...")
|
||||
if use_federation and iteration == 1:
|
||||
# First iteration: use federation for consensus
|
||||
logger.info(f"🌐 Using federation for generation...")
|
||||
content, tool_calls, finish_reason = await _generate_with_federation(
|
||||
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
response_text = content
|
||||
tokens_generated = 0 # Will be calculated from usage if needed
|
||||
tps = 0.0
|
||||
else:
|
||||
# Subsequent iterations or no federation: use local swarm
|
||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
||||
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
response_text, tokens_generated, tps = await _generate_with_consensus(
|
||||
prompt=prompt,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
temperature=request.temperature or 0.7,
|
||||
swarm_manager=swarm_manager,
|
||||
federated_swarm=federated_swarm
|
||||
)
|
||||
|
||||
logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)")
|
||||
logger.debug(f"Response: {response_text[:200]}...")
|
||||
@@ -685,8 +696,12 @@ async def handle_chat_completion(
|
||||
logger.info(f" (tool_call_id: {msg.tool_call_id}, name: {msg.name})")
|
||||
logger.debug(f"Full prompt:\n{next_prompt[:1000]}...")
|
||||
|
||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
||||
swarm_manager, next_prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
response_text, tokens_generated, tps = await _generate_with_consensus(
|
||||
prompt=next_prompt,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
temperature=request.temperature or 0.7,
|
||||
swarm_manager=swarm_manager,
|
||||
federated_swarm=federated_swarm
|
||||
)
|
||||
|
||||
logger.info(f"✅ Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")
|
||||
|
||||
Reference in New Issue
Block a user