BirdScopeAI / agent_cache.py
facemelter's picture
Initial commit to hf space for hackathon
ff0e97f verified
"""
Agent cache management for per-session agents.
This module handles storing and retrieving agents for differnet users/sessions.
Each agent is cached by (session_id, provider, model, api_key_hash) to avoid recreating them.
"""
from datetime import datetime, timedelta
from typing import Dict, Tuple, Any
import hashlib
# Global cache: maps (session_id, provider, model, api_key_hash, mode) -> agent
agent_cache: Dict[Tuple[str, str, str, str, str], Any] = {}
# Track when each agent was last used
agent_last_used: Dict[Tuple[str, str, str, str, str], datetime] = {}
async def get_or_create_agent(
session_id: str,
provider: str,
api_key: str,
model: str,
mode: str,
agent_factory_method
):
"""
Get existing agent from cache or create new one.
Args:
session_id: Unique identifier for user session (from gr.Request)
provider: "huggingface" or "openai"
api_key: User's API key or OAuth token
model: Model name/repo ID
mode: Agent mode (e.g., "Single Agent (All Tools)", "Specialized Subagents (3 Specialists)")
agent_factory_method: Async function to create agent if not cached
Returns:
Cached or newly created agent
Example:
agent = await get_or_create_agent(
session_id="abc123",
provider="openai",
api_key="sk-...",
model="gpt-4o-mini",
mode="Single Agent (All Tools)",
agent_factory_method=lambda: AgentFactory.create_streaming_agent_with_openai(...)
)
"""
# Create hash of API key to include in cache key
# This ensures different API keys create separate cached agents
api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
# Create cache key (now includes API key hash AND mode)
cache_key = (session_id, provider, model, api_key_hash, mode)
# Check if agent exists in cache
if cache_key in agent_cache:
print(f"[CACHE HIT] Reusing agent for session {session_id[:8]}...")
agent_last_used[cache_key] = datetime.now()
return agent_cache[cache_key]
# Cache miss - create new agent
print(f"[CACHE MISS] Creating new {provider} agent for session {session_id[:8]}...")
# Call the facotry method to create agent
agent = await agent_factory_method()
# Store in cache
agent_cache[cache_key] = agent
agent_last_used[cache_key] = datetime.now()
print(f"[CACHE] Stored agent. Total agents in cache: {len(agent_cache)}")
return agent
def cleanup_old_agents(max_age_hours: int = 1):
"""
Remove agents that haven't been used in max_age_hours.
Call this periodically to prevent memory leaks.
Args:
max_age_hours: Remove agents older than this many hours
Returns:
Number of agents removed
"""
now = datetime.now()
to_remove = []
for cache_key, last_used in agent_last_used.items():
age = now - last_used
if age > timedelta(hours=max_age_hours):
to_remove.append(cache_key)
# Remove old agents
for cache_key in to_remove:
print(f"[CLEANUP] Removing stale agent: {cache_key}")
del agent_cache[cache_key]
del agent_last_used[cache_key]
return len(to_remove)
def get_cache_stats():
"""Get statistics about the agent cache."""
return {
"total_agents": len(agent_cache),
"cache_keys": list(agent_cache.keys())
}