diff --git a/backend/main.py b/backend/main.py index 55c0c1f..e57a80f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1011,28 +1011,81 @@ async def serve_static_root_files(): return FileResponse(str(STATIC_DIR / "vite.svg")) -# SPA fallback via middleware — this avoids the catch-all route problem -# that breaks WebSocket routing in some Starlette versions -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request as StarletteRequest +# SPA fallback via pure ASGI middleware — BaseHTTPMiddleware breaks WebSocket +# routing in Starlette 0.27.x, so we use a raw ASGI wrapper instead. from starlette.responses import Response -class SPAFallbackMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: StarletteRequest, call_next): - response = await call_next(request) - # If a non-API GET request returned 404, serve index.html for SPA - if (response.status_code == 404 - and request.method == "GET" - and not request.url.path.startswith("/api") - and not request.url.path.startswith("/docs") - and not request.url.path.startswith("/openapi") - and not request.url.path.startswith("/assets") - and "websocket" not in request.headers.get("upgrade", "").lower()): +class SPAFallbackMiddleware: + """Pure ASGI middleware that serves index.html for unknown GET routes. + Unlike BaseHTTPMiddleware, this correctly passes WebSocket connections + through without wrapping them.""" + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + # Pass WebSocket and non-HTTP connections straight through + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + path = scope.get("path", "") + method = scope.get("method", "GET") + + # Only intercept GET requests for non-API, non-asset paths + skip_prefixes = ("/api", "/docs", "/openapi", "/assets") + if method != "GET" or any(path.startswith(p) for p in skip_prefixes): + await self.app(scope, receive, send) + return + + # Capture the response status to decide if we should serve index.html + response_started = False + initial_status = None + + async def send_wrapper(message): + nonlocal response_started, initial_status + if message["type"] == "http.response.start": + initial_status = message.get("status", 200) + if initial_status != 404: + response_started = True + await send(message) + # If 404, we hold off — we'll serve index.html instead + elif message["type"] == "http.response.body": + if response_started: + await send(message) + # If not started (was 404), we drop the body + + await self.app(scope, receive, send_wrapper) + + # If the app returned 404, serve index.html for SPA routing + if not response_started and initial_status == 404: index = STATIC_DIR / "index.html" if index.exists(): - return FileResponse(str(index)) - return response + body = index.read_bytes() + await send({ + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/html; charset=utf-8"], + [b"content-length", str(len(body)).encode()], + ], + }) + await send({ + "type": "http.response.body", + "body": body, + }) + else: + # No index.html, pass through the 404 + await send({ + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"text/plain"]], + }) + await send({ + "type": "http.response.body", + "body": b"Not Found", + }) app.add_middleware(SPAFallbackMiddleware)