"""
MCP Gateway Proxy with OAuth 2.1
=================================
Aggregates multiple MCP servers behind a single Streamable HTTP endpoint.
Implements a self-contained OAuth 2.1 provider compatible with claude.ai:
- RFC 8414 Authorization Server Metadata
- RFC 9728 Protected Resource Metadata
- RFC 7591 Dynamic Client Registration
- PKCE (S256) per OAuth 2.1
- Authorization Code Grant with refresh tokens
"""
import asyncio
import base64
import hashlib
import html
import json
import logging
import os
import secrets
import time
import uuid
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlencode
import httpx
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse
from starlette.routing import Route
import uvicorn
from user_routes import (
create_user, list_users, get_user, delete_user,
toggle_user, generate_api_key, revoke_api_key, set_mcp_access
)
from user_dashboard_ui import user_management_dashboard as user_dashboard
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("mcp-gateway")
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net")
GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "")
ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600"))
REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000"))
AUTH_CODE_TTL = 600
STATIC_API_KEY = os.environ.get("GATEWAY_STATIC_API_KEY", "")
# ---------------------------------------------------------------------------
# In-memory stores
# ---------------------------------------------------------------------------
REGISTERED_CLIENTS: dict[str, dict] = {}
AUTH_CODES: dict[str, dict] = {}
ACCESS_TOKENS: dict[str, dict] = {}
REFRESH_TOKENS: dict[str, dict] = {}
PENDING_AUTH: dict[str, dict] = {}
# Browser session store (cookie-based login for /admin and /dashboard)
GUI_SESSIONS: dict[str, float] = {} # token -> expires_at
GUI_SESSION_TTL = 8 * 3600 # 8 hours
def _create_gui_session() -> str:
token = secrets.token_hex(32)
GUI_SESSIONS[token] = time.time() + GUI_SESSION_TTL
return token
def _validate_gui_session(request: Request) -> bool:
token = request.cookies.get("mcp_session")
if not token:
return False
expires = GUI_SESSIONS.get(token, 0)
if expires < time.time():
GUI_SESSIONS.pop(token, None)
return False
return True
GUI_LOGIN_HTML = """
MCP Gateway — Login
🔒
MCP Gateway
Enter your gateway password to continue
{error_html}
"""
def _hash(value: str) -> str:
return hashlib.sha256(value.encode()).hexdigest()
def _clean_expired():
now = time.time()
for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS):
expired = [k for k, v in store.items() if v.get("expires_at", 0) < now]
for k in expired:
del store[k]
# ---------------------------------------------------------------------------
# Backend MCP aggregation
# ---------------------------------------------------------------------------
def load_backends() -> dict[str, str]:
backends = {}
for key, value in os.environ.items():
if key.startswith("MCP_BACKEND_"):
name = key[len("MCP_BACKEND_"):].lower()
backends[name] = value
logger.info(f"Backend configured: {name} -> {value}")
return backends
BACKENDS: dict[str, str] = load_backends()
TOOL_REGISTRY: dict[str, str] = {}
TOOL_DEFINITIONS: dict[str, dict] = {}
BACKEND_SESSIONS: dict[str, str | None] = {}
GATEWAY_SESSIONS: dict[str, bool] = {}
def parse_sse_response(text: str) -> dict | None:
for line in text.strip().split("\n"):
line = line.strip()
if line.startswith("data: "):
try:
return json.loads(line[6:])
except json.JSONDecodeError:
continue
return None
async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict:
payload = {"jsonrpc": "2.0", "method": method, "id": request_id}
if params is not None:
payload["params"] = params
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Host": "localhost",
}
if session_id:
headers["Mcp-Session-Id"] = session_id
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(backend_url, json=payload, headers=headers)
new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id")
content_type = resp.headers.get("content-type", "")
if resp.status_code in (200, 201):
try:
if "text/event-stream" in content_type:
parsed = parse_sse_response(resp.text)
return {"result": parsed, "session_id": new_session}
else:
return {"result": resp.json(), "session_id": new_session}
except Exception:
return {"result": None, "session_id": new_session}
elif resp.status_code == 202:
return {"result": None, "session_id": new_session}
else:
logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}")
return {"result": None, "session_id": new_session}
def _normalize_schema(obj: Any) -> None:
"""Recursively normalize JSON schemas so mcpo can parse them.
Fixes:
- type as list (e.g. ["string","null"]) -> "string"
- items as list (tuple validation) -> {"type": "string"}
- anyOf/oneOf containing non-dict entries
"""
if isinstance(obj, dict):
if isinstance(obj.get("type"), list):
obj["type"] = "string"
# Convert tuple-style items (list) to simple dict schema
if "items" in obj and isinstance(obj["items"], list):
obj["items"] = {"type": "string"}
# Clean anyOf/oneOf entries that might contain non-dicts
for key in ("anyOf", "oneOf"):
if key in obj and isinstance(obj[key], list):
obj[key] = [x for x in obj[key] if isinstance(x, dict)]
if len(obj[key]) == 1:
# Collapse single-option anyOf/oneOf
obj.update(obj.pop(key)[0])
elif len(obj[key]) == 0:
del obj[key]
obj["type"] = "string"
for v in obj.values():
_normalize_schema(v)
elif isinstance(obj, list):
for item in obj:
_normalize_schema(item)
async def initialize_backend(name: str, url: str) -> list[dict]:
logger.info(f"Initializing backend: {name} at {url}")
for attempt in range(3):
try:
init_result = await mcp_request(url, "initialize", {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"},
})
session_id = init_result.get("session_id")
BACKEND_SESSIONS[name] = session_id
logger.info(f" {name}: initialized (session: {session_id})")
await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id)
tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id)
tools_data = tools_result.get("result", {})
if isinstance(tools_data, dict):
if "result" in tools_data:
tools_data = tools_data["result"]
tools = tools_data.get("tools", [])
else:
tools = []
logger.info(f" {name}: discovered {len(tools)} tools")
for tool in tools:
original_name = tool["name"]
prefixed_name = f"{name}_{original_name}"
tool["name"] = prefixed_name
# Recursively normalize list-type schemas (e.g. ["string","null"]) to "string"
# so mcpo can parse the schema without crashing
_normalize_schema(tool.get("inputSchema", {}))
TOOL_REGISTRY[prefixed_name] = name
TOOL_DEFINITIONS[prefixed_name] = tool
return tools
except Exception as e:
if attempt < 2:
logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...")
await asyncio.sleep(5)
else:
logger.error(f" {name}: failed to initialize after 3 attempts - {e}")
return []
return []
async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict:
url = BACKENDS[backend_name]
session_id = BACKEND_SESSIONS.get(backend_name)
prefix = f"{backend_name}_"
original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name
if not session_id:
await initialize_backend(backend_name, url)
session_id = BACKEND_SESSIONS.get(backend_name)
result = await mcp_request(
url, "tools/call",
{"name": original_name, "arguments": arguments},
request_id=request_id,
session_id=session_id,
)
response_data = result.get("result", {})
if isinstance(response_data, dict) and response_data.get("error"):
error_code = response_data["error"].get("code", 0)
if error_code in (-32600, -32601):
logger.info(f"Re-initializing {backend_name} after error {error_code}")
await initialize_backend(backend_name, url)
session_id = BACKEND_SESSIONS.get(backend_name)
result = await mcp_request(
url, "tools/call",
{"name": original_name, "arguments": arguments},
request_id=request_id,
session_id=session_id,
)
response_data = result.get("result", {})
return response_data
# ---------------------------------------------------------------------------
# OAuth 2.1: Token validation
# ---------------------------------------------------------------------------
def validate_bearer_token(request: Request) -> dict | None:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:]
# Check static API key first (persistent, survives restarts)
if STATIC_API_KEY and token == STATIC_API_KEY:
return {"client_id": "static", "scope": "mcp:tools", "user": "static"}
token_hash = _hash(token)
info = ACCESS_TOKENS.get(token_hash)
if not info:
return None
if info["expires_at"] < time.time():
del ACCESS_TOKENS[token_hash]
return None
return info
# ---------------------------------------------------------------------------
# Consent page HTML template
# ---------------------------------------------------------------------------
CONSENT_PAGE_CSS = """
* { margin: 0; padding: 0; box-sizing: border-box; }
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0f172a; color: #e2e8f0; min-height: 100vh;
display: flex; align-items: center; justify-content: center; }
.card { background: #1e293b; border-radius: 16px; padding: 40px;
max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); }
h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; }
.subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; }
.client { background: #334155; border-radius: 8px; padding: 12px 16px;
margin-bottom: 24px; font-size: 14px; }
.client strong { color: #38bdf8; }
label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; }
input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px;
border: 1px solid #475569; background: #0f172a; color: #f8fafc;
font-size: 16px; margin-bottom: 20px; outline: none; }
input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); }
.buttons { display: flex; gap: 12px; }
button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px;
font-weight: 600; cursor: pointer; transition: all 0.15s; }
.approve { background: #38bdf8; color: #0f172a; }
.approve:hover { background: #7dd3fc; }
.deny { background: #334155; color: #94a3b8; }
.deny:hover { background: #475569; color: #e2e8f0; }
.scope { font-size: 13px; color: #64748b; margin-bottom: 16px; }
.error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px;
margin-bottom: 16px; font-size: 13px; }
"""
def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str:
error_html = f'{html.escape(error_msg)}
' if error_msg else ""
return f"""
MCP Gateway — Authorize
MCP Gateway
Authorization Request
{html.escape(client_name)} wants access to your MCP tools.
Scope: {html.escape(scope)}
{error_html}
"""
# ---------------------------------------------------------------------------
# OAuth 2.1 Endpoints
# ---------------------------------------------------------------------------
async def well_known_protected_resource(request: Request) -> JSONResponse:
return JSONResponse({
"resource": ISSUER_URL,
"authorization_servers": [ISSUER_URL],
"scopes_supported": ["mcp:tools"],
"bearer_methods_supported": ["header"],
})
async def well_known_oauth_authorization_server(request: Request) -> JSONResponse:
return JSONResponse({
"issuer": ISSUER_URL,
"authorization_endpoint": f"{ISSUER_URL}/authorize",
"token_endpoint": f"{ISSUER_URL}/token",
"registration_endpoint": f"{ISSUER_URL}/register",
"scopes_supported": ["mcp:tools"],
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
"code_challenge_methods_supported": ["S256"],
"service_documentation": f"{ISSUER_URL}/health",
})
async def oauth_register(request: Request) -> JSONResponse:
try:
body = await request.json()
except Exception:
return JSONResponse({"error": "invalid_request"}, status_code=400)
client_id = str(uuid.uuid4())
client_secret = secrets.token_urlsafe(48)
client_info = {
"client_id": client_id,
"client_secret": client_secret,
"client_name": body.get("client_name", "Unknown Client"),
"redirect_uris": body.get("redirect_uris", []),
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
"response_types": body.get("response_types", ["code"]),
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_post"),
}
REGISTERED_CLIENTS[client_id] = client_info
logger.info(f"DCR: registered client '{client_info['client_name']}' as {client_id}")
return JSONResponse(client_info, status_code=201)
async def oauth_authorize(request: Request) -> Response:
_clean_expired()
if request.method == "GET":
params = request.query_params
client_id = params.get("client_id", "")
redirect_uri = params.get("redirect_uri", "")
state = params.get("state", "")
scope = params.get("scope", "mcp:tools")
code_challenge = params.get("code_challenge", "")
code_challenge_method = params.get("code_challenge_method", "S256")
response_type = params.get("response_type", "code")
if response_type != "code":
return JSONResponse({"error": "unsupported_response_type"}, status_code=400)
client = REGISTERED_CLIENTS.get(client_id)
if not client:
# Auto-register unknown clients at authorize time (handles clients that
# skip DCR or whose registration was lost on gateway restart)
client = {
"client_id": client_id,
"client_secret": secrets.token_urlsafe(48),
"client_name": f"auto-{client_id[:8]}",
"redirect_uris": [redirect_uri] if redirect_uri else [],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
REGISTERED_CLIENTS[client_id] = client
logger.info(f"OAuth: auto-registered unknown client {client_id} at authorize")
internal_state = secrets.token_urlsafe(32)
PENDING_AUTH[internal_state] = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"scope": scope,
"state": state,
}
client_name = client.get("client_name", "Unknown Client")
return HTMLResponse(render_consent_page(client_name, scope, internal_state))
if request.method == "POST":
form = await request.form()
internal_state = form.get("internal_state", "")
password = form.get("password", "")
action = form.get("action", "")
pending = PENDING_AUTH.pop(internal_state, None)
if not pending:
return HTMLResponse("Authorization expired
Please try connecting again.
", status_code=400)
redirect_uri = pending["redirect_uri"]
state = pending["state"]
if action != "approve":
qs = urlencode({"error": "access_denied", "state": state})
return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302)
if not GATEWAY_PASSWORD:
return HTMLResponse("Server misconfigured
OAUTH_PASSWORD not set.
", status_code=500)
if password != GATEWAY_PASSWORD:
PENDING_AUTH[internal_state] = pending
client = REGISTERED_CLIENTS.get(pending["client_id"], {})
client_name = client.get("client_name", "Unknown Client")
return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again."))
code = secrets.token_urlsafe(48)
AUTH_CODES[code] = {
"client_id": pending["client_id"],
"redirect_uri": redirect_uri,
"code_challenge": pending["code_challenge"],
"code_challenge_method": pending["code_challenge_method"],
"scope": pending["scope"],
"expires_at": time.time() + AUTH_CODE_TTL,
"user": "owner",
}
qs = urlencode({"code": code, "state": state})
logger.info(f"OAuth: issued auth code for client {pending['client_id']}")
return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302)
return JSONResponse({"error": "method_not_allowed"}, status_code=405)
async def oauth_token(request: Request) -> JSONResponse:
_clean_expired()
content_type = request.headers.get("content-type", "")
if "application/json" in content_type:
try:
body = await request.json()
except Exception:
body = {}
else:
form = await request.form()
body = dict(form)
grant_type = body.get("grant_type", "")
if grant_type == "authorization_code":
code = body.get("code", "")
client_id = body.get("client_id", "")
code_verifier = body.get("code_verifier", "")
code_info = AUTH_CODES.pop(code, None)
if not code_info:
return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400)
if code_info["expires_at"] < time.time():
return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400)
if code_info["client_id"] != client_id:
return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400)
if code_info.get("code_challenge"):
if not code_verifier:
return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400)
expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
if expected != code_info["code_challenge"]:
return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400)
access_token = secrets.token_urlsafe(48)
refresh_token = secrets.token_urlsafe(48)
access_hash = _hash(access_token)
refresh_hash = _hash(refresh_token)
ACCESS_TOKENS[access_hash] = {
"client_id": client_id,
"scope": code_info["scope"],
"expires_at": time.time() + ACCESS_TOKEN_TTL,
"user": code_info["user"],
}
REFRESH_TOKENS[refresh_hash] = {
"client_id": client_id,
"scope": code_info["scope"],
"expires_at": time.time() + REFRESH_TOKEN_TTL,
"user": code_info["user"],
"access_token_hash": access_hash,
}
logger.info(f"OAuth: issued access token for client {client_id}")
return JSONResponse({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": ACCESS_TOKEN_TTL,
"refresh_token": refresh_token,
"scope": code_info["scope"],
})
elif grant_type == "refresh_token":
refresh_token_val = body.get("refresh_token", "")
client_id = body.get("client_id", "")
refresh_hash = _hash(refresh_token_val)
refresh_info = REFRESH_TOKENS.get(refresh_hash)
if not refresh_info:
return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400)
if refresh_info["expires_at"] < time.time():
del REFRESH_TOKENS[refresh_hash]
return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400)
if refresh_info["client_id"] != client_id:
return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400)
old_access_hash = refresh_info.get("access_token_hash")
if old_access_hash and old_access_hash in ACCESS_TOKENS:
del ACCESS_TOKENS[old_access_hash]
del REFRESH_TOKENS[refresh_hash]
new_access_token = secrets.token_urlsafe(48)
new_refresh_token = secrets.token_urlsafe(48)
new_access_hash = _hash(new_access_token)
new_refresh_hash = _hash(new_refresh_token)
ACCESS_TOKENS[new_access_hash] = {
"client_id": client_id,
"scope": refresh_info["scope"],
"expires_at": time.time() + ACCESS_TOKEN_TTL,
"user": refresh_info["user"],
}
REFRESH_TOKENS[new_refresh_hash] = {
"client_id": client_id,
"scope": refresh_info["scope"],
"expires_at": time.time() + REFRESH_TOKEN_TTL,
"user": refresh_info["user"],
"access_token_hash": new_access_hash,
}
logger.info(f"OAuth: refreshed token for client {client_id}")
return JSONResponse({
"access_token": new_access_token,
"token_type": "Bearer",
"expires_in": ACCESS_TOKEN_TTL,
"refresh_token": new_refresh_token,
"scope": refresh_info["scope"],
})
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
# ---------------------------------------------------------------------------
# MCP Endpoint (OAuth-protected)
# ---------------------------------------------------------------------------
async def handle_mcp(request: Request) -> Response:
if request.method == "HEAD":
return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"})
if request.method == "GET":
token_info = validate_bearer_token(request)
if not token_info:
return Response(
status_code=401,
headers={
"WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"',
"Content-Type": "application/json",
},
media_type="application/json",
content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}),
)
# Authenticated GET — return empty SSE stream (mcpo uses this for server-sent events)
async def empty_sse():
yield b": ping\n\n"
return StreamingResponse(empty_sse(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"})
if request.method == "DELETE":
session_id = request.headers.get("Mcp-Session-Id")
if session_id and session_id in GATEWAY_SESSIONS:
del GATEWAY_SESSIONS[session_id]
return Response(status_code=200)
# POST — require Bearer token
token_info = validate_bearer_token(request)
if not token_info:
return Response(
status_code=401,
headers={
"WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"',
"Content-Type": "application/json",
},
media_type="application/json",
content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}),
)
try:
body = await request.json()
except Exception:
return JSONResponse(
{"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None},
status_code=400,
)
method = body.get("method", "")
params = body.get("params", {})
req_id = body.get("id")
if method == "initialize":
session_id = str(uuid.uuid4())
GATEWAY_SESSIONS[session_id] = True
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"},
},
})
response.headers["Mcp-Session-Id"] = session_id
return response
if method == "notifications/initialized":
return Response(status_code=202)
if method == "tools/list":
tools = list(TOOL_DEFINITIONS.values())
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": {"tools": tools},
})
session_id = request.headers.get("Mcp-Session-Id")
if session_id:
response.headers["Mcp-Session-Id"] = session_id
return response
if method == "tools/call":
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
backend_name = TOOL_REGISTRY.get(tool_name)
if not backend_name:
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
})
try:
result = await forward_tool_call(backend_name, tool_name, arguments, req_id)
if isinstance(result, dict) and "jsonrpc" in result:
response = JSONResponse(result)
elif isinstance(result, dict) and "result" in result:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": result["result"],
})
elif isinstance(result, dict) and "error" in result:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": result["error"],
})
else:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": result,
})
session_id = request.headers.get("Mcp-Session-Id")
if session_id:
response.headers["Mcp-Session-Id"] = session_id
return response
except Exception as e:
logger.error(f"Tool call failed: {tool_name} - {e}")
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32603, "message": str(e)},
})
if method == "ping":
return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}})
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Method not found: {method}"},
})
# ---------------------------------------------------------------------------
# GUI Login (password → session cookie for /admin and /dashboard)
# ---------------------------------------------------------------------------
async def gui_login_get(request: Request) -> HTMLResponse:
"""GET /gui-login — show login form (redirect here when session missing)."""
next_url = request.query_params.get("next", "/admin")
page = GUI_LOGIN_HTML.format(error_html="", next_url=html.escape(next_url))
return HTMLResponse(page)
async def gui_login_post(request: Request) -> Response:
"""POST /gui-login — validate password, set session cookie, redirect."""
form = await request.form()
password = form.get("password", "")
next_url = form.get("next", "/admin")
# Sanitise redirect target — only allow relative paths on this host
if not next_url.startswith("/") or next_url.startswith("//"):
next_url = "/admin"
if not GATEWAY_PASSWORD or password != GATEWAY_PASSWORD:
error_html = 'Incorrect password. Please try again.
'
page = GUI_LOGIN_HTML.format(error_html=error_html, next_url=html.escape(next_url))
return HTMLResponse(page, status_code=401)
session_token = _create_gui_session()
response = RedirectResponse(next_url, status_code=303)
response.set_cookie(
"mcp_session", session_token,
max_age=GUI_SESSION_TTL,
httponly=True,
samesite="lax",
secure=True,
)
return response
async def gui_logout(request: Request) -> Response:
"""GET /gui-logout — clear session cookie."""
token = request.cookies.get("mcp_session")
if token:
GUI_SESSIONS.pop(token, None)
response = RedirectResponse("/gui-login?next=/admin", status_code=303)
response.delete_cookie("mcp_session")
return response
# ---------------------------------------------------------------------------
# Health / Status
# ---------------------------------------------------------------------------
async def health(request: Request) -> JSONResponse:
return JSONResponse({
"status": "healthy",
"oauth": "enabled",
})
async def status(request: Request) -> JSONResponse:
token_info = validate_bearer_token(request)
if not token_info:
return JSONResponse({"error": "unauthorized"}, status_code=401)
return JSONResponse({
"backends": {
name: {
"url": url,
"tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]),
"session": BACKEND_SESSIONS.get(name),
}
for name, url in BACKENDS.items()
},
"total_tools": len(TOOL_DEFINITIONS),
"tool_list": sorted(TOOL_DEFINITIONS.keys()),
})
# ---------------------------------------------------------------------------
# Dashboard
# ---------------------------------------------------------------------------
DISPLAY_NAMES = {
"erpnext": "ERPNext",
"truenas": "TrueNAS",
"homeassistant": "Home Assistant",
"wave": "Wave Finance",
"linkedin": "LinkedIn",
}
async def probe_backend(name: str, url: str) -> dict:
"""Live-probe a backend: initialize it and count its tools in real time."""
start = time.time()
try:
result = await mcp_request(url, "initialize", {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "mcp-gateway-dashboard", "version": "1.0.0"},
})
session_id = result.get("session_id")
tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id)
tools_data = tools_result.get("result", {})
if isinstance(tools_data, dict):
if "result" in tools_data:
tools_data = tools_data["result"]
tools = tools_data.get("tools", [])
else:
tools = []
elapsed = round((time.time() - start) * 1000)
return {
"status": "healthy",
"toolCount": len(tools),
"responseTime": elapsed,
}
except Exception as e:
elapsed = round((time.time() - start) * 1000)
logger.warning(f"Dashboard probe failed for {name}: {e}")
# Fall back to cached registry values if live probe fails
cached_tools = len([t for t, b in TOOL_REGISTRY.items() if b == name])
return {
"status": "healthy" if cached_tools > 0 else "unhealthy",
"toolCount": cached_tools,
"responseTime": elapsed,
"note": "cached",
}
async def dashboard_status(request: Request) -> JSONResponse:
"""Auth-protected endpoint — live-probes all backends."""
if not _validate_gui_session(request) and not validate_bearer_token(request):
return JSONResponse({"error": "unauthorized"}, status_code=401)
probes = await asyncio.gather(
*[probe_backend(name, url) for name, url in BACKENDS.items()],
return_exceptions=True
)
services = []
for (name, url), probe in zip(BACKENDS.items(), probes):
if isinstance(probe, Exception):
probe = {"status": "unhealthy", "toolCount": 0, "responseTime": None}
services.append({
"name": DISPLAY_NAMES.get(name, name.capitalize()),
"key": name,
"url": url,
"status": probe["status"],
"toolCount": probe["toolCount"],
"responseTime": probe.get("responseTime"),
"lastCheck": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
})
return JSONResponse({
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"services": services,
"summary": {
"total": len(services),
"healthy": len([s for s in services if s["status"] == "healthy"]),
"unhealthy": len([s for s in services if s["status"] == "unhealthy"]),
"totalTools": sum(s["toolCount"] for s in services),
}
})
DASHBOARD_HTML = """
MCP Gateway Dashboard
MCP Gateway
Service Status Dashboard
"""
async def dashboard(request: Request) -> Response:
if not _validate_gui_session(request):
return RedirectResponse(f"/gui-login?next=/dashboard", status_code=303)
return Response(DASHBOARD_HTML, media_type="text/html")
# ---------------------------------------------------------------------------
# Startup
# ---------------------------------------------------------------------------
async def startup():
if not GATEWAY_PASSWORD:
logger.error("OAUTH_PASSWORD is not set! OAuth login will fail.")
else:
logger.info("OAuth password configured")
logger.info(f"OAuth issuer: {ISSUER_URL}")
logger.info(f"Starting MCP Gateway Proxy with {len(BACKENDS)} backends")
logger.info("Waiting 10s for backends to start...")
await asyncio.sleep(10)
for name, url in BACKENDS.items():
tools = await initialize_backend(name, url)
if not tools:
logger.warning(f" {name}: no tools discovered — will retry on first request")
logger.info(f"Gateway ready: {len(TOOL_DEFINITIONS)} tools from {len(BACKENDS)} backends")
@asynccontextmanager
async def lifespan(app):
await startup()
yield
# ---------------------------------------------------------------------------
# LinkedIn OAuth proxy routes
# ---------------------------------------------------------------------------
LINKEDIN_MCP_URL = "http://mcp-linkedin:8500"
async def linkedin_auth_proxy(request: Request):
"""Proxy /linkedin/auth to the linkedin-mcp container."""
async with httpx.AsyncClient() as client:
resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/auth", follow_redirects=False)
if resp.status_code in (301, 302, 307, 308):
return RedirectResponse(resp.headers["location"])
return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html"))
async def linkedin_callback_proxy(request: Request):
"""Proxy /linkedin/callback to the linkedin-mcp container."""
query = str(request.url.query)
async with httpx.AsyncClient() as client:
resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/callback?{query}", follow_redirects=False)
return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html"))
async def linkedin_status_proxy(request: Request):
"""Proxy /linkedin/status to the linkedin-mcp container."""
async with httpx.AsyncClient() as client:
resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/status", follow_redirects=False)
return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html"))
# Build routes list
routes = [
# Well-known discovery (Claude tries both with and without /mcp suffix)
Route("/.well-known/oauth-protected-resource", well_known_protected_resource, methods=["GET"]),
Route("/.well-known/oauth-protected-resource/mcp", well_known_protected_resource, methods=["GET"]),
Route("/.well-known/oauth-authorization-server", well_known_oauth_authorization_server, methods=["GET"]),
Route("/.well-known/oauth-authorization-server/mcp", well_known_oauth_authorization_server, methods=["GET"]),
Route("/.well-known/openid-configuration", well_known_oauth_authorization_server, methods=["GET"]),
# OAuth endpoints at /oauth/* (spec-standard)
Route("/oauth/register", oauth_register, methods=["POST"]),
Route("/oauth/authorize", oauth_authorize, methods=["GET", "POST"]),
Route("/oauth/token", oauth_token, methods=["POST"]),
# OAuth endpoints at root (Claude may construct these from base URL)
Route("/register", oauth_register, methods=["POST"]),
Route("/authorize", oauth_authorize, methods=["GET", "POST"]),
Route("/token", oauth_token, methods=["POST"]),
# MCP endpoint (OAuth-protected)
Route("/mcp", handle_mcp, methods=["GET", "HEAD", "POST", "DELETE"]),
# Monitoring
Route("/health", health, methods=["GET"]),
Route("/status", status, methods=["GET"]),
# Dashboard (no auth required)
Route("/gui-login", gui_login_get, methods=["GET"]),
Route("/gui-login", gui_login_post, methods=["POST"]),
Route("/gui-logout", gui_logout, methods=["GET"]),
Route("/dashboard", dashboard, methods=["GET"]),
Route("/dashboard/status", dashboard_status, methods=["GET"]),
# LinkedIn OAuth flow (proxied to linkedin-mcp container)
Route("/linkedin/auth", linkedin_auth_proxy, methods=["GET"]),
Route("/linkedin/callback", linkedin_callback_proxy, methods=["GET"]),
Route("/linkedin/status", linkedin_status_proxy, methods=["GET"]),
# Admin / User Management UI
Route("/admin", user_dashboard, methods=["GET"]),
# User management API
Route("/users", list_users, methods=["GET"]),
Route("/users", create_user, methods=["POST"]),
Route("/users/{username}", get_user, methods=["GET"]),
Route("/users/{username}", delete_user, methods=["DELETE"]),
Route("/users/{username}/enable", toggle_user, methods=["PATCH"]),
Route("/users/{username}/keys", generate_api_key, methods=["POST"]),
Route("/users/{username}/mcp-access", set_mcp_access, methods=["PUT"]),
Route("/keys/revoke", revoke_api_key, methods=["POST"]),
]
app = Starlette(
routes=routes,
lifespan=lifespan,
)
if __name__ == "__main__":
port = int(os.environ.get("PORT", "4444"))
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")