File size: 3,305 Bytes
20f762e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a805769
20f762e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Hierarchical orchestrator using middleware and sub-teams."""

import asyncio
from collections.abc import AsyncGenerator

import structlog

from src.agents.judge_agent_llm import LLMSubIterationJudge
from src.agents.magentic_agents import create_search_agent
from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam
from src.services.embeddings import get_embedding_service
from src.state import init_magentic_state
from src.utils.models import AgentEvent

logger = structlog.get_logger()


class ResearchTeam(SubIterationTeam):
    """Adapts Magentic ChatAgent to SubIterationTeam protocol."""

    def __init__(self) -> None:
        self.agent = create_search_agent()

    async def execute(self, task: str) -> str:
        response = await self.agent.run(task)
        if response.messages:
            for msg in reversed(response.messages):
                if msg.role == "assistant" and msg.text:
                    return str(msg.text)
        return "No response from agent."


class HierarchicalOrchestrator:
    """Orchestrator that uses hierarchical teams and sub-iterations."""

    def __init__(self) -> None:
        self.team = ResearchTeam()
        self.judge = LLMSubIterationJudge()
        self.middleware = SubIterationMiddleware(self.team, self.judge, max_iterations=5)

    async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
        logger.info("Starting hierarchical orchestrator", query=query)

        try:
            service = get_embedding_service()
            init_magentic_state(service)
        except Exception as e:
            logger.warning(
                "Embedding service initialization failed, using default state",
                error=str(e),
            )
            init_magentic_state()

        yield AgentEvent(type="started", message=f"Starting research: {query}")

        queue: asyncio.Queue[AgentEvent | None] = asyncio.Queue()

        async def event_callback(event: AgentEvent) -> None:
            await queue.put(event)

        task_future = asyncio.create_task(self.middleware.run(query, event_callback))

        while not task_future.done():
            get_event = asyncio.create_task(queue.get())
            done, _ = await asyncio.wait(
                {task_future, get_event}, return_when=asyncio.FIRST_COMPLETED
            )

            if get_event in done:
                event = get_event.result()
                if event:
                    yield event
            else:
                get_event.cancel()

        # Process remaining events
        while not queue.empty():
            ev = queue.get_nowait()
            if ev:
                yield ev

        try:
            result, assessment = await task_future

            assessment_text = assessment.reasoning if assessment else "None"
            yield AgentEvent(
                type="complete",
                message=(
                    f"Research complete.\n\nResult:\n{result}\n\nAssessment:\n{assessment_text}"
                ),
                data={"assessment": assessment.model_dump() if assessment else None},
            )
        except Exception as e:
            logger.error("Orchestrator failed", error=str(e))
            yield AgentEvent(type="error", message=f"Orchestrator failed: {e}")