fix: replace BaseHTTPMiddleware with pure ASGI middleware for WebSocket compat
This commit is contained in:
parent
de8a7798c3
commit
6e6bb86bcd
1 changed files with 70 additions and 17 deletions
|
|
@ -1011,28 +1011,81 @@ async def serve_static_root_files():
|
|||
return FileResponse(str(STATIC_DIR / "vite.svg"))
|
||||
|
||||
|
||||
# SPA fallback via middleware — this avoids the catch-all route problem
|
||||
# that breaks WebSocket routing in some Starlette versions
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
# SPA fallback via pure ASGI middleware — BaseHTTPMiddleware breaks WebSocket
|
||||
# routing in Starlette 0.27.x, so we use a raw ASGI wrapper instead.
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class SPAFallbackMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: StarletteRequest, call_next):
|
||||
response = await call_next(request)
|
||||
# If a non-API GET request returned 404, serve index.html for SPA
|
||||
if (response.status_code == 404
|
||||
and request.method == "GET"
|
||||
and not request.url.path.startswith("/api")
|
||||
and not request.url.path.startswith("/docs")
|
||||
and not request.url.path.startswith("/openapi")
|
||||
and not request.url.path.startswith("/assets")
|
||||
and "websocket" not in request.headers.get("upgrade", "").lower()):
|
||||
class SPAFallbackMiddleware:
|
||||
"""Pure ASGI middleware that serves index.html for unknown GET routes.
|
||||
Unlike BaseHTTPMiddleware, this correctly passes WebSocket connections
|
||||
through without wrapping them."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# Pass WebSocket and non-HTTP connections straight through
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", "GET")
|
||||
|
||||
# Only intercept GET requests for non-API, non-asset paths
|
||||
skip_prefixes = ("/api", "/docs", "/openapi", "/assets")
|
||||
if method != "GET" or any(path.startswith(p) for p in skip_prefixes):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Capture the response status to decide if we should serve index.html
|
||||
response_started = False
|
||||
initial_status = None
|
||||
|
||||
async def send_wrapper(message):
|
||||
nonlocal response_started, initial_status
|
||||
if message["type"] == "http.response.start":
|
||||
initial_status = message.get("status", 200)
|
||||
if initial_status != 404:
|
||||
response_started = True
|
||||
await send(message)
|
||||
# If 404, we hold off — we'll serve index.html instead
|
||||
elif message["type"] == "http.response.body":
|
||||
if response_started:
|
||||
await send(message)
|
||||
# If not started (was 404), we drop the body
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
|
||||
# If the app returned 404, serve index.html for SPA routing
|
||||
if not response_started and initial_status == 404:
|
||||
index = STATIC_DIR / "index.html"
|
||||
if index.exists():
|
||||
return FileResponse(str(index))
|
||||
return response
|
||||
body = index.read_bytes()
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [
|
||||
[b"content-type", b"text/html; charset=utf-8"],
|
||||
[b"content-length", str(len(body)).encode()],
|
||||
],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": body,
|
||||
})
|
||||
else:
|
||||
# No index.html, pass through the 404
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 404,
|
||||
"headers": [[b"content-type", b"text/plain"]],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b"Not Found",
|
||||
})
|
||||
|
||||
|
||||
app.add_middleware(SPAFallbackMiddleware)
|
||||
|
|
|
|||
Reference in a new issue