""" MCP Gateway Proxy with OAuth 2.1 ================================= Aggregates multiple MCP servers behind a single Streamable HTTP endpoint. Implements a self-contained OAuth 2.1 provider compatible with claude.ai: - RFC 8414 Authorization Server Metadata - RFC 9728 Protected Resource Metadata - RFC 7591 Dynamic Client Registration - PKCE (S256) per OAuth 2.1 - Authorization Code Grant with refresh tokens """ import asyncio import base64 import hashlib import html import json import logging import os import secrets import time import uuid from contextlib import asynccontextmanager from typing import Any from urllib.parse import urlencode import httpx from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse from starlette.routing import Route import uvicorn logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger("mcp-gateway") # Import OpenAI-compatible routes try: from openai_routes import chat_completions, list_models OPENAI_AVAILABLE = True logger.info("✓ OpenAI-compatible routes imported successfully") except ImportError as e: OPENAI_AVAILABLE = False logger.error(f"✗ Failed to import OpenAI routes: {e}") import traceback logger.error(f"Traceback: {traceback.format_exc()}") # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net") GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "") ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600")) REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000")) AUTH_CODE_TTL = 600 # --------------------------------------------------------------------------- # In-memory stores # --------------------------------------------------------------------------- REGISTERED_CLIENTS: dict[str, dict] = {} AUTH_CODES: dict[str, dict] = {} ACCESS_TOKENS: dict[str, dict] = {} REFRESH_TOKENS: dict[str, dict] = {} PENDING_AUTH: dict[str, dict] = {} def _hash(value: str) -> str: return hashlib.sha256(value.encode()).hexdigest() def _clean_expired(): now = time.time() for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS): expired = [k for k, v in store.items() if v.get("expires_at", 0) < now] for k in expired: del store[k] # --------------------------------------------------------------------------- # Backend MCP aggregation # --------------------------------------------------------------------------- def load_backends() -> dict[str, str]: backends = {} for key, value in os.environ.items(): if key.startswith("MCP_BACKEND_"): name = key[len("MCP_BACKEND_"):].lower() backends[name] = value logger.info(f"Backend configured: {name} -> {value}") return backends BACKENDS: dict[str, str] = load_backends() TOOL_REGISTRY: dict[str, str] = {} TOOL_DEFINITIONS: dict[str, dict] = {} BACKEND_SESSIONS: dict[str, str | None] = {} GATEWAY_SESSIONS: dict[str, bool] = {} def parse_sse_response(text: str) -> dict | None: for line in text.strip().split("\n"): line = line.strip() if line.startswith("data: "): try: return json.loads(line[6:]) except json.JSONDecodeError: continue return None async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict: payload = {"jsonrpc": "2.0", "method": method, "id": request_id} if params is not None: payload["params"] = params headers = { "Content-Type": "application/json", "Accept": "application/json, text/event-stream", "Host": "localhost", } if session_id: headers["Mcp-Session-Id"] = session_id async with httpx.AsyncClient(timeout=30) as client: resp = await client.post(backend_url, json=payload, headers=headers) new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id") content_type = resp.headers.get("content-type", "") if resp.status_code in (200, 201): try: if "text/event-stream" in content_type: parsed = parse_sse_response(resp.text) return {"result": parsed, "session_id": new_session} else: return {"result": resp.json(), "session_id": new_session} except Exception: return {"result": None, "session_id": new_session} elif resp.status_code == 202: return {"result": None, "session_id": new_session} else: logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}") return {"result": None, "session_id": new_session} async def initialize_backend(name: str, url: str) -> list[dict]: logger.info(f"Initializing backend: {name} at {url}") for attempt in range(3): try: init_result = await mcp_request(url, "initialize", { "protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, }) session_id = init_result.get("session_id") BACKEND_SESSIONS[name] = session_id logger.info(f" {name}: initialized (session: {session_id})") await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id) tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id) tools_data = tools_result.get("result", {}) if isinstance(tools_data, dict): if "result" in tools_data: tools_data = tools_data["result"] tools = tools_data.get("tools", []) else: tools = [] logger.info(f" {name}: discovered {len(tools)} tools") for tool in tools: original_name = tool["name"] prefixed_name = f"{name}_{original_name}" tool["name"] = prefixed_name TOOL_REGISTRY[prefixed_name] = name TOOL_DEFINITIONS[prefixed_name] = tool return tools except Exception as e: if attempt < 2: logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...") await asyncio.sleep(5) else: logger.error(f" {name}: failed to initialize after 3 attempts - {e}") return [] return [] async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict: url = BACKENDS[backend_name] session_id = BACKEND_SESSIONS.get(backend_name) prefix = f"{backend_name}_" original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name if not session_id: await initialize_backend(backend_name, url) session_id = BACKEND_SESSIONS.get(backend_name) result = await mcp_request( url, "tools/call", {"name": original_name, "arguments": arguments}, request_id=request_id, session_id=session_id, ) response_data = result.get("result", {}) if isinstance(response_data, dict) and response_data.get("error"): error_code = response_data["error"].get("code", 0) if error_code in (-32600, -32601): logger.info(f"Re-initializing {backend_name} after error {error_code}") await initialize_backend(backend_name, url) session_id = BACKEND_SESSIONS.get(backend_name) result = await mcp_request( url, "tools/call", {"name": original_name, "arguments": arguments}, request_id=request_id, session_id=session_id, ) response_data = result.get("result", {}) return response_data # --------------------------------------------------------------------------- # OAuth 2.1: Token validation # --------------------------------------------------------------------------- def validate_bearer_token(request: Request) -> dict | None: auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): return None token = auth_header[7:] token_hash = _hash(token) info = ACCESS_TOKENS.get(token_hash) if not info: return None if info["expires_at"] < time.time(): del ACCESS_TOKENS[token_hash] return None return info # --------------------------------------------------------------------------- # Consent page HTML template # --------------------------------------------------------------------------- CONSENT_PAGE_CSS = """ * { margin: 0; padding: 0; box-sizing: border-box; } body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: #0f172a; color: #e2e8f0; min-height: 100vh; display: flex; align-items: center; justify-content: center; } .card { background: #1e293b; border-radius: 16px; padding: 40px; max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); } h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; } .subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; } .client { background: #334155; border-radius: 8px; padding: 12px 16px; margin-bottom: 24px; font-size: 14px; } .client strong { color: #38bdf8; } label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; } input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px; border: 1px solid #475569; background: #0f172a; color: #f8fafc; font-size: 16px; margin-bottom: 20px; outline: none; } input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); } .buttons { display: flex; gap: 12px; } button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px; font-weight: 600; cursor: pointer; transition: all 0.15s; } .approve { background: #38bdf8; color: #0f172a; } .approve:hover { background: #7dd3fc; } .deny { background: #334155; color: #94a3b8; } .deny:hover { background: #475569; color: #e2e8f0; } .scope { font-size: 13px; color: #64748b; margin-bottom: 16px; } .error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px; margin-bottom: 16px; font-size: 13px; } """ def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str: error_html = f'
Authorization Request
Scope: {html.escape(scope)}
Please try connecting again.
", status_code=400) redirect_uri = pending["redirect_uri"] state = pending["state"] if action != "approve": qs = urlencode({"error": "access_denied", "state": state}) return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) if not GATEWAY_PASSWORD: return HTMLResponse("OAUTH_PASSWORD not set.
", status_code=500) if password != GATEWAY_PASSWORD: PENDING_AUTH[internal_state] = pending client = REGISTERED_CLIENTS.get(pending["client_id"], {}) client_name = client.get("client_name", "Unknown Client") return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again.")) code = secrets.token_urlsafe(48) AUTH_CODES[code] = { "client_id": pending["client_id"], "redirect_uri": redirect_uri, "code_challenge": pending["code_challenge"], "code_challenge_method": pending["code_challenge_method"], "scope": pending["scope"], "expires_at": time.time() + AUTH_CODE_TTL, "user": "owner", } qs = urlencode({"code": code, "state": state}) logger.info(f"OAuth: issued auth code for client {pending['client_id']}") return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) return JSONResponse({"error": "method_not_allowed"}, status_code=405) async def oauth_token(request: Request) -> JSONResponse: _clean_expired() content_type = request.headers.get("content-type", "") if "application/json" in content_type: try: body = await request.json() except Exception: body = {} else: form = await request.form() body = dict(form) grant_type = body.get("grant_type", "") if grant_type == "authorization_code": code = body.get("code", "") client_id = body.get("client_id", "") code_verifier = body.get("code_verifier", "") code_info = AUTH_CODES.pop(code, None) if not code_info: return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400) if code_info["expires_at"] < time.time(): return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400) if code_info["client_id"] != client_id: return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) if code_info.get("code_challenge"): if not code_verifier: return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400) expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode() if expected != code_info["code_challenge"]: return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400) access_token = secrets.token_urlsafe(48) refresh_token = secrets.token_urlsafe(48) access_hash = _hash(access_token) refresh_hash = _hash(refresh_token) ACCESS_TOKENS[access_hash] = { "client_id": client_id, "scope": code_info["scope"], "expires_at": time.time() + ACCESS_TOKEN_TTL, "user": code_info["user"], } REFRESH_TOKENS[refresh_hash] = { "client_id": client_id, "scope": code_info["scope"], "expires_at": time.time() + REFRESH_TOKEN_TTL, "user": code_info["user"], "access_token_hash": access_hash, } logger.info(f"OAuth: issued access token for client {client_id}") return JSONResponse({ "access_token": access_token, "token_type": "Bearer", "expires_in": ACCESS_TOKEN_TTL, "refresh_token": refresh_token, "scope": code_info["scope"], }) elif grant_type == "refresh_token": refresh_token_val = body.get("refresh_token", "") client_id = body.get("client_id", "") refresh_hash = _hash(refresh_token_val) refresh_info = REFRESH_TOKENS.get(refresh_hash) if not refresh_info: return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400) if refresh_info["expires_at"] < time.time(): del REFRESH_TOKENS[refresh_hash] return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400) if refresh_info["client_id"] != client_id: return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) old_access_hash = refresh_info.get("access_token_hash") if old_access_hash and old_access_hash in ACCESS_TOKENS: del ACCESS_TOKENS[old_access_hash] del REFRESH_TOKENS[refresh_hash] new_access_token = secrets.token_urlsafe(48) new_refresh_token = secrets.token_urlsafe(48) new_access_hash = _hash(new_access_token) new_refresh_hash = _hash(new_refresh_token) ACCESS_TOKENS[new_access_hash] = { "client_id": client_id, "scope": refresh_info["scope"], "expires_at": time.time() + ACCESS_TOKEN_TTL, "user": refresh_info["user"], } REFRESH_TOKENS[new_refresh_hash] = { "client_id": client_id, "scope": refresh_info["scope"], "expires_at": time.time() + REFRESH_TOKEN_TTL, "user": refresh_info["user"], "access_token_hash": new_access_hash, } logger.info(f"OAuth: refreshed token for client {client_id}") return JSONResponse({ "access_token": new_access_token, "token_type": "Bearer", "expires_in": ACCESS_TOKEN_TTL, "refresh_token": new_refresh_token, "scope": refresh_info["scope"], }) return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) # --------------------------------------------------------------------------- # MCP Endpoint (OAuth-protected) # --------------------------------------------------------------------------- async def handle_mcp(request: Request) -> Response: if request.method == "HEAD": return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"}) if request.method == "GET": return Response( status_code=401, headers={ "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', "Content-Type": "application/json", }, media_type="application/json", content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}), ) if request.method == "DELETE": session_id = request.headers.get("Mcp-Session-Id") if session_id and session_id in GATEWAY_SESSIONS: del GATEWAY_SESSIONS[session_id] return Response(status_code=200) # POST — require Bearer token token_info = validate_bearer_token(request) if not token_info: return Response( status_code=401, headers={ "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', "Content-Type": "application/json", }, media_type="application/json", content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}), ) try: body = await request.json() except Exception: return JSONResponse( {"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None}, status_code=400, ) method = body.get("method", "") params = body.get("params", {}) req_id = body.get("id") if method == "initialize": session_id = str(uuid.uuid4()) GATEWAY_SESSIONS[session_id] = True response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "result": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {"listChanged": False}}, "serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, }, }) response.headers["Mcp-Session-Id"] = session_id return response if method == "notifications/initialized": return Response(status_code=202) if method == "tools/list": tools = list(TOOL_DEFINITIONS.values()) response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}, }) session_id = request.headers.get("Mcp-Session-Id") if session_id: response.headers["Mcp-Session-Id"] = session_id return response if method == "tools/call": tool_name = params.get("name", "") arguments = params.get("arguments", {}) backend_name = TOOL_REGISTRY.get(tool_name) if not backend_name: return JSONResponse({ "jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"}, }) try: result = await forward_tool_call(backend_name, tool_name, arguments, req_id) if isinstance(result, dict) and "jsonrpc" in result: response = JSONResponse(result) elif isinstance(result, dict) and "result" in result: response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "result": result["result"], }) elif isinstance(result, dict) and "error" in result: response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "error": result["error"], }) else: response = JSONResponse({ "jsonrpc": "2.0", "id": req_id, "result": result, }) session_id = request.headers.get("Mcp-Session-Id") if session_id: response.headers["Mcp-Session-Id"] = session_id return response except Exception as e: logger.error(f"Tool call failed: {tool_name} - {e}") return JSONResponse({ "jsonrpc": "2.0", "id": req_id, "error": {"code": -32603, "message": str(e)}, }) if method == "ping": return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}}) return JSONResponse({ "jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}, }) # --------------------------------------------------------------------------- # Health / Status # --------------------------------------------------------------------------- async def health(request: Request) -> JSONResponse: return JSONResponse({ "status": "healthy", "backends": len(BACKENDS), "tools": len(TOOL_DEFINITIONS), "oauth": "enabled", "active_tokens": len(ACCESS_TOKENS), "registered_clients": len(REGISTERED_CLIENTS), }) async def status(request: Request) -> JSONResponse: token_info = validate_bearer_token(request) if not token_info: return JSONResponse({"error": "unauthorized"}, status_code=401) return JSONResponse({ "backends": { name: { "url": url, "tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]), "session": BACKEND_SESSIONS.get(name), } for name, url in BACKENDS.items() }, "total_tools": len(TOOL_DEFINITIONS), "tool_list": sorted(TOOL_DEFINITIONS.keys()), }) # --------------------------------------------------------------------------- # Startup # --------------------------------------------------------------------------- async def startup(): if not GATEWAY_PASSWORD: logger.error("OAUTH_PASSWORD is not set! OAuth login will fail.") else: logger.info("OAuth password configured") logger.info(f"OAuth issuer: {ISSUER_URL}") logger.info(f"Starting MCP Gateway Proxy with {len(BACKENDS)} backends") logger.info("Waiting 10s for backends to start...") await asyncio.sleep(10) for name, url in BACKENDS.items(): tools = await initialize_backend(name, url) if not tools: logger.warning(f" {name}: no tools discovered — will retry on first request") logger.info(f"Gateway ready: {len(TOOL_DEFINITIONS)} tools from {len(BACKENDS)} backends") @asynccontextmanager async def lifespan(app): await startup() yield # Build routes list routes = [ # Well-known discovery (Claude tries both with and without /mcp suffix) Route("/.well-known/oauth-protected-resource", well_known_protected_resource, methods=["GET"]), Route("/.well-known/oauth-protected-resource/mcp", well_known_protected_resource, methods=["GET"]), Route("/.well-known/oauth-authorization-server", well_known_oauth_authorization_server, methods=["GET"]), Route("/.well-known/oauth-authorization-server/mcp", well_known_oauth_authorization_server, methods=["GET"]), Route("/.well-known/openid-configuration", well_known_oauth_authorization_server, methods=["GET"]), # OAuth endpoints at /oauth/* (spec-standard) Route("/oauth/register", oauth_register, methods=["POST"]), Route("/oauth/authorize", oauth_authorize, methods=["GET", "POST"]), Route("/oauth/token", oauth_token, methods=["POST"]), # OAuth endpoints at root (Claude may construct these from base URL) Route("/register", oauth_register, methods=["POST"]), Route("/authorize", oauth_authorize, methods=["GET", "POST"]), Route("/token", oauth_token, methods=["POST"]), # MCP endpoint (OAuth-protected) Route("/mcp", handle_mcp, methods=["GET", "HEAD", "POST", "DELETE"]), # Monitoring Route("/health", health, methods=["GET"]), Route("/status", status, methods=["GET"]), ] # Add OpenAI-compatible endpoints if available if OPENAI_AVAILABLE: routes.extend([ # OpenAI-compatible endpoints Route("/v1/models", list_models, methods=["GET"]), Route("/v1/chat/completions", chat_completions, methods=["POST"]), ]) logger.info("OpenAI-compatible endpoints registered at /v1/*") app = Starlette( routes=routes, lifespan=lifespan, ) if __name__ == "__main__": port = int(os.environ.get("PORT", "4444")) uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")