Spaces:
Running
Running
| """ | |
| Agent cache management for per-session agents. | |
| This module handles storing and retrieving agents for differnet users/sessions. | |
| Each agent is cached by (session_id, provider, model, api_key_hash) to avoid recreating them. | |
| """ | |
| from datetime import datetime, timedelta | |
| from typing import Dict, Tuple, Any | |
| import hashlib | |
| # Global cache: maps (session_id, provider, model, api_key_hash, mode) -> agent | |
| agent_cache: Dict[Tuple[str, str, str, str, str], Any] = {} | |
| # Track when each agent was last used | |
| agent_last_used: Dict[Tuple[str, str, str, str, str], datetime] = {} | |
| async def get_or_create_agent( | |
| session_id: str, | |
| provider: str, | |
| api_key: str, | |
| model: str, | |
| mode: str, | |
| agent_factory_method | |
| ): | |
| """ | |
| Get existing agent from cache or create new one. | |
| Args: | |
| session_id: Unique identifier for user session (from gr.Request) | |
| provider: "huggingface" or "openai" | |
| api_key: User's API key or OAuth token | |
| model: Model name/repo ID | |
| mode: Agent mode (e.g., "Single Agent (All Tools)", "Specialized Subagents (3 Specialists)") | |
| agent_factory_method: Async function to create agent if not cached | |
| Returns: | |
| Cached or newly created agent | |
| Example: | |
| agent = await get_or_create_agent( | |
| session_id="abc123", | |
| provider="openai", | |
| api_key="sk-...", | |
| model="gpt-4o-mini", | |
| mode="Single Agent (All Tools)", | |
| agent_factory_method=lambda: AgentFactory.create_streaming_agent_with_openai(...) | |
| ) | |
| """ | |
| # Create hash of API key to include in cache key | |
| # This ensures different API keys create separate cached agents | |
| api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] | |
| # Create cache key (now includes API key hash AND mode) | |
| cache_key = (session_id, provider, model, api_key_hash, mode) | |
| # Check if agent exists in cache | |
| if cache_key in agent_cache: | |
| print(f"[CACHE HIT] Reusing agent for session {session_id[:8]}...") | |
| agent_last_used[cache_key] = datetime.now() | |
| return agent_cache[cache_key] | |
| # Cache miss - create new agent | |
| print(f"[CACHE MISS] Creating new {provider} agent for session {session_id[:8]}...") | |
| # Call the facotry method to create agent | |
| agent = await agent_factory_method() | |
| # Store in cache | |
| agent_cache[cache_key] = agent | |
| agent_last_used[cache_key] = datetime.now() | |
| print(f"[CACHE] Stored agent. Total agents in cache: {len(agent_cache)}") | |
| return agent | |
| def cleanup_old_agents(max_age_hours: int = 1): | |
| """ | |
| Remove agents that haven't been used in max_age_hours. | |
| Call this periodically to prevent memory leaks. | |
| Args: | |
| max_age_hours: Remove agents older than this many hours | |
| Returns: | |
| Number of agents removed | |
| """ | |
| now = datetime.now() | |
| to_remove = [] | |
| for cache_key, last_used in agent_last_used.items(): | |
| age = now - last_used | |
| if age > timedelta(hours=max_age_hours): | |
| to_remove.append(cache_key) | |
| # Remove old agents | |
| for cache_key in to_remove: | |
| print(f"[CLEANUP] Removing stale agent: {cache_key}") | |
| del agent_cache[cache_key] | |
| del agent_last_used[cache_key] | |
| return len(to_remove) | |
| def get_cache_stats(): | |
| """Get statistics about the agent cache.""" | |
| return { | |
| "total_agents": len(agent_cache), | |
| "cache_keys": list(agent_cache.keys()) | |
| } |