File size: 3,648 Bytes
ff0e97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Subagent Router

Orchestrates routing between specialized subagents using LangGraph's
delegation pattern.
"""
from typing import Dict, Any, List, Literal
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import InMemorySaver
from .subagent_config import SubAgentConfig
from .subagent_factory import SubAgentFactory

async def create_router_agent(all_tools: List[Any], llm: BaseChatModel):
    """
    Create a router agent that orchestrates specialized subagents.

    Args:
        all_tools: Full list of available MCP tools
        llm: Language model for the router

    Returns:
        Compiled LangGraph workflow
    """

    async def router_node(state: MessagesState):
        """Main router that delegates to subagents."""
        # Get routing instructions
        router_prompt = SubAgentConfig.get_router_prompt()

        # Add system message with routing instructions
        messages = [SystemMessage(content=router_prompt)] + state["messages"]

        # Router decides which subagent to use
        response = await llm.ainvoke(messages)

        # Extract subagent name from response (you could make this more sophisticated)
        # For now, the router will use tools to delegate
        return {"messages": [response]}

    async def create_subagent_node(subagent_name: str):
        """Create a node for a specific subagent."""
        async def subagent_node(state: MessagesState):
            # Create the specialized subagent
            subagent = await SubAgentFactory.create_subagent(
                subagent_name, all_tools, llm
            )

            # Run the subagent
            result = await subagent.ainvoke(state)
            return result

        return subagent_node

    # Build the graph
    workflow = StateGraph(MessagesState)

    # Add nodes
    workflow.add_node("router", router_node)
    workflow.add_node("image_identifier", await create_subagent_node("image_identifier"))
    workflow.add_node("species_explorer", await create_subagent_node("species_explorer"))
    workflow.add_node("taxonomy_specialist", await create_subagent_node("taxonomy_specialist"))

    # Define routing logic
    def route_to_specialist(state: MessagesState) -> Literal["image_identifier", "species_explorer", "taxonomy_specialist", END]:
        """Route based on last message content."""
        last_message = state["messages"][-1]
        content = last_message.content.lower()

        # Simple keyword-based routing (could be improved with LLM classification)
        if any(word in content for word in ["identify", "what bird", "classify", "image", "photo"]):
            return "image_identifier"
        elif any(word in content for word in ["audio", "sound", "call", "song", "find", "search"]):
            return "species_explorer"
        elif any(word in content for word in ["family", "families", "conservation", "endangered", "taxonomy"]):
            return "taxonomy_specialist"
        else:
            # Default to species explorer for general queries
            return "species_explorer"

    # Connect nodes
    workflow.add_edge(START, "router")
    workflow.add_conditional_edges("router", route_to_specialist)
    workflow.add_edge("image_identifier", END)
    workflow.add_edge("species_explorer", END)
    workflow.add_edge("taxonomy_specialist", END)

    # Compile with memory for conversation context
    return workflow.compile(checkpointer=InMemorySaver())