""" MCP Gateway Proxy with OAuth 2.1 ================================= Aggregates multiple MCP servers behind a single Streamable HTTP endpoint. Implements a self-contained OAuth 2.1 provider compatible with claude.ai: - RFC 8414 Authorization Server Metadata - RFC 9728 Protected Resource Metadata - RFC 7591 Dynamic Client Registration - PKCE (S256) per OAuth 2.1 - Authorization Code Grant with refresh tokens """ import asyncio import base64 import hashlib import html import json import logging import os import secrets import time import uuid from contextlib import asynccontextmanager from typing import Any from urllib.parse import urlencode import httpx from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse from starlette.routing import Route import uvicorn from user_routes import ( create_user, list_users, get_user, delete_user, toggle_user, generate_api_key, revoke_api_key, set_mcp_access ) from user_dashboard_ui import user_management_dashboard as user_dashboard logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger("mcp-gateway") # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net") GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "") ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600")) REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000")) AUTH_CODE_TTL = 600 STATIC_API_KEY = os.environ.get("GATEWAY_STATIC_API_KEY", "") # --------------------------------------------------------------------------- # In-memory stores # --------------------------------------------------------------------------- REGISTERED_CLIENTS: dict[str, dict] = {} AUTH_CODES: dict[str, dict] = {} ACCESS_TOKENS: dict[str, dict] = {} REFRESH_TOKENS: dict[str, dict] = {} PENDING_AUTH: dict[str, dict] = {} def _hash(value: str) -> str: return hashlib.sha256(value.encode()).hexdigest() def _clean_expired(): now = time.time() for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS): expired = [k for k, v in store.items() if v.get("expires_at", 0) < now] for k in expired: del store[k] # --------------------------------------------------------------------------- # Backend MCP aggregation # --------------------------------------------------------------------------- def load_backends() -> dict[str, str]: backends = {} for key, value in os.environ.items(): if key.startswith("MCP_BACKEND_"): name = key[len("MCP_BACKEND_"):].lower() backends[name] = value logger.info(f"Backend configured: {name} -> {value}") return backends BACKENDS: dict[str, str] = load_backends() TOOL_REGISTRY: dict[str, str] = {} TOOL_DEFINITIONS: dict[str, dict] = {} BACKEND_SESSIONS: dict[str, str | None] = {} GATEWAY_SESSIONS: dict[str, bool] = {} def parse_sse_response(text: str) -> dict | None: for line in text.strip().split("\n"): line = line.strip() if line.startswith("data: "): try: return json.loads(line[6:]) except json.JSONDecodeError: continue return None async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict: payload = {"jsonrpc": "2.0", "method": method, "id": request_id} if params is not None: payload["params"] = params headers = { "Content-Type": "application/json", "Accept": "application/json, text/event-stream", "Host": "localhost", } if session_id: headers["Mcp-Session-Id"] = session_id async with httpx.AsyncClient(timeout=30) as client: resp = await client.post(backend_url, json=payload, headers=headers) new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id") content_type = resp.headers.get("content-type", "") if resp.status_code in (200, 201): try: if "text/event-stream" in content_type: parsed = parse_sse_response(resp.text) return {"result": parsed, "session_id": new_session} else: return {"result": resp.json(), "session_id": new_session} except Exception: return {"result": None, "session_id": new_session} elif resp.status_code == 202: return {"result": None, "session_id": new_session} else: logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}") return {"result": None, "session_id": new_session} def _normalize_schema(obj: Any) -> None: """Recursively normalize JSON schemas so mcpo can parse them. Fixes: - type as list (e.g. ["string","null"]) -> "string" - items as list (tuple validation) -> {"type": "string"} - anyOf/oneOf containing non-dict entries """ if isinstance(obj, dict): if isinstance(obj.get("type"), list): obj["type"] = "string" # Convert tuple-style items (list) to simple dict schema if "items" in obj and isinstance(obj["items"], list): obj["items"] = {"type": "string"} # Clean anyOf/oneOf entries that might contain non-dicts for key in ("anyOf", "oneOf"): if key in obj and isinstance(obj[key], list): obj[key] = [x for x in obj[key] if isinstance(x, dict)] if len(obj[key]) == 1: # Collapse single-option anyOf/oneOf obj.update(obj.pop(key)[0]) elif len(obj[key]) == 0: del obj[key] obj["type"] = "string" for v in obj.values(): _normalize_schema(v) elif isinstance(obj, list): for item in obj: _normalize_schema(item) async def initialize_backend(name: str, url: str) -> list[dict]: logger.info(f"Initializing backend: {name} at {url}") for attempt in range(3): try: init_result = await mcp_request(url, "initialize", { "protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, }) session_id = init_result.get("session_id") BACKEND_SESSIONS[name] = session_id logger.info(f" {name}: initialized (session: {session_id})") await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id) tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id) tools_data = tools_result.get("result", {}) if isinstance(tools_data, dict): if "result" in tools_data: tools_data = tools_data["result"] tools = tools_data.get("tools", []) else: tools = [] logger.info(f" {name}: discovered {len(tools)} tools") for tool in tools: original_name = tool["name"] prefixed_name = f"{name}_{original_name}" tool["name"] = prefixed_name # Recursively normalize list-type schemas (e.g. ["string","null"]) to "string" # so mcpo can parse the schema without crashing _normalize_schema(tool.get("inputSchema", {})) TOOL_REGISTRY[prefixed_name] = name TOOL_DEFINITIONS[prefixed_name] = tool return tools except Exception as e: if attempt < 2: logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...") await asyncio.sleep(5) else: logger.error(f" {name}: failed to initialize after 3 attempts - {e}") return [] return [] async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict: url = BACKENDS[backend_name] session_id = BACKEND_SESSIONS.get(backend_name) prefix = f"{backend_name}_" original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name if not session_id: await initialize_backend(backend_name, url) session_id = BACKEND_SESSIONS.get(backend_name) result = await mcp_request( url, "tools/call", {"name": original_name, "arguments": arguments}, request_id=request_id, session_id=session_id, ) response_data = result.get("result", {}) if isinstance(response_data, dict) and response_data.get("error"): error_code = response_data["error"].get("code", 0) if error_code in (-32600, -32601): logger.info(f"Re-initializing {backend_name} after error {error_code}") await initialize_backend(backend_name, url) session_id = BACKEND_SESSIONS.get(backend_name) result = await mcp_request( url, "tools/call", {"name": original_name, "arguments": arguments}, request_id=request_id, session_id=session_id, ) response_data = result.get("result", {}) return response_data # --------------------------------------------------------------------------- # OAuth 2.1: Token validation # --------------------------------------------------------------------------- def validate_bearer_token(request: Request) -> dict | None: auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): return None token = auth_header[7:] # Check static API key first (persistent, survives restarts) if STATIC_API_KEY and token == STATIC_API_KEY: return {"client_id": "static", "scope": "mcp:tools", "user": "static"} token_hash = _hash(token) info = ACCESS_TOKENS.get(token_hash) if not info: return None if info["expires_at"] < time.time(): del ACCESS_TOKENS[token_hash] return None return info # --------------------------------------------------------------------------- # Consent page HTML template # --------------------------------------------------------------------------- CONSENT_PAGE_CSS = """ * { margin: 0; padding: 0; box-sizing: border-box; } body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: #0f172a; color: #e2e8f0; min-height: 100vh; display: flex; align-items: center; justify-content: center; } .card { background: #1e293b; border-radius: 16px; padding: 40px; max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); } h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; } .subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; } .client { background: #334155; border-radius: 8px; padding: 12px 16px; margin-bottom: 24px; font-size: 14px; } .client strong { color: #38bdf8; } label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; } input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px; border: 1px solid #475569; background: #0f172a; color: #f8fafc; font-size: 16px; margin-bottom: 20px; outline: none; } input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); } .buttons { display: flex; gap: 12px; } button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px; font-weight: 600; cursor: pointer; transition: all 0.15s; } .approve { background: #38bdf8; color: #0f172a; } .approve:hover { background: #7dd3fc; } .deny { background: #334155; color: #94a3b8; } .deny:hover { background: #475569; color: #e2e8f0; } .scope { font-size: 13px; color: #64748b; margin-bottom: 16px; } .error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px; margin-bottom: 16px; font-size: 13px; } """ def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str: error_html = f'
Authorization Request
Scope: {html.escape(scope)}
Please try connecting again.
", status_code=400) redirect_uri = pending["redirect_uri"] state = pending["state"] if action != "approve": qs = urlencode({"error": "access_denied", "state": state}) return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) if not GATEWAY_PASSWORD: return HTMLResponse("OAUTH_PASSWORD not set.
", status_code=500) if password != GATEWAY_PASSWORD: PENDING_AUTH[internal_state] = pending client = REGISTERED_CLIENTS.get(pending["client_id"], {}) client_name = client.get("client_name", "Unknown Client") return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again.")) code = secrets.token_urlsafe(48) AUTH_CODES[code] = { "client_id": pending["client_id"], "redirect_uri": redirect_uri, "code_challenge": pending["code_challenge"], "code_challenge_method": pending["code_challenge_method"], "scope": pending["scope"], "expires_at": time.time() + AUTH_CODE_TTL, "user": "owner", } qs = urlencode({"code": code, "state": state}) logger.info(f"OAuth: issued auth code for client {pending['client_id']}") return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) return JSONResponse({"error": "method_not_allowed"}, status_code=405) async def oauth_token(request: Request) -> JSONResponse: _clean_expired() content_type = request.headers.get("content-type", "") if "application/json" in content_type: try: body = await request.json() except Exception: body = {} else: form = await request.form() body = dict(form) grant_type = body.get("grant_type", "") if grant_type == "authorization_code": code = body.get("code", "") client_id = body.get("client_id", "") code_verifier = body.get("code_verifier", "") code_info = AUTH_CODES.pop(code, None) if not code_info: return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400) if code_info["expires_at"] < time.time(): return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400) if code_info["client_id"] != client_id: return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) if code_info.get("code_challenge"): if not code_verifier: return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400) expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode() if expected != code_info["code_challenge"]: return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400) access_token = secrets.token_urlsafe(48) refresh_token = secrets.token_urlsafe(48) access_hash = _hash(access_token) refresh_hash = _hash(refresh_token) ACCESS_TOKENS[access_hash] = { "client_id": client_id, "scope": code_info["scope"], "expires_at": time.time() + ACCESS_TOKEN_TTL, "user": code_info["user"], } REFRESH_TOKENS[refresh_hash] = { "client_id": client_id, "scope": code_info["scope"], "expires_at": time.time() + REFRESH_TOKEN_TTL, "user": code_info["user"], "access_token_hash": access_hash, } logger.info(f"OAuth: issued access token for client {client_id}") return JSONResponse({ "access_token": access_token, "token_type": "Bearer", "expires_in": ACCESS_TOKEN_TTL, "refresh_token": refresh_token, "scope": code_info["scope"], }) elif grant_type == "refresh_token": refresh_token_val = body.get("refresh_token", "") client_id = body.get("client_id", "") refresh_hash = _hash(refresh_token_val) refresh_info = REFRESH_TOKENS.get(refresh_hash) if not refresh_info: return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400) if refresh_info["expires_at"] < time.time(): del REFRESH_TOKENS[refresh_hash] return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400) if refresh_info["client_id"] != client_id: return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) old_access_hash = refresh_info.get("access_token_hash") if old_access_hash and old_access_hash in ACCESS_TOKENS: del ACCESS_TOKENS[old_access_hash] del REFRESH_TOKENS[refresh_hash] new_access_token = secrets.token_urlsafe(48) new_refresh_token = secrets.token_urlsafe(48) new_access_hash = _hash(new_access_token) new_refresh_hash = _hash(new_refresh_token) ACCESS_TOKENS[new_access_hash] = { "client_id": client_id, "scope": refresh_info["scope"], "expires_at": time.time() + ACCESS_TOKEN_TTL, "user": refresh_info["user"], } REFRESH_TOKENS[new_refresh_hash] = { "client_id": client_id, "scope": refresh_info["scope"], "expires_at": time.time() + REFRESH_TOKEN_TTL, "user": refresh_info["user"], "access_token_hash": new_access_hash, } logger.info(f"OAuth: refreshed token for client {client_id}") return JSONResponse({ "access_token": new_access_token, "token_type": "Bearer", "expires_in": ACCESS_TOKEN_TTL, "refresh_token": new_refresh_token, "scope": refresh_info["scope"], }) return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) # --------------------------------------------------------------------------- # MCP Endpoint (OAuth-protected) # --------------------------------------------------------------------------- async def handle_mcp(request: Request) -> Response: if request.method == "HEAD": return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"}) if request.method == "GET": token_info = validate_bearer_token(request) if not token_info: return Response( status_code=401, headers={ "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', "Content-Type": "application/json", }, media_type="application/json", content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}), ) # Authenticated GET — return empty SSE stream (mcpo uses this for server-sent events) async def empty_sse(): yield b": ping\n\n" return StreamingResponse(empty_sse(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}) if request.method == "DELETE": session_id = request.headers.get("Mcp-Session-Id") if session_id and session_id in GATEWAY_SESSIONS: del GATEWAY_SESSIONS[session_id] return Response(status_code=200) # POST — require Bearer token token_info = validate_bearer_token(request) if not token_info: return Response( status_code=401, headers={ "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', "Content-Type": "application/json", }, media_type="application/json", content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}), ) try: body = await request.json() except Exception: return JSONResponse( {"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None}, status_code=400, ) method = body.get("method", "") params = body.get("params", {}) req_id = body.get("id") if method == "initialize": session_id = str(uuid.uuid4()) GATEWAY_SESSIONS[session_id] = True response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "result": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {"listChanged": False}}, "serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, }, }) response.headers["Mcp-Session-Id"] = session_id return response if method == "notifications/initialized": return Response(status_code=202) if method == "tools/list": tools = list(TOOL_DEFINITIONS.values()) response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}, }) session_id = request.headers.get("Mcp-Session-Id") if session_id: response.headers["Mcp-Session-Id"] = session_id return response if method == "tools/call": tool_name = params.get("name", "") arguments = params.get("arguments", {}) backend_name = TOOL_REGISTRY.get(tool_name) if not backend_name: return JSONResponse({ "jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"}, }) try: result = await forward_tool_call(backend_name, tool_name, arguments, req_id) if isinstance(result, dict) and "jsonrpc" in result: response = JSONResponse(result) elif isinstance(result, dict) and "result" in result: response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "result": result["result"], }) elif isinstance(result, dict) and "error" in result: response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "error": result["error"], }) else: response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "result": result, }) session_id = request.headers.get("Mcp-Session-Id") if session_id: response.headers["Mcp-Session-Id"] = session_id return response except Exception as e: logger.error(f"Tool call failed: {tool_name} - {e}") return JSONResponse({ "jsonrpc": "2.0", "id": req_id, "error": {"code": -32603, "message": str(e)}, }) if method == "ping": return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}}) return JSONResponse({ "jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}, }) # --------------------------------------------------------------------------- # Health / Status # --------------------------------------------------------------------------- async def health(request: Request) -> JSONResponse: return JSONResponse({ "status": "healthy", "oauth": "enabled", }) async def status(request: Request) -> JSONResponse: token_info = validate_bearer_token(request) if not token_info: return JSONResponse({"error": "unauthorized"}, status_code=401) return JSONResponse({ "backends": { name: { "url": url, "tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]), "session": BACKEND_SESSIONS.get(name), } for name, url in BACKENDS.items() }, "total_tools": len(TOOL_DEFINITIONS), "tool_list": sorted(TOOL_DEFINITIONS.keys()), }) # --------------------------------------------------------------------------- # Dashboard # --------------------------------------------------------------------------- DISPLAY_NAMES = { "erpnext": "ERPNext", "truenas": "TrueNAS", "homeassistant": "Home Assistant", "wave": "Wave Finance", "linkedin": "LinkedIn", } async def probe_backend(name: str, url: str) -> dict: """Live-probe a backend: initialize it and count its tools in real time.""" start = time.time() try: result = await mcp_request(url, "initialize", { "protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "mcp-gateway-dashboard", "version": "1.0.0"}, }) session_id = result.get("session_id") tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id) tools_data = tools_result.get("result", {}) if isinstance(tools_data, dict): if "result" in tools_data: tools_data = tools_data["result"] tools = tools_data.get("tools", []) else: tools = [] elapsed = round((time.time() - start) * 1000) return { "status": "healthy", "toolCount": len(tools), "responseTime": elapsed, } except Exception as e: elapsed = round((time.time() - start) * 1000) logger.warning(f"Dashboard probe failed for {name}: {e}") # Fall back to cached registry values if live probe fails cached_tools = len([t for t, b in TOOL_REGISTRY.items() if b == name]) return { "status": "healthy" if cached_tools > 0 else "unhealthy", "toolCount": cached_tools, "responseTime": elapsed, "note": "cached", } async def dashboard_status(request: Request) -> JSONResponse: """Auth-protected endpoint — live-probes all backends.""" token_info = validate_bearer_token(request) if not token_info: return JSONResponse({"error": "unauthorized"}, status_code=401) probes = await asyncio.gather( *[probe_backend(name, url) for name, url in BACKENDS.items()], return_exceptions=True ) services = [] for (name, url), probe in zip(BACKENDS.items(), probes): if isinstance(probe, Exception): probe = {"status": "unhealthy", "toolCount": 0, "responseTime": None} services.append({ "name": DISPLAY_NAMES.get(name, name.capitalize()), "key": name, "url": url, "status": probe["status"], "toolCount": probe["toolCount"], "responseTime": probe.get("responseTime"), "lastCheck": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), }) return JSONResponse({ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "services": services, "summary": { "total": len(services), "healthy": len([s for s in services if s["status"] == "healthy"]), "unhealthy": len([s for s in services if s["status"] == "unhealthy"]), "totalTools": sum(s["toolCount"] for s in services), } }) DASHBOARD_HTML = """Service Status Dashboard
Total Services
—
Total Tools
—
Healthy
—