ruslanmv commited on
Commit
cd6d7ff
·
1 Parent(s): a6720b5

Update middleware.py

Browse files
Files changed (1) hide show
  1. app/middleware.py +132 -38
app/middleware.py CHANGED
@@ -1,14 +1,22 @@
 
 
 
1
  import time
2
  import logging
3
  import json
4
- from typing import Callable
5
- from fastapi import FastAPI, Request, Response
 
 
 
6
  from fastapi.middleware.cors import CORSMiddleware
 
7
  from starlette.middleware.gzip import GZipMiddleware
 
8
 
9
- # Try to import python-json-logger; fall back to a tiny JSON formatter if missing.
10
  try:
11
- from pythonjsonlogger import jsonlogger # type: ignore[import-not-found]
12
  _HAS_PY_JSON_LOGGER = True
13
  except Exception:
14
  _HAS_PY_JSON_LOGGER = False
@@ -17,7 +25,6 @@ from .deps import get_settings
17
  from .core.rate_limit import RateLimiter
18
  from .core.logging import add_trace_id
19
 
20
- # ---- Fallback JSON formatter (if python-json-logger isn't available) ----
21
  class _SimpleJsonFormatter(logging.Formatter):
22
  def format(self, record: logging.LogRecord) -> str:
23
  payload = {
@@ -25,73 +32,160 @@ class _SimpleJsonFormatter(logging.Formatter):
25
  "name": record.name,
26
  "levelname": record.levelname,
27
  "message": record.getMessage(),
28
- # We attach trace_id via logger.info(..., extra={"trace_id": "..."}).
29
  "trace_id": getattr(record, "trace_id", None),
30
  }
31
  try:
32
  return json.dumps(payload, ensure_ascii=False)
33
  except Exception:
34
- # Last-ditch plain log if JSON serialization ever fails
35
  return (
36
  f'{payload["asctime"]} {payload["name"]} {payload["levelname"]} '
37
  f'{payload["message"]} trace_id={payload["trace_id"]}'
38
  )
39
 
40
- # Setup structured logging
41
- logger = logging.getLogger("matrix-ai")
42
- if not logger.handlers:
43
- logger.setLevel(logging.INFO)
44
- handler = logging.StreamHandler()
45
  if _HAS_PY_JSON_LOGGER:
46
- # Same fields you had; python-json-logger builds JSON from this format string
47
- formatter = jsonlogger.JsonFormatter(
48
  "%(asctime)s %(name)s %(levelname)s %(message)s %(trace_id)s"
49
  )
50
  else:
51
- formatter = _SimpleJsonFormatter()
52
  logging.getLogger("uvicorn.error").warning(
53
  "python-json-logger not found; using a minimal JSON formatter."
54
  )
55
- handler.setFormatter(formatter)
56
- logger.addHandler(handler)
57
 
58
  _rate_limiter = RateLimiter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def attach_middlewares(app: FastAPI) -> None:
61
- """Attaches all required middlewares to the FastAPI app."""
62
- # NOTE: We keep GZip, but your SSE endpoints already set `Content-Encoding: identity`
63
- # so they won't be buffered/compressed.
64
  app.add_middleware(GZipMiddleware, minimum_size=512)
65
-
66
  app.add_middleware(
67
  CORSMiddleware,
68
  allow_origins=["*"],
69
  allow_credentials=True,
70
  allow_methods=["*"],
71
  allow_headers=["*"],
 
72
  )
73
 
74
  @app.middleware("http")
75
  async def rate_limit_and_log_middleware(request: Request, call_next: Callable):
76
- # Attach per-request trace id
77
  add_trace_id(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  settings = get_settings()
80
- client_ip = request.client.host if request.client else "unknown"
81
-
82
- # Simple fixed-window limiter
83
- if not _rate_limiter.allow(
84
- client_ip, request.url.path, settings.limits.rate_per_min
85
- ):
86
- return Response(status_code=429, content="Rate limit exceeded")
87
-
88
- start_time = time.time()
89
- response = await call_next(request)
90
- process_time = (time.time() - start_time) * 1000.0
91
- response.headers["X-Process-Time-Ms"] = f"{process_time:.2f}"
92
-
93
- logger.info(
94
- f'"{request.method} {request.url.path}" {response.status_code}',
95
- extra={"trace_id": getattr(request.state, "trace_id", "N/A")},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
  return response
 
1
+ # app/middleware.py
2
+ from __future__ import annotations
3
+
4
  import time
5
  import logging
6
  import json
7
+ import asyncio
8
+ from typing import Callable, Optional
9
+
10
+ from anyio import EndOfStream
11
+ from fastapi import FastAPI, Request
12
  from fastapi.middleware.cors import CORSMiddleware
13
+ from starlette.responses import Response, JSONResponse
14
  from starlette.middleware.gzip import GZipMiddleware
15
+ from starlette.exceptions import ClientDisconnect
16
 
17
+ # Optional: python-json-logger for structured logs; fallback to a minimal JSON formatter.
18
  try:
19
+ from pythonjsonlogger import jsonlogger # type: ignore
20
  _HAS_PY_JSON_LOGGER = True
21
  except Exception:
22
  _HAS_PY_JSON_LOGGER = False
 
25
  from .core.rate_limit import RateLimiter
26
  from .core.logging import add_trace_id
27
 
 
28
  class _SimpleJsonFormatter(logging.Formatter):
29
  def format(self, record: logging.LogRecord) -> str:
30
  payload = {
 
32
  "name": record.name,
33
  "levelname": record.levelname,
34
  "message": record.getMessage(),
 
35
  "trace_id": getattr(record, "trace_id", None),
36
  }
37
  try:
38
  return json.dumps(payload, ensure_ascii=False)
39
  except Exception:
 
40
  return (
41
  f'{payload["asctime"]} {payload["name"]} {payload["levelname"]} '
42
  f'{payload["message"]} trace_id={payload["trace_id"]}'
43
  )
44
 
45
+ _logger = logging.getLogger("matrix-ai")
46
+ if not _logger.handlers:
47
+ _logger.setLevel(logging.INFO)
48
+ _handler = logging.StreamHandler()
 
49
  if _HAS_PY_JSON_LOGGER:
50
+ _formatter = jsonlogger.JsonFormatter(
 
51
  "%(asctime)s %(name)s %(levelname)s %(message)s %(trace_id)s"
52
  )
53
  else:
54
+ _formatter = _SimpleJsonFormatter()
55
  logging.getLogger("uvicorn.error").warning(
56
  "python-json-logger not found; using a minimal JSON formatter."
57
  )
58
+ _handler.setFormatter(_formatter)
59
+ _logger.addHandler(_handler)
60
 
61
  _rate_limiter = RateLimiter()
62
+ _SSE_PATH_SUFFIXES = ("/chat/stream", "/v1/chat/stream")
63
+ _HEALTH_PATHS = ("/health", "/livez", "/readyz")
64
+
65
+ def _client_ip(request: Request) -> str:
66
+ xff = request.headers.get("x-forwarded-for")
67
+ if xff:
68
+ return xff.split(",")[0].strip()
69
+ return request.client.host if request.client else "unknown"
70
+
71
+ def _is_sse(request: Request, response: Optional[Response] = None) -> bool:
72
+ path = request.url.path
73
+ if path.endswith(_SSE_PATH_SUFFIXES):
74
+ return True
75
+ if response is not None:
76
+ ctype = response.headers.get("content-type", "")
77
+ if ctype.startswith("text/event-stream"):
78
+ return True
79
+ accept = request.headers.get("accept", "")
80
+ return "text/event-stream" in accept
81
 
82
  def attach_middlewares(app: FastAPI) -> None:
 
 
 
83
  app.add_middleware(GZipMiddleware, minimum_size=512)
 
84
  app.add_middleware(
85
  CORSMiddleware,
86
  allow_origins=["*"],
87
  allow_credentials=True,
88
  allow_methods=["*"],
89
  allow_headers=["*"],
90
+ expose_headers=["X-Trace-Id", "X-Process-Time-Ms", "Server-Timing"],
91
  )
92
 
93
  @app.middleware("http")
94
  async def rate_limit_and_log_middleware(request: Request, call_next: Callable):
 
95
  add_trace_id(request)
96
+ trace_id = getattr(request.state, "trace_id", "N/A")
97
+
98
+ path = request.url.path
99
+ method = request.method
100
+ ua = request.headers.get("user-agent", "-")
101
+ ip = _client_ip(request)
102
+
103
+ if path in _HEALTH_PATHS:
104
+ try:
105
+ response = await call_next(request)
106
+ except Exception:
107
+ return JSONResponse({"status": "unhealthy"}, status_code=500)
108
+ response.headers.setdefault("X-Trace-Id", str(trace_id))
109
+ return response
110
 
111
  settings = get_settings()
112
+ if not _rate_limiter.allow(ip, path, settings.limits.rate_per_min):
113
+ _logger.warning(
114
+ "429 Too Many Requests from %s on %s",
115
+ ip, path, extra={"trace_id": trace_id},
116
+ )
117
+ return JSONResponse({"detail": "Too Many Requests"}, status_code=429,
118
+ headers={"X-Trace-Id": str(trace_id)})
119
+
120
+ t0 = time.time()
121
+ try:
122
+ response = await call_next(request)
123
+
124
+ # --- NEW: treat disconnects as benign (return 204) ---
125
+ except (EndOfStream, ClientDisconnect, asyncio.CancelledError):
126
+ _logger.info(
127
+ "Client disconnected from stream. Path: %s, IP: %s",
128
+ path, ip, extra={"trace_id": trace_id},
129
+ )
130
+ resp = Response(status_code=204)
131
+ resp.headers.setdefault("X-Trace-Id", str(trace_id))
132
+ return resp
133
+
134
+ except RuntimeError as e:
135
+ # Starlette sometimes wraps EndOfStream as this RuntimeError
136
+ if str(e) == "No response returned.":
137
+ _logger.info(
138
+ "Downstream produced no response (likely streaming disconnect). "
139
+ "Path: %s, IP: %s",
140
+ path, ip, extra={"trace_id": trace_id},
141
+ )
142
+ resp = Response(status_code=204)
143
+ resp.headers.setdefault("X-Trace-Id", str(trace_id))
144
+ return resp
145
+ # not a disconnect case → re-raise to be handled below
146
+ raise
147
+
148
+ except Exception as e:
149
+ _logger.exception(
150
+ "Unhandled error while processing %s %s: %s",
151
+ method, path, e, extra={"trace_id": trace_id},
152
+ )
153
+ dur_ms = (time.time() - t0) * 1000.0
154
+ return JSONResponse(
155
+ {"detail": "Internal Server Error"},
156
+ status_code=500,
157
+ headers={
158
+ "X-Trace-Id": str(trace_id),
159
+ "X-Process-Time-Ms": f"{dur_ms:.2f}",
160
+ "Server-Timing": f"app;dur={dur_ms:.2f}",
161
+ },
162
+ )
163
+
164
+ if not isinstance(response, Response):
165
+ _logger.error("Downstream returned no Response object for %s",
166
+ path, extra={"trace_id": trace_id})
167
+ return JSONResponse({"detail": "Internal Server Error"},
168
+ status_code=500,
169
+ headers={"X-Trace-Id": str(trace_id)})
170
+
171
+ sse = _is_sse(request, response)
172
+ dur_ms = (time.time() - t0) * 1000.0
173
+ response.headers.setdefault("X-Trace-Id", str(trace_id))
174
+ response.headers.setdefault("X-Process-Time-Ms", f"{dur_ms:.2f}")
175
+ response.headers.setdefault("Server-Timing", f"app;dur={dur_ms:.2f}")
176
+
177
+ if sse:
178
+ response.headers.setdefault("Cache-Control", "no-cache")
179
+ _logger.info(
180
+ '"%s %s" %s (SSE) ip=%s ua="%s" %.2fms',
181
+ method, path, response.status_code, ip, ua, dur_ms,
182
+ extra={"trace_id": trace_id},
183
+ )
184
+ return response
185
+
186
+ _logger.info(
187
+ '"%s %s" %s ip=%s ua="%s" %.2fms',
188
+ method, path, response.status_code, ip, ua, dur_ms,
189
+ extra={"trace_id": trace_id},
190
  )
191
  return response