diff --git a/mcp-gateway/gateway-proxy/gateway_proxy_fixed.py b/mcp-gateway/gateway-proxy/gateway_proxy_fixed.py new file mode 100644 index 0000000..36237f8 --- /dev/null +++ b/mcp-gateway/gateway-proxy/gateway_proxy_fixed.py @@ -0,0 +1,707 @@ +""" +MCP Gateway Proxy with OAuth 2.1 +================================= +Aggregates multiple MCP servers behind a single Streamable HTTP endpoint. +Implements a self-contained OAuth 2.1 provider compatible with claude.ai: + - RFC 8414 Authorization Server Metadata + - RFC 9728 Protected Resource Metadata + - RFC 7591 Dynamic Client Registration + - PKCE (S256) per OAuth 2.1 + - Authorization Code Grant with refresh tokens +""" + +import asyncio +import base64 +import hashlib +import html +import json +import logging +import os +import secrets +import time +import uuid +from contextlib import asynccontextmanager +from typing import Any +from urllib.parse import urlencode + +import httpx +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse +from starlette.routing import Route +import uvicorn + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger("mcp-gateway") + +# Import OpenAI-compatible routes +try: + from openai_routes import chat_completions, list_models + OPENAI_AVAILABLE = True + logger.info("✓ OpenAI-compatible routes imported successfully") +except ImportError as e: + OPENAI_AVAILABLE = False + logger.error(f"✗ Failed to import OpenAI routes: {e}") + import traceback + logger.error(f"Traceback: {traceback.format_exc()}") + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net") +GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "") +ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600")) +REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000")) +AUTH_CODE_TTL = 600 + +# --------------------------------------------------------------------------- +# In-memory stores +# --------------------------------------------------------------------------- + +REGISTERED_CLIENTS: dict[str, dict] = {} +AUTH_CODES: dict[str, dict] = {} +ACCESS_TOKENS: dict[str, dict] = {} +REFRESH_TOKENS: dict[str, dict] = {} +PENDING_AUTH: dict[str, dict] = {} + +def _hash(value: str) -> str: + return hashlib.sha256(value.encode()).hexdigest() + +def _clean_expired(): + now = time.time() + for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS): + expired = [k for k, v in store.items() if v.get("expires_at", 0) < now] + for k in expired: + del store[k] + +# --------------------------------------------------------------------------- +# Backend MCP aggregation +# --------------------------------------------------------------------------- + +def load_backends() -> dict[str, str]: + backends = {} + for key, value in os.environ.items(): + if key.startswith("MCP_BACKEND_"): + name = key[len("MCP_BACKEND_"):].lower() + backends[name] = value + logger.info(f"Backend configured: {name} -> {value}") + return backends + +BACKENDS: dict[str, str] = load_backends() +TOOL_REGISTRY: dict[str, str] = {} +TOOL_DEFINITIONS: dict[str, dict] = {} +BACKEND_SESSIONS: dict[str, str | None] = {} +GATEWAY_SESSIONS: dict[str, bool] = {} + +def parse_sse_response(text: str) -> dict | None: + for line in text.strip().split("\n"): + line = line.strip() + if line.startswith("data: "): + try: + return json.loads(line[6:]) + except json.JSONDecodeError: + continue + return None + +async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict: + payload = {"jsonrpc": "2.0", "method": method, "id": request_id} + if params is not None: + payload["params"] = params + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Host": "localhost", + } + if session_id: + headers["Mcp-Session-Id"] = session_id + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post(backend_url, json=payload, headers=headers) + new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id") + content_type = resp.headers.get("content-type", "") + if resp.status_code in (200, 201): + try: + if "text/event-stream" in content_type: + parsed = parse_sse_response(resp.text) + return {"result": parsed, "session_id": new_session} + else: + return {"result": resp.json(), "session_id": new_session} + except Exception: + return {"result": None, "session_id": new_session} + elif resp.status_code == 202: + return {"result": None, "session_id": new_session} + else: + logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}") + return {"result": None, "session_id": new_session} + +async def initialize_backend(name: str, url: str) -> list[dict]: + logger.info(f"Initializing backend: {name} at {url}") + for attempt in range(3): + try: + init_result = await mcp_request(url, "initialize", { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, + }) + session_id = init_result.get("session_id") + BACKEND_SESSIONS[name] = session_id + logger.info(f" {name}: initialized (session: {session_id})") + await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id) + tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id) + tools_data = tools_result.get("result", {}) + if isinstance(tools_data, dict): + if "result" in tools_data: + tools_data = tools_data["result"] + tools = tools_data.get("tools", []) + else: + tools = [] + logger.info(f" {name}: discovered {len(tools)} tools") + for tool in tools: + original_name = tool["name"] + prefixed_name = f"{name}_{original_name}" + tool["name"] = prefixed_name + TOOL_REGISTRY[prefixed_name] = name + TOOL_DEFINITIONS[prefixed_name] = tool + return tools + except Exception as e: + if attempt < 2: + logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...") + await asyncio.sleep(5) + else: + logger.error(f" {name}: failed to initialize after 3 attempts - {e}") + return [] + return [] + +async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict: + url = BACKENDS[backend_name] + session_id = BACKEND_SESSIONS.get(backend_name) + prefix = f"{backend_name}_" + original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name + if not session_id: + await initialize_backend(backend_name, url) + session_id = BACKEND_SESSIONS.get(backend_name) + result = await mcp_request( + url, "tools/call", + {"name": original_name, "arguments": arguments}, + request_id=request_id, + session_id=session_id, + ) + response_data = result.get("result", {}) + if isinstance(response_data, dict) and response_data.get("error"): + error_code = response_data["error"].get("code", 0) + if error_code in (-32600, -32601): + logger.info(f"Re-initializing {backend_name} after error {error_code}") + await initialize_backend(backend_name, url) + session_id = BACKEND_SESSIONS.get(backend_name) + result = await mcp_request( + url, "tools/call", + {"name": original_name, "arguments": arguments}, + request_id=request_id, + session_id=session_id, + ) + response_data = result.get("result", {}) + return response_data + +# --------------------------------------------------------------------------- +# OAuth 2.1: Token validation +# --------------------------------------------------------------------------- + +def validate_bearer_token(request: Request) -> dict | None: + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return None + token = auth_header[7:] + token_hash = _hash(token) + info = ACCESS_TOKENS.get(token_hash) + if not info: + return None + if info["expires_at"] < time.time(): + del ACCESS_TOKENS[token_hash] + return None + return info + +# --------------------------------------------------------------------------- +# Consent page HTML template +# --------------------------------------------------------------------------- + +CONSENT_PAGE_CSS = """ +* { margin: 0; padding: 0; box-sizing: border-box; } +body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; + background: #0f172a; color: #e2e8f0; min-height: 100vh; + display: flex; align-items: center; justify-content: center; } +.card { background: #1e293b; border-radius: 16px; padding: 40px; + max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); } +h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; } +.subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; } +.client { background: #334155; border-radius: 8px; padding: 12px 16px; + margin-bottom: 24px; font-size: 14px; } +.client strong { color: #38bdf8; } +label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; } +input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px; + border: 1px solid #475569; background: #0f172a; color: #f8fafc; + font-size: 16px; margin-bottom: 20px; outline: none; } +input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); } +.buttons { display: flex; gap: 12px; } +button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px; + font-weight: 600; cursor: pointer; transition: all 0.15s; } +.approve { background: #38bdf8; color: #0f172a; } +.approve:hover { background: #7dd3fc; } +.deny { background: #334155; color: #94a3b8; } +.deny:hover { background: #475569; color: #e2e8f0; } +.scope { font-size: 13px; color: #64748b; margin-bottom: 16px; } +.error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px; + margin-bottom: 16px; font-size: 13px; } +""" + +def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str: + error_html = f'
Authorization Request
+Scope: {html.escape(scope)}
Please try connecting again.
", status_code=400) + redirect_uri = pending["redirect_uri"] + state = pending["state"] + if action != "approve": + qs = urlencode({"error": "access_denied", "state": state}) + return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) + if not GATEWAY_PASSWORD: + return HTMLResponse("OAUTH_PASSWORD not set.
", status_code=500) + if password != GATEWAY_PASSWORD: + PENDING_AUTH[internal_state] = pending + client = REGISTERED_CLIENTS.get(pending["client_id"], {}) + client_name = client.get("client_name", "Unknown Client") + return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again.")) + code = secrets.token_urlsafe(48) + AUTH_CODES[code] = { + "client_id": pending["client_id"], + "redirect_uri": redirect_uri, + "code_challenge": pending["code_challenge"], + "code_challenge_method": pending["code_challenge_method"], + "scope": pending["scope"], + "expires_at": time.time() + AUTH_CODE_TTL, + "user": "owner", + } + qs = urlencode({"code": code, "state": state}) + logger.info(f"OAuth: issued auth code for client {pending['client_id']}") + return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) + return JSONResponse({"error": "method_not_allowed"}, status_code=405) + +async def oauth_token(request: Request) -> JSONResponse: + _clean_expired() + content_type = request.headers.get("content-type", "") + if "application/json" in content_type: + try: + body = await request.json() + except Exception: + body = {} + else: + form = await request.form() + body = dict(form) + grant_type = body.get("grant_type", "") + if grant_type == "authorization_code": + code = body.get("code", "") + client_id = body.get("client_id", "") + code_verifier = body.get("code_verifier", "") + code_info = AUTH_CODES.pop(code, None) + if not code_info: + return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400) + if code_info["expires_at"] < time.time(): + return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400) + if code_info["client_id"] != client_id: + return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) + if code_info.get("code_challenge"): + if not code_verifier: + return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400) + expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode() + if expected != code_info["code_challenge"]: + return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400) + access_token = secrets.token_urlsafe(48) + refresh_token = secrets.token_urlsafe(48) + access_hash = _hash(access_token) + refresh_hash = _hash(refresh_token) + ACCESS_TOKENS[access_hash] = { + "client_id": client_id, + "scope": code_info["scope"], + "expires_at": time.time() + ACCESS_TOKEN_TTL, + "user": code_info["user"], + } + REFRESH_TOKENS[refresh_hash] = { + "client_id": client_id, + "scope": code_info["scope"], + "expires_at": time.time() + REFRESH_TOKEN_TTL, + "user": code_info["user"], + "access_token_hash": access_hash, + } + logger.info(f"OAuth: issued access token for client {client_id}") + return JSONResponse({ + "access_token": access_token, + "token_type": "Bearer", + "expires_in": ACCESS_TOKEN_TTL, + "refresh_token": refresh_token, + "scope": code_info["scope"], + }) + elif grant_type == "refresh_token": + refresh_token_val = body.get("refresh_token", "") + client_id = body.get("client_id", "") + refresh_hash = _hash(refresh_token_val) + refresh_info = REFRESH_TOKENS.get(refresh_hash) + if not refresh_info: + return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400) + if refresh_info["expires_at"] < time.time(): + del REFRESH_TOKENS[refresh_hash] + return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400) + if refresh_info["client_id"] != client_id: + return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) + old_access_hash = refresh_info.get("access_token_hash") + if old_access_hash and old_access_hash in ACCESS_TOKENS: + del ACCESS_TOKENS[old_access_hash] + del REFRESH_TOKENS[refresh_hash] + new_access_token = secrets.token_urlsafe(48) + new_refresh_token = secrets.token_urlsafe(48) + new_access_hash = _hash(new_access_token) + new_refresh_hash = _hash(new_refresh_token) + ACCESS_TOKENS[new_access_hash] = { + "client_id": client_id, + "scope": refresh_info["scope"], + "expires_at": time.time() + ACCESS_TOKEN_TTL, + "user": refresh_info["user"], + } + REFRESH_TOKENS[new_refresh_hash] = { + "client_id": client_id, + "scope": refresh_info["scope"], + "expires_at": time.time() + REFRESH_TOKEN_TTL, + "user": refresh_info["user"], + "access_token_hash": new_access_hash, + } + logger.info(f"OAuth: refreshed token for client {client_id}") + return JSONResponse({ + "access_token": new_access_token, + "token_type": "Bearer", + "expires_in": ACCESS_TOKEN_TTL, + "refresh_token": new_refresh_token, + "scope": refresh_info["scope"], + }) + return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) + +# --------------------------------------------------------------------------- +# MCP Endpoint (OAuth-protected) +# --------------------------------------------------------------------------- + +async def handle_mcp(request: Request) -> Response: + if request.method == "HEAD": + return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"}) + if request.method == "GET": + return Response( + status_code=401, + headers={ + "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', + "Content-Type": "application/json", + }, + media_type="application/json", + content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}), + ) + if request.method == "DELETE": + session_id = request.headers.get("Mcp-Session-Id") + if session_id and session_id in GATEWAY_SESSIONS: + del GATEWAY_SESSIONS[session_id] + return Response(status_code=200) + # POST — require Bearer token + token_info = validate_bearer_token(request) + if not token_info: + return Response( + status_code=401, + headers={ + "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', + "Content-Type": "application/json", + }, + media_type="application/json", + content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}), + ) + try: + body = await request.json() + except Exception: + return JSONResponse( + {"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None}, + status_code=400, + ) + method = body.get("method", "") + params = body.get("params", {}) + req_id = body.get("id") + if method == "initialize": + session_id = str(uuid.uuid4()) + GATEWAY_SESSIONS[session_id] = True + response = JSONResponse({ + "jsonrpc": "2.0", + "id": req_id, + "result": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, + }, + }) + response.headers["Mcp-Session-Id"] = session_id + return response + if method == "notifications/initialized": + return Response(status_code=202) + if method == "tools/list": + tools = list(TOOL_DEFINITIONS.values()) + response = JSONResponse({ + "jsonrpc": "2.0", + "id": req_id, + "result": {"tools": tools}, + }) + session_id = request.headers.get("Mcp-Session-Id") + if session_id: + response.headers["Mcp-Session-Id"] = session_id + return response + if method == "tools/call": + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + backend_name = TOOL_REGISTRY.get(tool_name) + if not backend_name: + return JSONResponse({ + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"}, + }) + try: + result = await forward_tool_call(backend_name, tool_name, arguments, req_id) + if isinstance(result, dict) and "jsonrpc" in result: + response = JSONResponse(result) + elif isinstance(result, dict) and "result" in result: + response = JSONResponse({ + "jsonrpc": "2.0", + "id": req_id, + "result": result["result"], + }) + elif isinstance(result, dict) and "error" in result: + response = JSONResponse({ + "jsonrpc": "2.0", + "id": req_id, + "error": result["error"], + }) + else: + response = JSONResponse({ + "jsonrpc": "2.0", + "id": req_id, + "result": result, + }) + session_id = request.headers.get("Mcp-Session-Id") + if session_id: + response.headers["Mcp-Session-Id"] = session_id + return response + except Exception as e: + logger.error(f"Tool call failed: {tool_name} - {e}") + return JSONResponse({ + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32603, "message": str(e)}, + }) + if method == "ping": + return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}}) + return JSONResponse({ + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32601, "message": f"Method not found: {method}"}, + }) + +# --------------------------------------------------------------------------- +# Health / Status +# --------------------------------------------------------------------------- + +async def health(request: Request) -> JSONResponse: + return JSONResponse({ + "status": "healthy", + "backends": len(BACKENDS), + "tools": len(TOOL_DEFINITIONS), + "oauth": "enabled", + "active_tokens": len(ACCESS_TOKENS), + "registered_clients": len(REGISTERED_CLIENTS), + }) + +async def status(request: Request) -> JSONResponse: + token_info = validate_bearer_token(request) + if not token_info: + return JSONResponse({"error": "unauthorized"}, status_code=401) + return JSONResponse({ + "backends": { + name: { + "url": url, + "tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]), + "session": BACKEND_SESSIONS.get(name), + } + for name, url in BACKENDS.items() + }, + "total_tools": len(TOOL_DEFINITIONS), + "tool_list": sorted(TOOL_DEFINITIONS.keys()), + }) + +# --------------------------------------------------------------------------- +# Startup +# --------------------------------------------------------------------------- + +async def startup(): + if not GATEWAY_PASSWORD: + logger.error("OAUTH_PASSWORD is not set! OAuth login will fail.") + else: + logger.info("OAuth password configured") + logger.info(f"OAuth issuer: {ISSUER_URL}") + logger.info(f"Starting MCP Gateway Proxy with {len(BACKENDS)} backends") + logger.info("Waiting 10s for backends to start...") + await asyncio.sleep(10) + for name, url in BACKENDS.items(): + tools = await initialize_backend(name, url) + if not tools: + logger.warning(f" {name}: no tools discovered — will retry on first request") + logger.info(f"Gateway ready: {len(TOOL_DEFINITIONS)} tools from {len(BACKENDS)} backends") + +@asynccontextmanager +async def lifespan(app): + await startup() + yield + +# Build routes list +routes = [ + # Well-known discovery (Claude tries both with and without /mcp suffix) + Route("/.well-known/oauth-protected-resource", well_known_protected_resource, methods=["GET"]), + Route("/.well-known/oauth-protected-resource/mcp", well_known_protected_resource, methods=["GET"]), + Route("/.well-known/oauth-authorization-server", well_known_oauth_authorization_server, methods=["GET"]), + Route("/.well-known/oauth-authorization-server/mcp", well_known_oauth_authorization_server, methods=["GET"]), + Route("/.well-known/openid-configuration", well_known_oauth_authorization_server, methods=["GET"]), + # OAuth endpoints at /oauth/* (spec-standard) + Route("/oauth/register", oauth_register, methods=["POST"]), + Route("/oauth/authorize", oauth_authorize, methods=["GET", "POST"]), + Route("/oauth/token", oauth_token, methods=["POST"]), + # OAuth endpoints at root (Claude may construct these from base URL) + Route("/register", oauth_register, methods=["POST"]), + Route("/authorize", oauth_authorize, methods=["GET", "POST"]), + Route("/token", oauth_token, methods=["POST"]), + # MCP endpoint (OAuth-protected) + Route("/mcp", handle_mcp, methods=["GET", "HEAD", "POST", "DELETE"]), + # Monitoring + Route("/health", health, methods=["GET"]), + Route("/status", status, methods=["GET"]), +] + +# Add OpenAI-compatible endpoints if available +if OPENAI_AVAILABLE: + routes.extend([ + # OpenAI-compatible endpoints + Route("/v1/models", list_models, methods=["GET"]), + Route("/v1/chat/completions", chat_completions, methods=["POST"]), + ]) + logger.info("OpenAI-compatible endpoints registered at /v1/*") + +app = Starlette( + routes=routes, + lifespan=lifespan, +) + +if __name__ == "__main__": + port = int(os.environ.get("PORT", "4444")) + uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")