fix: replace BaseHTTPMiddleware with pure ASGI middleware for WebSocket compat

This commit is contained in:
Zac Gaetano 2026-04-05 11:51:01 -04:00
parent de8a7798c3
commit 6e6bb86bcd

View file

@ -1011,28 +1011,81 @@ async def serve_static_root_files():
return FileResponse(str(STATIC_DIR / "vite.svg"))
# SPA fallback via middleware — this avoids the catch-all route problem
# that breaks WebSocket routing in some Starlette versions
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request as StarletteRequest
# SPA fallback via pure ASGI middleware — BaseHTTPMiddleware breaks WebSocket
# routing in Starlette 0.27.x, so we use a raw ASGI wrapper instead.
from starlette.responses import Response
class SPAFallbackMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: StarletteRequest, call_next):
response = await call_next(request)
# If a non-API GET request returned 404, serve index.html for SPA
if (response.status_code == 404
and request.method == "GET"
and not request.url.path.startswith("/api")
and not request.url.path.startswith("/docs")
and not request.url.path.startswith("/openapi")
and not request.url.path.startswith("/assets")
and "websocket" not in request.headers.get("upgrade", "").lower()):
class SPAFallbackMiddleware:
"""Pure ASGI middleware that serves index.html for unknown GET routes.
Unlike BaseHTTPMiddleware, this correctly passes WebSocket connections
through without wrapping them."""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
# Pass WebSocket and non-HTTP connections straight through
if scope["type"] != "http":
await self.app(scope, receive, send)
return
path = scope.get("path", "")
method = scope.get("method", "GET")
# Only intercept GET requests for non-API, non-asset paths
skip_prefixes = ("/api", "/docs", "/openapi", "/assets")
if method != "GET" or any(path.startswith(p) for p in skip_prefixes):
await self.app(scope, receive, send)
return
# Capture the response status to decide if we should serve index.html
response_started = False
initial_status = None
async def send_wrapper(message):
nonlocal response_started, initial_status
if message["type"] == "http.response.start":
initial_status = message.get("status", 200)
if initial_status != 404:
response_started = True
await send(message)
# If 404, we hold off — we'll serve index.html instead
elif message["type"] == "http.response.body":
if response_started:
await send(message)
# If not started (was 404), we drop the body
await self.app(scope, receive, send_wrapper)
# If the app returned 404, serve index.html for SPA routing
if not response_started and initial_status == 404:
index = STATIC_DIR / "index.html"
if index.exists():
return FileResponse(str(index))
return response
body = index.read_bytes()
await send({
"type": "http.response.start",
"status": 200,
"headers": [
[b"content-type", b"text/html; charset=utf-8"],
[b"content-length", str(len(body)).encode()],
],
})
await send({
"type": "http.response.body",
"body": body,
})
else:
# No index.html, pass through the 404
await send({
"type": "http.response.start",
"status": 404,
"headers": [[b"content-type", b"text/plain"]],
})
await send({
"type": "http.response.body",
"body": b"Not Found",
})
app.add_middleware(SPAFallbackMiddleware)