mcp-servers/gateway-proxy/gateway_proxy.py

1222 lines
49 KiB
Python
Raw Normal View History

2026-03-31 15:33:37 -04:00
"""
MCP Gateway Proxy with OAuth 2.1
=================================
Aggregates multiple MCP servers behind a single Streamable HTTP endpoint.
Implements a self-contained OAuth 2.1 provider compatible with claude.ai:
- RFC 8414 Authorization Server Metadata
- RFC 9728 Protected Resource Metadata
- RFC 7591 Dynamic Client Registration
- PKCE (S256) per OAuth 2.1
- Authorization Code Grant with refresh tokens
"""
import asyncio
import base64
import hashlib
import html
import json
import logging
import os
import secrets
import time
import uuid
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlencode
import httpx
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse
from starlette.routing import Route
import uvicorn
from user_routes import (
create_user, list_users, get_user, delete_user,
toggle_user, generate_api_key, revoke_api_key, set_mcp_access
)
from user_dashboard_ui import user_management_dashboard as user_dashboard
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("mcp-gateway")
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net")
GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "")
ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600"))
REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000"))
AUTH_CODE_TTL = 600
STATIC_API_KEY = os.environ.get("GATEWAY_STATIC_API_KEY", "")
# ---------------------------------------------------------------------------
# In-memory stores
# ---------------------------------------------------------------------------
REGISTERED_CLIENTS: dict[str, dict] = {}
AUTH_CODES: dict[str, dict] = {}
ACCESS_TOKENS: dict[str, dict] = {}
REFRESH_TOKENS: dict[str, dict] = {}
PENDING_AUTH: dict[str, dict] = {}
# Browser session store (cookie-based login for /admin and /dashboard)
GUI_SESSIONS: dict[str, float] = {} # token -> expires_at
GUI_SESSION_TTL = 8 * 3600 # 8 hours
def _create_gui_session() -> str:
token = secrets.token_hex(32)
GUI_SESSIONS[token] = time.time() + GUI_SESSION_TTL
return token
def _validate_gui_session(request: Request) -> bool:
token = request.cookies.get("mcp_session")
if not token:
return False
expires = GUI_SESSIONS.get(token, 0)
if expires < time.time():
GUI_SESSIONS.pop(token, None)
return False
return True
GUI_LOGIN_HTML = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MCP Gateway Login</title>
<style>
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0f172a; color: #e2e8f0; min-height: 100vh;
display: flex; align-items: center; justify-content: center; }}
.card {{ background: #1e293b; border-radius: 16px; padding: 40px;
max-width: 380px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); }}
h1 {{ font-size: 22px; margin-bottom: 6px; color: #f8fafc; }}
.subtitle {{ color: #94a3b8; margin-bottom: 28px; font-size: 14px; }}
label {{ display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; }}
input[type=password] {{ width: 100%; padding: 12px 16px; border-radius: 8px;
border: 1px solid #475569; background: #0f172a; color: #f8fafc;
font-size: 16px; margin-bottom: 20px; outline: none; }}
input[type=password]:focus {{ border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); }}
button {{ width: 100%; padding: 12px; border-radius: 8px; border: none; font-size: 15px;
font-weight: 600; cursor: pointer; background: #38bdf8; color: #0f172a; }}
button:hover {{ background: #7dd3fc; }}
.error {{ background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px;
margin-bottom: 16px; font-size: 13px; }}
.lock {{ font-size: 32px; margin-bottom: 16px; }}
</style>
</head>
<body>
<div class="card">
<div class="lock">&#128274;</div>
<h1>MCP Gateway</h1>
<p class="subtitle">Enter your gateway password to continue</p>
{error_html}
<form method="POST" action="/gui-login">
<input type="hidden" name="next" value="{next_url}" />
<label for="password">Password</label>
<input type="password" name="password" id="password" placeholder="Gateway password" autofocus />
<button type="submit">Sign In</button>
</form>
</div>
</body>
</html>"""
2026-03-31 15:33:37 -04:00
def _hash(value: str) -> str:
return hashlib.sha256(value.encode()).hexdigest()
def _clean_expired():
now = time.time()
for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS):
expired = [k for k, v in store.items() if v.get("expires_at", 0) < now]
for k in expired:
del store[k]
# ---------------------------------------------------------------------------
# Backend MCP aggregation
# ---------------------------------------------------------------------------
def load_backends() -> dict[str, str]:
backends = {}
for key, value in os.environ.items():
if key.startswith("MCP_BACKEND_"):
name = key[len("MCP_BACKEND_"):].lower()
backends[name] = value
logger.info(f"Backend configured: {name} -> {value}")
return backends
BACKENDS: dict[str, str] = load_backends()
TOOL_REGISTRY: dict[str, str] = {}
TOOL_DEFINITIONS: dict[str, dict] = {}
BACKEND_SESSIONS: dict[str, str | None] = {}
GATEWAY_SESSIONS: dict[str, bool] = {}
def parse_sse_response(text: str) -> dict | None:
for line in text.strip().split("\n"):
line = line.strip()
if line.startswith("data: "):
try:
return json.loads(line[6:])
except json.JSONDecodeError:
continue
return None
async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict:
payload = {"jsonrpc": "2.0", "method": method, "id": request_id}
if params is not None:
payload["params"] = params
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Host": "localhost",
}
if session_id:
headers["Mcp-Session-Id"] = session_id
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(backend_url, json=payload, headers=headers)
new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id")
content_type = resp.headers.get("content-type", "")
if resp.status_code in (200, 201):
try:
if "text/event-stream" in content_type:
parsed = parse_sse_response(resp.text)
return {"result": parsed, "session_id": new_session}
else:
return {"result": resp.json(), "session_id": new_session}
except Exception:
return {"result": None, "session_id": new_session}
elif resp.status_code == 202:
return {"result": None, "session_id": new_session}
else:
logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}")
return {"result": None, "session_id": new_session}
def _normalize_schema(obj: Any) -> None:
"""Recursively normalize JSON schemas so mcpo can parse them.
Fixes:
- type as list (e.g. ["string","null"]) -> "string"
- items as list (tuple validation) -> {"type": "string"}
- anyOf/oneOf containing non-dict entries
"""
if isinstance(obj, dict):
if isinstance(obj.get("type"), list):
obj["type"] = "string"
# Convert tuple-style items (list) to simple dict schema
if "items" in obj and isinstance(obj["items"], list):
obj["items"] = {"type": "string"}
# Clean anyOf/oneOf entries that might contain non-dicts
for key in ("anyOf", "oneOf"):
if key in obj and isinstance(obj[key], list):
obj[key] = [x for x in obj[key] if isinstance(x, dict)]
if len(obj[key]) == 1:
# Collapse single-option anyOf/oneOf
obj.update(obj.pop(key)[0])
elif len(obj[key]) == 0:
del obj[key]
obj["type"] = "string"
for v in obj.values():
_normalize_schema(v)
elif isinstance(obj, list):
for item in obj:
_normalize_schema(item)
async def initialize_backend(name: str, url: str) -> list[dict]:
logger.info(f"Initializing backend: {name} at {url}")
for attempt in range(3):
try:
init_result = await mcp_request(url, "initialize", {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"},
})
session_id = init_result.get("session_id")
BACKEND_SESSIONS[name] = session_id
logger.info(f" {name}: initialized (session: {session_id})")
await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id)
tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id)
tools_data = tools_result.get("result", {})
if isinstance(tools_data, dict):
if "result" in tools_data:
tools_data = tools_data["result"]
tools = tools_data.get("tools", [])
else:
tools = []
logger.info(f" {name}: discovered {len(tools)} tools")
for tool in tools:
original_name = tool["name"]
prefixed_name = f"{name}_{original_name}"
tool["name"] = prefixed_name
# Recursively normalize list-type schemas (e.g. ["string","null"]) to "string"
# so mcpo can parse the schema without crashing
_normalize_schema(tool.get("inputSchema", {}))
TOOL_REGISTRY[prefixed_name] = name
TOOL_DEFINITIONS[prefixed_name] = tool
return tools
except Exception as e:
if attempt < 2:
logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...")
await asyncio.sleep(5)
else:
logger.error(f" {name}: failed to initialize after 3 attempts - {e}")
return []
return []
async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict:
url = BACKENDS[backend_name]
session_id = BACKEND_SESSIONS.get(backend_name)
prefix = f"{backend_name}_"
original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name
if not session_id:
await initialize_backend(backend_name, url)
session_id = BACKEND_SESSIONS.get(backend_name)
result = await mcp_request(
url, "tools/call",
{"name": original_name, "arguments": arguments},
request_id=request_id,
session_id=session_id,
)
response_data = result.get("result", {})
if isinstance(response_data, dict) and response_data.get("error"):
error_code = response_data["error"].get("code", 0)
if error_code in (-32600, -32601):
logger.info(f"Re-initializing {backend_name} after error {error_code}")
await initialize_backend(backend_name, url)
session_id = BACKEND_SESSIONS.get(backend_name)
result = await mcp_request(
url, "tools/call",
{"name": original_name, "arguments": arguments},
request_id=request_id,
session_id=session_id,
)
response_data = result.get("result", {})
return response_data
# ---------------------------------------------------------------------------
# OAuth 2.1: Token validation
# ---------------------------------------------------------------------------
def validate_bearer_token(request: Request) -> dict | None:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:]
# Check static API key first (persistent, survives restarts)
if STATIC_API_KEY and token == STATIC_API_KEY:
return {"client_id": "static", "scope": "mcp:tools", "user": "static"}
token_hash = _hash(token)
info = ACCESS_TOKENS.get(token_hash)
if not info:
return None
if info["expires_at"] < time.time():
del ACCESS_TOKENS[token_hash]
return None
return info
# ---------------------------------------------------------------------------
# Consent page HTML template
# ---------------------------------------------------------------------------
CONSENT_PAGE_CSS = """
* { margin: 0; padding: 0; box-sizing: border-box; }
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0f172a; color: #e2e8f0; min-height: 100vh;
display: flex; align-items: center; justify-content: center; }
.card { background: #1e293b; border-radius: 16px; padding: 40px;
max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); }
h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; }
.subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; }
.client { background: #334155; border-radius: 8px; padding: 12px 16px;
margin-bottom: 24px; font-size: 14px; }
.client strong { color: #38bdf8; }
label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; }
input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px;
border: 1px solid #475569; background: #0f172a; color: #f8fafc;
font-size: 16px; margin-bottom: 20px; outline: none; }
input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); }
.buttons { display: flex; gap: 12px; }
button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px;
font-weight: 600; cursor: pointer; transition: all 0.15s; }
.approve { background: #38bdf8; color: #0f172a; }
.approve:hover { background: #7dd3fc; }
.deny { background: #334155; color: #94a3b8; }
.deny:hover { background: #475569; color: #e2e8f0; }
.scope { font-size: 13px; color: #64748b; margin-bottom: 16px; }
.error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px;
margin-bottom: 16px; font-size: 13px; }
"""
def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str:
error_html = f'<div class="error">{html.escape(error_msg)}</div>' if error_msg else ""
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MCP Gateway Authorize</title>
<style>{CONSENT_PAGE_CSS}</style>
</head>
<body>
<div class="card">
<h1>MCP Gateway</h1>
<p class="subtitle">Authorization Request</p>
<div class="client"><strong>{html.escape(client_name)}</strong> wants access to your MCP tools.</div>
<p class="scope">Scope: <code>{html.escape(scope)}</code></p>
{error_html}
<form method="POST" action="/oauth/authorize">
<input type="hidden" name="internal_state" value="{internal_state}" />
<label for="password">Gateway Password</label>
<input type="password" name="password" id="password" placeholder="Enter your password" autofocus />
<div class="buttons">
<button type="submit" name="action" value="approve" class="approve">Authorize</button>
<button type="submit" name="action" value="deny" class="deny">Deny</button>
</div>
</form>
</div>
</body>
</html>"""
# ---------------------------------------------------------------------------
# OAuth 2.1 Endpoints
# ---------------------------------------------------------------------------
async def well_known_protected_resource(request: Request) -> JSONResponse:
return JSONResponse({
"resource": ISSUER_URL,
"authorization_servers": [ISSUER_URL],
"scopes_supported": ["mcp:tools"],
"bearer_methods_supported": ["header"],
})
async def well_known_oauth_authorization_server(request: Request) -> JSONResponse:
return JSONResponse({
"issuer": ISSUER_URL,
"authorization_endpoint": f"{ISSUER_URL}/authorize",
"token_endpoint": f"{ISSUER_URL}/token",
"registration_endpoint": f"{ISSUER_URL}/register",
"scopes_supported": ["mcp:tools"],
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
"code_challenge_methods_supported": ["S256"],
"service_documentation": f"{ISSUER_URL}/health",
})
async def oauth_register(request: Request) -> JSONResponse:
try:
body = await request.json()
except Exception:
return JSONResponse({"error": "invalid_request"}, status_code=400)
client_id = str(uuid.uuid4())
client_secret = secrets.token_urlsafe(48)
client_info = {
"client_id": client_id,
"client_secret": client_secret,
"client_name": body.get("client_name", "Unknown Client"),
"redirect_uris": body.get("redirect_uris", []),
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
"response_types": body.get("response_types", ["code"]),
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_post"),
}
REGISTERED_CLIENTS[client_id] = client_info
logger.info(f"DCR: registered client '{client_info['client_name']}' as {client_id}")
return JSONResponse(client_info, status_code=201)
async def oauth_authorize(request: Request) -> Response:
_clean_expired()
if request.method == "GET":
params = request.query_params
client_id = params.get("client_id", "")
redirect_uri = params.get("redirect_uri", "")
state = params.get("state", "")
scope = params.get("scope", "mcp:tools")
code_challenge = params.get("code_challenge", "")
code_challenge_method = params.get("code_challenge_method", "S256")
response_type = params.get("response_type", "code")
if response_type != "code":
return JSONResponse({"error": "unsupported_response_type"}, status_code=400)
client = REGISTERED_CLIENTS.get(client_id)
if not client:
# Auto-register unknown clients at authorize time (handles clients that
# skip DCR or whose registration was lost on gateway restart)
client = {
"client_id": client_id,
"client_secret": secrets.token_urlsafe(48),
"client_name": f"auto-{client_id[:8]}",
"redirect_uris": [redirect_uri] if redirect_uri else [],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
REGISTERED_CLIENTS[client_id] = client
logger.info(f"OAuth: auto-registered unknown client {client_id} at authorize")
internal_state = secrets.token_urlsafe(32)
PENDING_AUTH[internal_state] = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"scope": scope,
"state": state,
}
client_name = client.get("client_name", "Unknown Client")
return HTMLResponse(render_consent_page(client_name, scope, internal_state))
if request.method == "POST":
form = await request.form()
internal_state = form.get("internal_state", "")
password = form.get("password", "")
action = form.get("action", "")
pending = PENDING_AUTH.pop(internal_state, None)
if not pending:
return HTMLResponse("<h1>Authorization expired</h1><p>Please try connecting again.</p>", status_code=400)
redirect_uri = pending["redirect_uri"]
state = pending["state"]
if action != "approve":
qs = urlencode({"error": "access_denied", "state": state})
return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302)
if not GATEWAY_PASSWORD:
return HTMLResponse("<h1>Server misconfigured</h1><p>OAUTH_PASSWORD not set.</p>", status_code=500)
if password != GATEWAY_PASSWORD:
PENDING_AUTH[internal_state] = pending
client = REGISTERED_CLIENTS.get(pending["client_id"], {})
client_name = client.get("client_name", "Unknown Client")
return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again."))
code = secrets.token_urlsafe(48)
AUTH_CODES[code] = {
"client_id": pending["client_id"],
"redirect_uri": redirect_uri,
"code_challenge": pending["code_challenge"],
"code_challenge_method": pending["code_challenge_method"],
"scope": pending["scope"],
"expires_at": time.time() + AUTH_CODE_TTL,
"user": "owner",
}
qs = urlencode({"code": code, "state": state})
logger.info(f"OAuth: issued auth code for client {pending['client_id']}")
return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302)
return JSONResponse({"error": "method_not_allowed"}, status_code=405)
async def oauth_token(request: Request) -> JSONResponse:
_clean_expired()
content_type = request.headers.get("content-type", "")
if "application/json" in content_type:
try:
body = await request.json()
except Exception:
body = {}
else:
form = await request.form()
body = dict(form)
grant_type = body.get("grant_type", "")
if grant_type == "authorization_code":
code = body.get("code", "")
client_id = body.get("client_id", "")
code_verifier = body.get("code_verifier", "")
code_info = AUTH_CODES.pop(code, None)
if not code_info:
return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400)
if code_info["expires_at"] < time.time():
return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400)
if code_info["client_id"] != client_id:
return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400)
if code_info.get("code_challenge"):
if not code_verifier:
return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400)
expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
if expected != code_info["code_challenge"]:
return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400)
access_token = secrets.token_urlsafe(48)
refresh_token = secrets.token_urlsafe(48)
access_hash = _hash(access_token)
refresh_hash = _hash(refresh_token)
ACCESS_TOKENS[access_hash] = {
"client_id": client_id,
"scope": code_info["scope"],
"expires_at": time.time() + ACCESS_TOKEN_TTL,
"user": code_info["user"],
}
REFRESH_TOKENS[refresh_hash] = {
"client_id": client_id,
"scope": code_info["scope"],
"expires_at": time.time() + REFRESH_TOKEN_TTL,
"user": code_info["user"],
"access_token_hash": access_hash,
}
logger.info(f"OAuth: issued access token for client {client_id}")
return JSONResponse({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": ACCESS_TOKEN_TTL,
"refresh_token": refresh_token,
"scope": code_info["scope"],
})
elif grant_type == "refresh_token":
refresh_token_val = body.get("refresh_token", "")
client_id = body.get("client_id", "")
refresh_hash = _hash(refresh_token_val)
refresh_info = REFRESH_TOKENS.get(refresh_hash)
if not refresh_info:
return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400)
if refresh_info["expires_at"] < time.time():
del REFRESH_TOKENS[refresh_hash]
return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400)
if refresh_info["client_id"] != client_id:
return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400)
old_access_hash = refresh_info.get("access_token_hash")
if old_access_hash and old_access_hash in ACCESS_TOKENS:
del ACCESS_TOKENS[old_access_hash]
del REFRESH_TOKENS[refresh_hash]
new_access_token = secrets.token_urlsafe(48)
new_refresh_token = secrets.token_urlsafe(48)
new_access_hash = _hash(new_access_token)
new_refresh_hash = _hash(new_refresh_token)
ACCESS_TOKENS[new_access_hash] = {
"client_id": client_id,
"scope": refresh_info["scope"],
"expires_at": time.time() + ACCESS_TOKEN_TTL,
"user": refresh_info["user"],
}
REFRESH_TOKENS[new_refresh_hash] = {
"client_id": client_id,
"scope": refresh_info["scope"],
"expires_at": time.time() + REFRESH_TOKEN_TTL,
"user": refresh_info["user"],
"access_token_hash": new_access_hash,
}
logger.info(f"OAuth: refreshed token for client {client_id}")
return JSONResponse({
"access_token": new_access_token,
"token_type": "Bearer",
"expires_in": ACCESS_TOKEN_TTL,
"refresh_token": new_refresh_token,
"scope": refresh_info["scope"],
})
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
# ---------------------------------------------------------------------------
# MCP Endpoint (OAuth-protected)
# ---------------------------------------------------------------------------
async def handle_mcp(request: Request) -> Response:
if request.method == "HEAD":
return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"})
if request.method == "GET":
token_info = validate_bearer_token(request)
if not token_info:
return Response(
status_code=401,
headers={
"WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"',
"Content-Type": "application/json",
},
media_type="application/json",
content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}),
)
# Authenticated GET — return empty SSE stream (mcpo uses this for server-sent events)
async def empty_sse():
yield b": ping\n\n"
return StreamingResponse(empty_sse(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"})
if request.method == "DELETE":
session_id = request.headers.get("Mcp-Session-Id")
if session_id and session_id in GATEWAY_SESSIONS:
del GATEWAY_SESSIONS[session_id]
return Response(status_code=200)
# POST — require Bearer token
token_info = validate_bearer_token(request)
if not token_info:
return Response(
status_code=401,
headers={
"WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"',
"Content-Type": "application/json",
},
media_type="application/json",
content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}),
)
try:
body = await request.json()
except Exception:
return JSONResponse(
{"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None},
status_code=400,
)
method = body.get("method", "")
params = body.get("params", {})
req_id = body.get("id")
if method == "initialize":
session_id = str(uuid.uuid4())
GATEWAY_SESSIONS[session_id] = True
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"},
},
})
response.headers["Mcp-Session-Id"] = session_id
return response
if method == "notifications/initialized":
return Response(status_code=202)
if method == "tools/list":
tools = list(TOOL_DEFINITIONS.values())
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": {"tools": tools},
})
session_id = request.headers.get("Mcp-Session-Id")
if session_id:
response.headers["Mcp-Session-Id"] = session_id
return response
if method == "tools/call":
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
backend_name = TOOL_REGISTRY.get(tool_name)
if not backend_name:
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
})
try:
result = await forward_tool_call(backend_name, tool_name, arguments, req_id)
if isinstance(result, dict) and "jsonrpc" in result:
response = JSONResponse(result)
elif isinstance(result, dict) and "result" in result:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": result["result"],
})
elif isinstance(result, dict) and "error" in result:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": result["error"],
})
else:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": result,
})
session_id = request.headers.get("Mcp-Session-Id")
if session_id:
response.headers["Mcp-Session-Id"] = session_id
return response
except Exception as e:
logger.error(f"Tool call failed: {tool_name} - {e}")
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32603, "message": str(e)},
})
if method == "ping":
return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}})
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Method not found: {method}"},
})
# ---------------------------------------------------------------------------
# GUI Login (password → session cookie for /admin and /dashboard)
# ---------------------------------------------------------------------------
async def gui_login_get(request: Request) -> HTMLResponse:
"""GET /gui-login — show login form (redirect here when session missing)."""
next_url = request.query_params.get("next", "/admin")
page = GUI_LOGIN_HTML.format(error_html="", next_url=html.escape(next_url))
return HTMLResponse(page)
async def gui_login_post(request: Request) -> Response:
"""POST /gui-login — validate password, set session cookie, redirect."""
form = await request.form()
password = form.get("password", "")
next_url = form.get("next", "/admin")
# Sanitise redirect target — only allow relative paths on this host
if not next_url.startswith("/") or next_url.startswith("//"):
next_url = "/admin"
if not GATEWAY_PASSWORD or password != GATEWAY_PASSWORD:
error_html = '<div class="error">Incorrect password. Please try again.</div>'
page = GUI_LOGIN_HTML.format(error_html=error_html, next_url=html.escape(next_url))
return HTMLResponse(page, status_code=401)
session_token = _create_gui_session()
response = RedirectResponse(next_url, status_code=303)
response.set_cookie(
"mcp_session", session_token,
max_age=GUI_SESSION_TTL,
httponly=True,
samesite="lax",
secure=True,
)
return response
async def gui_logout(request: Request) -> Response:
"""GET /gui-logout — clear session cookie."""
token = request.cookies.get("mcp_session")
if token:
GUI_SESSIONS.pop(token, None)
response = RedirectResponse("/gui-login?next=/admin", status_code=303)
response.delete_cookie("mcp_session")
return response
2026-03-31 15:33:37 -04:00
# ---------------------------------------------------------------------------
# Health / Status
# ---------------------------------------------------------------------------
async def health(request: Request) -> JSONResponse:
return JSONResponse({
"status": "healthy",
"oauth": "enabled",
})
async def status(request: Request) -> JSONResponse:
token_info = validate_bearer_token(request)
if not token_info:
return JSONResponse({"error": "unauthorized"}, status_code=401)
return JSONResponse({
"backends": {
name: {
"url": url,
"tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]),
"session": BACKEND_SESSIONS.get(name),
}
for name, url in BACKENDS.items()
},
"total_tools": len(TOOL_DEFINITIONS),
"tool_list": sorted(TOOL_DEFINITIONS.keys()),
})
# ---------------------------------------------------------------------------
# Dashboard
# ---------------------------------------------------------------------------
DISPLAY_NAMES = {
"erpnext": "ERPNext",
"truenas": "TrueNAS",
"homeassistant": "Home Assistant",
"wave": "Wave Finance",
"linkedin": "LinkedIn",
}
async def probe_backend(name: str, url: str) -> dict:
"""Live-probe a backend: initialize it and count its tools in real time."""
start = time.time()
try:
result = await mcp_request(url, "initialize", {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "mcp-gateway-dashboard", "version": "1.0.0"},
})
session_id = result.get("session_id")
tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id)
tools_data = tools_result.get("result", {})
if isinstance(tools_data, dict):
if "result" in tools_data:
tools_data = tools_data["result"]
tools = tools_data.get("tools", [])
else:
tools = []
elapsed = round((time.time() - start) * 1000)
return {
"status": "healthy",
"toolCount": len(tools),
"responseTime": elapsed,
}
except Exception as e:
elapsed = round((time.time() - start) * 1000)
logger.warning(f"Dashboard probe failed for {name}: {e}")
# Fall back to cached registry values if live probe fails
cached_tools = len([t for t, b in TOOL_REGISTRY.items() if b == name])
return {
"status": "healthy" if cached_tools > 0 else "unhealthy",
"toolCount": cached_tools,
"responseTime": elapsed,
"note": "cached",
}
async def dashboard_status(request: Request) -> JSONResponse:
"""Auth-protected endpoint — live-probes all backends."""
if not _validate_gui_session(request) and not validate_bearer_token(request):
return JSONResponse({"error": "unauthorized"}, status_code=401)
2026-03-31 15:33:37 -04:00
probes = await asyncio.gather(
*[probe_backend(name, url) for name, url in BACKENDS.items()],
return_exceptions=True
)
services = []
for (name, url), probe in zip(BACKENDS.items(), probes):
if isinstance(probe, Exception):
probe = {"status": "unhealthy", "toolCount": 0, "responseTime": None}
services.append({
"name": DISPLAY_NAMES.get(name, name.capitalize()),
"key": name,
"url": url,
"status": probe["status"],
"toolCount": probe["toolCount"],
"responseTime": probe.get("responseTime"),
"lastCheck": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
})
return JSONResponse({
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"services": services,
"summary": {
"total": len(services),
"healthy": len([s for s in services if s["status"] == "healthy"]),
"unhealthy": len([s for s in services if s["status"] == "unhealthy"]),
"totalTools": sum(s["toolCount"] for s in services),
}
})
DASHBOARD_HTML = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MCP Gateway Dashboard</title>
<script src="https://cdn.tailwindcss.com"></script>
<style>
@keyframes spin { to { transform: rotate(360deg); } }
.spin { animation: spin 1s linear infinite; }
</style>
</head>
<body class="min-h-screen bg-slate-900 text-slate-100 p-8">
<div class="max-w-6xl mx-auto">
<div class="flex items-center justify-between mb-8">
<div>
<h1 class="text-3xl font-bold text-white mb-1">MCP Gateway</h1>
<p class="text-slate-400 text-sm">Service Status Dashboard</p>
</div>
<button onclick="load()" id="refresh-btn"
class="flex items-center gap-2 px-4 py-2 bg-blue-600 hover:bg-blue-500 rounded-lg text-sm font-medium transition-colors">
<svg id="refresh-icon" class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"/>
</svg>
Refresh
</button>
</div>
<!-- Summary cards -->
<div class="grid grid-cols-3 gap-4 mb-8" id="summary">
<div class="bg-slate-800 rounded-xl p-5 border border-slate-700">
<p class="text-slate-400 text-xs font-medium uppercase tracking-wide mb-1">Total Services</p>
<p class="text-3xl font-bold" id="total-services"></p>
</div>
<div class="bg-slate-800 rounded-xl p-5 border border-slate-700">
<p class="text-slate-400 text-xs font-medium uppercase tracking-wide mb-1">Total Tools</p>
<p class="text-3xl font-bold text-green-400" id="total-tools"></p>
</div>
<div class="bg-slate-800 rounded-xl p-5 border border-slate-700">
<p class="text-slate-400 text-xs font-medium uppercase tracking-wide mb-1">Healthy</p>
<p class="text-3xl font-bold text-green-400" id="healthy-count"></p>
</div>
</div>
<p class="text-xs text-slate-500 mb-4" id="last-updated"></p>
<!-- Service cards -->
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-2 gap-4" id="services"></div>
<div id="error-msg" class="hidden bg-red-900/50 border border-red-700 rounded-xl p-4 text-red-300 text-sm"></div>
</div>
<script>
async function load() {
const icon = document.getElementById('refresh-icon');
icon.classList.add('spin');
try {
const r = await fetch('/dashboard/status');
const data = await r.json();
renderSummary(data.summary);
renderServices(data.services);
document.getElementById('last-updated').textContent =
'Last updated: ' + new Date().toLocaleTimeString();
document.getElementById('error-msg').classList.add('hidden');
} catch(e) {
document.getElementById('error-msg').textContent = 'Failed to load status: ' + e.message;
document.getElementById('error-msg').classList.remove('hidden');
} finally {
icon.classList.remove('spin');
}
}
function renderSummary(s) {
document.getElementById('total-services').textContent = s.total;
document.getElementById('total-tools').textContent = s.totalTools;
document.getElementById('healthy-count').textContent = s.healthy + ' / ' + s.total;
}
function renderServices(services) {
const el = document.getElementById('services');
el.innerHTML = services.map(s => {
const healthy = s.status === 'healthy';
const border = healthy ? 'border-green-700' : 'border-red-700';
const badge = healthy
? '<span class="px-2 py-0.5 rounded-full text-xs font-medium bg-green-900 text-green-300">Healthy</span>'
: '<span class="px-2 py-0.5 rounded-full text-xs font-medium bg-red-900 text-red-300">Unhealthy</span>';
const dot = healthy ? 'bg-green-400' : 'bg-red-400';
return `
<div class="bg-slate-800 rounded-xl p-5 border ${border}">
<div class="flex items-center justify-between mb-4">
<div class="flex items-center gap-2">
<div class="w-2.5 h-2.5 rounded-full ${dot}"></div>
<h3 class="font-semibold text-white">${s.name}</h3>
</div>
${badge}
</div>
<div class="space-y-2">
<div class="flex justify-between items-center bg-slate-900 rounded-lg px-3 py-2">
<span class="text-slate-400 text-sm">Tools Available</span>
<span class="text-white font-bold text-lg">${s.toolCount}</span>
</div>
<div class="flex justify-between items-center bg-slate-900 rounded-lg px-3 py-2">
<span class="text-slate-400 text-sm">Response Time</span>
<span class="text-slate-300 text-sm font-mono">${s.responseTime != null ? s.responseTime + 'ms' : ''}</span>
</div>
<div class="bg-slate-900 rounded-lg px-3 py-2">
<p class="text-slate-400 text-xs mb-0.5">Endpoint</p>
<p class="text-slate-300 text-xs font-mono truncate">${s.url}</p>
</div>
</div>
</div>`;
}).join('');
}
load();
setInterval(load, 30000);
</script>
</body>
</html>"""
async def dashboard(request: Request) -> Response:
if not _validate_gui_session(request):
return RedirectResponse(f"/gui-login?next=/dashboard", status_code=303)
2026-03-31 15:33:37 -04:00
return Response(DASHBOARD_HTML, media_type="text/html")
# ---------------------------------------------------------------------------
# Startup
# ---------------------------------------------------------------------------
async def startup():
if not GATEWAY_PASSWORD:
logger.error("OAUTH_PASSWORD is not set! OAuth login will fail.")
else:
logger.info("OAuth password configured")
logger.info(f"OAuth issuer: {ISSUER_URL}")
logger.info(f"Starting MCP Gateway Proxy with {len(BACKENDS)} backends")
logger.info("Waiting 10s for backends to start...")
await asyncio.sleep(10)
for name, url in BACKENDS.items():
tools = await initialize_backend(name, url)
if not tools:
logger.warning(f" {name}: no tools discovered — will retry on first request")
logger.info(f"Gateway ready: {len(TOOL_DEFINITIONS)} tools from {len(BACKENDS)} backends")
@asynccontextmanager
async def lifespan(app):
await startup()
yield
# ---------------------------------------------------------------------------
# LinkedIn OAuth proxy routes
# ---------------------------------------------------------------------------
LINKEDIN_MCP_URL = "http://mcp-linkedin:8500"
async def linkedin_auth_proxy(request: Request):
"""Proxy /linkedin/auth to the linkedin-mcp container."""
async with httpx.AsyncClient() as client:
resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/auth", follow_redirects=False)
if resp.status_code in (301, 302, 307, 308):
return RedirectResponse(resp.headers["location"])
return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html"))
async def linkedin_callback_proxy(request: Request):
"""Proxy /linkedin/callback to the linkedin-mcp container."""
query = str(request.url.query)
async with httpx.AsyncClient() as client:
resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/callback?{query}", follow_redirects=False)
return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html"))
async def linkedin_status_proxy(request: Request):
"""Proxy /linkedin/status to the linkedin-mcp container."""
async with httpx.AsyncClient() as client:
resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/status", follow_redirects=False)
return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html"))
# Build routes list
routes = [
# Well-known discovery (Claude tries both with and without /mcp suffix)
Route("/.well-known/oauth-protected-resource", well_known_protected_resource, methods=["GET"]),
Route("/.well-known/oauth-protected-resource/mcp", well_known_protected_resource, methods=["GET"]),
Route("/.well-known/oauth-authorization-server", well_known_oauth_authorization_server, methods=["GET"]),
Route("/.well-known/oauth-authorization-server/mcp", well_known_oauth_authorization_server, methods=["GET"]),
Route("/.well-known/openid-configuration", well_known_oauth_authorization_server, methods=["GET"]),
# OAuth endpoints at /oauth/* (spec-standard)
Route("/oauth/register", oauth_register, methods=["POST"]),
Route("/oauth/authorize", oauth_authorize, methods=["GET", "POST"]),
Route("/oauth/token", oauth_token, methods=["POST"]),
# OAuth endpoints at root (Claude may construct these from base URL)
Route("/register", oauth_register, methods=["POST"]),
Route("/authorize", oauth_authorize, methods=["GET", "POST"]),
Route("/token", oauth_token, methods=["POST"]),
# MCP endpoint (OAuth-protected)
Route("/mcp", handle_mcp, methods=["GET", "HEAD", "POST", "DELETE"]),
# Monitoring
Route("/health", health, methods=["GET"]),
Route("/status", status, methods=["GET"]),
# Dashboard (no auth required)
Route("/gui-login", gui_login_get, methods=["GET"]),
Route("/gui-login", gui_login_post, methods=["POST"]),
Route("/gui-logout", gui_logout, methods=["GET"]),
2026-03-31 15:33:37 -04:00
Route("/dashboard", dashboard, methods=["GET"]),
Route("/dashboard/status", dashboard_status, methods=["GET"]),
# LinkedIn OAuth flow (proxied to linkedin-mcp container)
Route("/linkedin/auth", linkedin_auth_proxy, methods=["GET"]),
Route("/linkedin/callback", linkedin_callback_proxy, methods=["GET"]),
Route("/linkedin/status", linkedin_status_proxy, methods=["GET"]),
# Admin / User Management UI
Route("/admin", user_dashboard, methods=["GET"]),
# User management API
Route("/users", list_users, methods=["GET"]),
Route("/users", create_user, methods=["POST"]),
Route("/users/{username}", get_user, methods=["GET"]),
Route("/users/{username}", delete_user, methods=["DELETE"]),
Route("/users/{username}/enable", toggle_user, methods=["PATCH"]),
Route("/users/{username}/keys", generate_api_key, methods=["POST"]),
Route("/users/{username}/mcp-access", set_mcp_access, methods=["PUT"]),
Route("/keys/revoke", revoke_api_key, methods=["POST"]),
]
app = Starlette(
routes=routes,
lifespan=lifespan,
)
if __name__ == "__main__":
port = int(os.environ.get("PORT", "4444"))
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")