Spaces:
Running
Running
| import logging | |
| logger = logging.getLogger(__name__) | |
| from a2a.server.agent_execution import AgentExecutor, RequestContext | |
| from a2a.server.events import EventQueue | |
| from a2a.server.tasks import TaskUpdater | |
| from a2a.types import ( | |
| Part, | |
| TaskState, | |
| TextPart, | |
| ) | |
| from a2a.utils import new_agent_text_message, new_task | |
| from google.adk.artifacts import InMemoryArtifactService | |
| from google.adk.memory.in_memory_memory_service import InMemoryMemoryService | |
| from google.adk.runners import Runner | |
| from google.adk.sessions import InMemorySessionService | |
| from google.genai import types | |
| class ADKAgentExecutor(AgentExecutor): | |
| def __init__( | |
| self, | |
| agent, | |
| status_message="Processing request...", | |
| artifact_name="response", | |
| ): | |
| """Initialize a generic ADK agent executor. | |
| Args: | |
| agent: The ADK agent instance | |
| status_message: Message to display while processing | |
| artifact_name: Name for the response artifact | |
| """ | |
| self.agent = agent | |
| self.status_message = status_message | |
| self.artifact_name = artifact_name | |
| self.runner = Runner( | |
| app_name=agent.name, | |
| agent=agent, | |
| artifact_service=InMemoryArtifactService(), | |
| session_service=InMemorySessionService(), | |
| memory_service=InMemoryMemoryService(), | |
| ) | |
| async def cancel( | |
| self, | |
| context: RequestContext, | |
| event_queue: EventQueue, | |
| ) -> None: | |
| """Cancel the execution of a specific task.""" | |
| raise NotImplementedError( | |
| "Cancellation is not implemented for ADKAgentExecutor." | |
| ) | |
| async def execute( | |
| self, | |
| context: RequestContext, | |
| event_queue: EventQueue, | |
| ) -> None: | |
| query = context.get_user_input() | |
| task = context.current_task or new_task(context.message) | |
| await event_queue.enqueue_event(task) | |
| updater = TaskUpdater(event_queue, task.id, task.context_id) | |
| if context.call_context: | |
| user_id = context.call_context.user.user_name | |
| else: | |
| user_id = "a2a_user" | |
| try: | |
| # Update status with custom message | |
| await updater.update_status( | |
| TaskState.working, | |
| new_agent_text_message(self.status_message, task.context_id, task.id), | |
| ) | |
| # Process with ADK agent | |
| session = await self.runner.session_service.create_session( | |
| app_name=self.agent.name, | |
| user_id=user_id, | |
| state={}, | |
| session_id=task.context_id, | |
| ) | |
| content = types.Content( | |
| role="user", parts=[types.Part.from_text(text=query)] | |
| ) | |
| response_text = "" | |
| async for event in self.runner.run_async( | |
| user_id=user_id, session_id=session.id, new_message=content | |
| ): | |
| if event.is_final_response(): | |
| if event.content and event.content.parts: | |
| response_text = ''.join( | |
| [p.text for p in event.content.parts if p.text] | |
| ) | |
| logger.info(f' 🛠️ **Response from LLM Call: {response_text}') | |
| # if hasattr(part, "text") and part.text: | |
| # response_text += part.text + "\n" | |
| else: | |
| if event.content and event.content.parts: | |
| for part in event.content.parts: | |
| if hasattr(part, "function_call") and part.function_call is not None: | |
| # Log or handle function calls if needed | |
| logger.info(f' 🛠️ **Tool Call: {part.function_call.name}') | |
| elif hasattr(part, "function_response") and part.function_response is not None: | |
| logger.info(f' ⚡ **Tool Response: {part.function_response.response}') | |
| # Add response as artifact with custom name | |
| await updater.add_artifact( | |
| [Part(root=TextPart(text=response_text))], | |
| name=self.artifact_name, | |
| ) | |
| await updater.complete() | |
| except Exception as e: | |
| await updater.update_status( | |
| TaskState.failed, | |
| new_agent_text_message(f"Error: {e!s}", task.context_id, task.id), | |
| final=True, | |
| ) |