File size: 3,449 Bytes
ff0e97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Agent cache management for per-session agents.

This module handles storing and retrieving agents for differnet users/sessions.
Each agent is cached by (session_id, provider, model, api_key_hash) to avoid recreating them.
"""

from datetime import datetime, timedelta
from typing import Dict, Tuple, Any
import hashlib

# Global cache: maps (session_id, provider, model, api_key_hash, mode) -> agent
agent_cache: Dict[Tuple[str, str, str, str, str], Any] = {}

# Track when each agent was last used
agent_last_used: Dict[Tuple[str, str, str, str, str], datetime] = {}

async def get_or_create_agent(
    session_id: str,
    provider: str,
    api_key: str,
    model: str,
    mode: str,
    agent_factory_method
):
    """
    Get existing agent from cache or create new one.

    Args:
        session_id: Unique identifier for user session (from gr.Request)
        provider: "huggingface" or "openai"
        api_key: User's API key or OAuth token
        model: Model name/repo ID
        mode: Agent mode (e.g., "Single Agent (All Tools)", "Specialized Subagents (3 Specialists)")
        agent_factory_method: Async function to create agent if not cached

    Returns:
        Cached or newly created agent

    Example:
        agent = await get_or_create_agent(
            session_id="abc123",
            provider="openai",
            api_key="sk-...",
            model="gpt-4o-mini",
            mode="Single Agent (All Tools)",
            agent_factory_method=lambda: AgentFactory.create_streaming_agent_with_openai(...)
        )
    """
    # Create hash of API key to include in cache key
    # This ensures different API keys create separate cached agents
    api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]

    # Create cache key (now includes API key hash AND mode)
    cache_key = (session_id, provider, model, api_key_hash, mode)

    # Check if agent exists in cache
    if cache_key in agent_cache:
        print(f"[CACHE HIT] Reusing agent for session {session_id[:8]}...")
        agent_last_used[cache_key] = datetime.now()
        return agent_cache[cache_key]

    # Cache miss - create new agent
    print(f"[CACHE MISS] Creating new {provider} agent for session {session_id[:8]}...")

    # Call the facotry method to create agent
    agent = await agent_factory_method()

    # Store in cache
    agent_cache[cache_key] = agent
    agent_last_used[cache_key] = datetime.now()

    print(f"[CACHE] Stored agent. Total agents in cache: {len(agent_cache)}")

    return agent

def cleanup_old_agents(max_age_hours: int = 1):
    """
    Remove agents that haven't been used in max_age_hours.

    Call this periodically to prevent memory leaks.

    Args:
        max_age_hours: Remove agents older than this many hours

    Returns:
        Number of agents removed
    """
    now = datetime.now()
    to_remove = []

    for cache_key, last_used in agent_last_used.items():
        age = now - last_used
        if age > timedelta(hours=max_age_hours):
            to_remove.append(cache_key)

    # Remove old agents
    for cache_key in to_remove:
        print(f"[CLEANUP] Removing stale agent: {cache_key}")
        del agent_cache[cache_key]
        del agent_last_used[cache_key]

    return len(to_remove)

def get_cache_stats():
    """Get statistics about the agent cache."""
    return {
        "total_agents": len(agent_cache),
        "cache_keys": list(agent_cache.keys())
    }