CODERAMA / agent_executor.py
debasisdwivedy's picture
Fixing requirements file
d8f7860
import logging
logger = logging.getLogger(__name__)
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.tasks import TaskUpdater
from a2a.types import (
Part,
TaskState,
TextPart,
)
from a2a.utils import new_agent_text_message, new_task
from google.adk.artifacts import InMemoryArtifactService
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai import types
class ADKAgentExecutor(AgentExecutor):
def __init__(
self,
agent,
status_message="Processing request...",
artifact_name="response",
):
"""Initialize a generic ADK agent executor.
Args:
agent: The ADK agent instance
status_message: Message to display while processing
artifact_name: Name for the response artifact
"""
self.agent = agent
self.status_message = status_message
self.artifact_name = artifact_name
self.runner = Runner(
app_name=agent.name,
agent=agent,
artifact_service=InMemoryArtifactService(),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
async def cancel(
self,
context: RequestContext,
event_queue: EventQueue,
) -> None:
"""Cancel the execution of a specific task."""
raise NotImplementedError(
"Cancellation is not implemented for ADKAgentExecutor."
)
async def execute(
self,
context: RequestContext,
event_queue: EventQueue,
) -> None:
query = context.get_user_input()
task = context.current_task or new_task(context.message)
await event_queue.enqueue_event(task)
updater = TaskUpdater(event_queue, task.id, task.context_id)
if context.call_context:
user_id = context.call_context.user.user_name
else:
user_id = "a2a_user"
try:
# Update status with custom message
await updater.update_status(
TaskState.working,
new_agent_text_message(self.status_message, task.context_id, task.id),
)
# Process with ADK agent
session = await self.runner.session_service.create_session(
app_name=self.agent.name,
user_id=user_id,
state={},
session_id=task.context_id,
)
content = types.Content(
role="user", parts=[types.Part.from_text(text=query)]
)
response_text = ""
async for event in self.runner.run_async(
user_id=user_id, session_id=session.id, new_message=content
):
if event.is_final_response():
if event.content and event.content.parts:
response_text = ''.join(
[p.text for p in event.content.parts if p.text]
)
logger.info(f' 🛠️ **Response from LLM Call: {response_text}')
# if hasattr(part, "text") and part.text:
# response_text += part.text + "\n"
else:
if event.content and event.content.parts:
for part in event.content.parts:
if hasattr(part, "function_call") and part.function_call is not None:
# Log or handle function calls if needed
logger.info(f' 🛠️ **Tool Call: {part.function_call.name}')
elif hasattr(part, "function_response") and part.function_response is not None:
logger.info(f' ⚡ **Tool Response: {part.function_response.response}')
# Add response as artifact with custom name
await updater.add_artifact(
[Part(root=TextPart(text=response_text))],
name=self.artifact_name,
)
await updater.complete()
except Exception as e:
await updater.update_status(
TaskState.failed,
new_agent_text_message(f"Error: {e!s}", task.context_id, task.id),
final=True,
)