use cookie-based implementation
Browse files- agent_manager/__init__.py +16 -31
- api/views.py +28 -9
- backend/settings.py +4 -4
agent_manager/__init__.py
CHANGED
|
@@ -146,7 +146,6 @@ STRUCTURED_CHAT = StructuredChatWrapper(CHAT)
|
|
| 146 |
SESSION_AGENTS = {}
|
| 147 |
|
| 148 |
def set_session_agent(session_key):
|
| 149 |
-
print(f"New session created: {session_key}")
|
| 150 |
memory = InMemorySaver()
|
| 151 |
agent = create_agent(
|
| 152 |
model=STRUCTURED_CHAT,
|
|
@@ -155,45 +154,31 @@ def set_session_agent(session_key):
|
|
| 155 |
)
|
| 156 |
SESSION_AGENTS[session_key] = agent
|
| 157 |
|
| 158 |
-
def get_or_create_agent(
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
# Check if session key exists in cache
|
| 170 |
-
cached_chat_session = cache.get(cache_key)
|
| 171 |
-
|
| 172 |
-
if cached_chat_session is not None:
|
| 173 |
-
# Session key exists in cache
|
| 174 |
-
if cached_chat_session != chat_session:
|
| 175 |
-
# Chat session is different, update to new session agent
|
| 176 |
-
set_session_agent(session_key)
|
| 177 |
-
# Update cache with new chat_session
|
| 178 |
-
cache.set(cache_key, chat_session)
|
| 179 |
-
# If chat_session is the same, continue without changes
|
| 180 |
-
else:
|
| 181 |
-
# Session key doesn't exist, add it to cache
|
| 182 |
-
cache.set(cache_key, chat_session)
|
| 183 |
|
| 184 |
-
|
| 185 |
-
return SESSION_AGENTS.get(session_key)
|
| 186 |
|
| 187 |
|
| 188 |
def get_agent(session_id: str):
|
| 189 |
"""Return an existing agent for a session, or None if expired/closed."""
|
| 190 |
return SESSION_AGENTS.get(session_id)
|
| 191 |
|
| 192 |
-
def end_session(
|
| 193 |
"""Delete an agent session to free memory."""
|
| 194 |
-
session_key =
|
| 195 |
-
if session_key in SESSION_AGENTS:
|
| 196 |
del SESSION_AGENTS[session_key]
|
|
|
|
| 197 |
return True
|
| 198 |
return False
|
| 199 |
|
|
|
|
| 146 |
SESSION_AGENTS = {}
|
| 147 |
|
| 148 |
def set_session_agent(session_key):
|
|
|
|
| 149 |
memory = InMemorySaver()
|
| 150 |
agent = create_agent(
|
| 151 |
model=STRUCTURED_CHAT,
|
|
|
|
| 154 |
)
|
| 155 |
SESSION_AGENTS[session_key] = agent
|
| 156 |
|
| 157 |
+
def get_or_create_agent(chat_session):
|
| 158 |
+
"""Get or create an agent keyed by the provided chat_session token."""
|
| 159 |
+
# Normalize to string to avoid type-mismatch keys
|
| 160 |
+
session_key = str(chat_session) if chat_session else None
|
| 161 |
+
|
| 162 |
+
if not session_key:
|
| 163 |
+
session_key = str(uuid.uuid4())
|
| 164 |
+
|
| 165 |
+
if session_key not in SESSION_AGENTS:
|
| 166 |
+
set_session_agent(session_key)
|
| 167 |
+
cache.set(f"chat_session_{session_key}", True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
return SESSION_AGENTS.get(session_key), session_key
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
def get_agent(session_id: str):
|
| 173 |
"""Return an existing agent for a session, or None if expired/closed."""
|
| 174 |
return SESSION_AGENTS.get(session_id)
|
| 175 |
|
| 176 |
+
def end_session(chat_session):
|
| 177 |
"""Delete an agent session to free memory."""
|
| 178 |
+
session_key = str(chat_session) if chat_session is not None else None
|
| 179 |
+
if session_key and session_key in SESSION_AGENTS:
|
| 180 |
del SESSION_AGENTS[session_key]
|
| 181 |
+
cache.delete(f"chat_session_{session_key}")
|
| 182 |
return True
|
| 183 |
return False
|
| 184 |
|
api/views.py
CHANGED
|
@@ -4,6 +4,7 @@ from rest_framework.permissions import AllowAny
|
|
| 4 |
from rest_framework.response import Response
|
| 5 |
from rest_framework import status
|
| 6 |
from agent_manager import get_or_create_agent, end_session, get_message_list
|
|
|
|
| 7 |
|
| 8 |
@csrf_exempt
|
| 9 |
@permission_classes([AllowAny])
|
|
@@ -16,7 +17,8 @@ def hello(request):
|
|
| 16 |
@api_view(['POST'])
|
| 17 |
def chat(request):
|
| 18 |
"""Start or continue an existing chat session."""
|
| 19 |
-
|
|
|
|
| 20 |
message = request.data.get("message")
|
| 21 |
|
| 22 |
if not message:
|
|
@@ -25,15 +27,15 @@ def chat(request):
|
|
| 25 |
"response": "Invalid message."
|
| 26 |
}, status=status.HTTP_400_BAD_REQUEST)
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
mode = request.data.get("mode")
|
| 31 |
tone = request.data.get("tone")
|
| 32 |
messages = get_message_list(mode, tone, message)
|
| 33 |
|
| 34 |
-
print("Message:", message, "Session Key:", request.session.session_key)
|
| 35 |
result = agent.invoke({ "messages": messages },
|
| 36 |
-
config={ "configurable": {"thread_id":
|
| 37 |
)
|
| 38 |
|
| 39 |
last_message = result.get('messages', [])[-1] if result.get('messages') else None
|
|
@@ -44,20 +46,37 @@ def chat(request):
|
|
| 44 |
"response": "Server Error"
|
| 45 |
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
| 46 |
|
| 47 |
-
|
| 48 |
"status": "success",
|
| 49 |
"response": last_message.content
|
| 50 |
}, status=status.HTTP_200_OK)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
@csrf_exempt
|
| 53 |
@permission_classes([AllowAny])
|
| 54 |
@api_view(['POST'])
|
| 55 |
def end(request):
|
| 56 |
"""End and delete the chat session."""
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
return Response({
|
| 62 |
"status": "error",
|
| 63 |
"response": "No active session."
|
|
|
|
| 4 |
from rest_framework.response import Response
|
| 5 |
from rest_framework import status
|
| 6 |
from agent_manager import get_or_create_agent, end_session, get_message_list
|
| 7 |
+
from django.conf import settings
|
| 8 |
|
| 9 |
@csrf_exempt
|
| 10 |
@permission_classes([AllowAny])
|
|
|
|
| 17 |
@api_view(['POST'])
|
| 18 |
def chat(request):
|
| 19 |
"""Start or continue an existing chat session."""
|
| 20 |
+
# Prefer secure HttpOnly cookie for session tracking
|
| 21 |
+
cookie_session = request.COOKIES.get("gm_session")
|
| 22 |
message = request.data.get("message")
|
| 23 |
|
| 24 |
if not message:
|
|
|
|
| 27 |
"response": "Invalid message."
|
| 28 |
}, status=status.HTTP_400_BAD_REQUEST)
|
| 29 |
|
| 30 |
+
# Use cookie if present; otherwise create a new session
|
| 31 |
+
agent, session_key = get_or_create_agent(cookie_session)
|
| 32 |
|
| 33 |
mode = request.data.get("mode")
|
| 34 |
tone = request.data.get("tone")
|
| 35 |
messages = get_message_list(mode, tone, message)
|
| 36 |
|
|
|
|
| 37 |
result = agent.invoke({ "messages": messages },
|
| 38 |
+
config={ "configurable": {"thread_id": session_key } }
|
| 39 |
)
|
| 40 |
|
| 41 |
last_message = result.get('messages', [])[-1] if result.get('messages') else None
|
|
|
|
| 46 |
"response": "Server Error"
|
| 47 |
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
| 48 |
|
| 49 |
+
resp = Response({
|
| 50 |
"status": "success",
|
| 51 |
"response": last_message.content
|
| 52 |
}, status=status.HTTP_200_OK)
|
| 53 |
|
| 54 |
+
# If cookie was missing, set it now with secure attributes
|
| 55 |
+
if not cookie_session:
|
| 56 |
+
secure = True if settings.MODE == 'production' else False
|
| 57 |
+
samesite = 'None' if settings.MODE == 'production' else 'Lax'
|
| 58 |
+
resp.set_cookie(
|
| 59 |
+
"gm_session",
|
| 60 |
+
value=session_key,
|
| 61 |
+
httponly=True,
|
| 62 |
+
secure=secure,
|
| 63 |
+
samesite=samesite,
|
| 64 |
+
max_age=60 * 60 * 24
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
return resp
|
| 68 |
+
|
| 69 |
@csrf_exempt
|
| 70 |
@permission_classes([AllowAny])
|
| 71 |
@api_view(['POST'])
|
| 72 |
def end(request):
|
| 73 |
"""End and delete the chat session."""
|
| 74 |
+
cookie_session = request.COOKIES.get("gm_session")
|
| 75 |
+
if end_session(cookie_session):
|
| 76 |
+
resp = Response({"status": "success", "message": "Session ended successfully"})
|
| 77 |
+
# Clear cookie
|
| 78 |
+
resp.delete_cookie("gm_session")
|
| 79 |
+
return resp
|
| 80 |
return Response({
|
| 81 |
"status": "error",
|
| 82 |
"response": "No active session."
|
backend/settings.py
CHANGED
|
@@ -44,12 +44,12 @@ CORS_ALLOW_ALL_ORIGINS = False if MODE == 'production' else True
|
|
| 44 |
|
| 45 |
|
| 46 |
SESSION_COOKIE_HTTPONLY = True
|
| 47 |
-
SESSION_COOKIE_SECURE =
|
| 48 |
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
|
| 49 |
SESSION_COOKIE_AGE = 60 * 60 * 24 # 1 day
|
| 50 |
|
| 51 |
CSRF_COOKIE_HTTPONLY = True
|
| 52 |
-
CSRF_COOKIE_SECURE =
|
| 53 |
CSRF_TRUSTED_ORIGINS = [
|
| 54 |
origin.strip()
|
| 55 |
for origin in os.environ.get("CSRF_TRUSTED_ORIGINS", "").split(",")
|
|
@@ -60,9 +60,9 @@ CSRF_TRUSTED_ORIGINS = [
|
|
| 60 |
'http://localhost:3000'
|
| 61 |
]
|
| 62 |
|
| 63 |
-
SECURE_SSL_REDIRECT =
|
| 64 |
|
| 65 |
-
SECURE_CONTENT_TYPE_NOSNIFF =
|
| 66 |
|
| 67 |
# HSTS settings - only enable in production with proper HTTPS configuration
|
| 68 |
# WARNING: Once enabled, browsers will remember this for SECURE_HSTS_SECONDS seconds
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
SESSION_COOKIE_HTTPONLY = True
|
| 47 |
+
SESSION_COOKIE_SECURE = True if MODE == 'production' else False # secure cookies only over HTTPS in production
|
| 48 |
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
|
| 49 |
SESSION_COOKIE_AGE = 60 * 60 * 24 # 1 day
|
| 50 |
|
| 51 |
CSRF_COOKIE_HTTPONLY = True
|
| 52 |
+
CSRF_COOKIE_SECURE = True if MODE == 'production' else False
|
| 53 |
CSRF_TRUSTED_ORIGINS = [
|
| 54 |
origin.strip()
|
| 55 |
for origin in os.environ.get("CSRF_TRUSTED_ORIGINS", "").split(",")
|
|
|
|
| 60 |
'http://localhost:3000'
|
| 61 |
]
|
| 62 |
|
| 63 |
+
SECURE_SSL_REDIRECT = True if MODE == 'production' else False
|
| 64 |
|
| 65 |
+
SECURE_CONTENT_TYPE_NOSNIFF = True if MODE == 'production' else False
|
| 66 |
|
| 67 |
# HSTS settings - only enable in production with proper HTTPS configuration
|
| 68 |
# WARNING: Once enabled, browsers will remember this for SECURE_HSTS_SECONDS seconds
|