diff --git a/mcp-gateway/gateway-proxy/gateway_proxy_fixed.py b/mcp-gateway/gateway-proxy/gateway_proxy_fixed.py deleted file mode 100644 index 36237f8..0000000 --- a/mcp-gateway/gateway-proxy/gateway_proxy_fixed.py +++ /dev/null @@ -1,707 +0,0 @@ -""" -MCP Gateway Proxy with OAuth 2.1 -================================= -Aggregates multiple MCP servers behind a single Streamable HTTP endpoint. -Implements a self-contained OAuth 2.1 provider compatible with claude.ai: - - RFC 8414 Authorization Server Metadata - - RFC 9728 Protected Resource Metadata - - RFC 7591 Dynamic Client Registration - - PKCE (S256) per OAuth 2.1 - - Authorization Code Grant with refresh tokens -""" - -import asyncio -import base64 -import hashlib -import html -import json -import logging -import os -import secrets -import time -import uuid -from contextlib import asynccontextmanager -from typing import Any -from urllib.parse import urlencode - -import httpx -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse -from starlette.routing import Route -import uvicorn - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -logger = logging.getLogger("mcp-gateway") - -# Import OpenAI-compatible routes -try: - from openai_routes import chat_completions, list_models - OPENAI_AVAILABLE = True - logger.info("✓ OpenAI-compatible routes imported successfully") -except ImportError as e: - OPENAI_AVAILABLE = False - logger.error(f"✗ Failed to import OpenAI routes: {e}") - import traceback - logger.error(f"Traceback: {traceback.format_exc()}") - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - -ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net") -GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "") -ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600")) -REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000")) -AUTH_CODE_TTL = 600 - -# --------------------------------------------------------------------------- -# In-memory stores -# --------------------------------------------------------------------------- - -REGISTERED_CLIENTS: dict[str, dict] = {} -AUTH_CODES: dict[str, dict] = {} -ACCESS_TOKENS: dict[str, dict] = {} -REFRESH_TOKENS: dict[str, dict] = {} -PENDING_AUTH: dict[str, dict] = {} - -def _hash(value: str) -> str: - return hashlib.sha256(value.encode()).hexdigest() - -def _clean_expired(): - now = time.time() - for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS): - expired = [k for k, v in store.items() if v.get("expires_at", 0) < now] - for k in expired: - del store[k] - -# --------------------------------------------------------------------------- -# Backend MCP aggregation -# --------------------------------------------------------------------------- - -def load_backends() -> dict[str, str]: - backends = {} - for key, value in os.environ.items(): - if key.startswith("MCP_BACKEND_"): - name = key[len("MCP_BACKEND_"):].lower() - backends[name] = value - logger.info(f"Backend configured: {name} -> {value}") - return backends - -BACKENDS: dict[str, str] = load_backends() -TOOL_REGISTRY: dict[str, str] = {} -TOOL_DEFINITIONS: dict[str, dict] = {} -BACKEND_SESSIONS: dict[str, str | None] = {} -GATEWAY_SESSIONS: dict[str, bool] = {} - -def parse_sse_response(text: str) -> dict | None: - for line in text.strip().split("\n"): - line = line.strip() - if line.startswith("data: "): - try: - return json.loads(line[6:]) - except json.JSONDecodeError: - continue - return None - -async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict: - payload = {"jsonrpc": "2.0", "method": method, "id": request_id} - if params is not None: - payload["params"] = params - headers = { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - "Host": "localhost", - } - if session_id: - headers["Mcp-Session-Id"] = session_id - async with httpx.AsyncClient(timeout=30) as client: - resp = await client.post(backend_url, json=payload, headers=headers) - new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id") - content_type = resp.headers.get("content-type", "") - if resp.status_code in (200, 201): - try: - if "text/event-stream" in content_type: - parsed = parse_sse_response(resp.text) - return {"result": parsed, "session_id": new_session} - else: - return {"result": resp.json(), "session_id": new_session} - except Exception: - return {"result": None, "session_id": new_session} - elif resp.status_code == 202: - return {"result": None, "session_id": new_session} - else: - logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}") - return {"result": None, "session_id": new_session} - -async def initialize_backend(name: str, url: str) -> list[dict]: - logger.info(f"Initializing backend: {name} at {url}") - for attempt in range(3): - try: - init_result = await mcp_request(url, "initialize", { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, - }) - session_id = init_result.get("session_id") - BACKEND_SESSIONS[name] = session_id - logger.info(f" {name}: initialized (session: {session_id})") - await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id) - tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id) - tools_data = tools_result.get("result", {}) - if isinstance(tools_data, dict): - if "result" in tools_data: - tools_data = tools_data["result"] - tools = tools_data.get("tools", []) - else: - tools = [] - logger.info(f" {name}: discovered {len(tools)} tools") - for tool in tools: - original_name = tool["name"] - prefixed_name = f"{name}_{original_name}" - tool["name"] = prefixed_name - TOOL_REGISTRY[prefixed_name] = name - TOOL_DEFINITIONS[prefixed_name] = tool - return tools - except Exception as e: - if attempt < 2: - logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...") - await asyncio.sleep(5) - else: - logger.error(f" {name}: failed to initialize after 3 attempts - {e}") - return [] - return [] - -async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict: - url = BACKENDS[backend_name] - session_id = BACKEND_SESSIONS.get(backend_name) - prefix = f"{backend_name}_" - original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name - if not session_id: - await initialize_backend(backend_name, url) - session_id = BACKEND_SESSIONS.get(backend_name) - result = await mcp_request( - url, "tools/call", - {"name": original_name, "arguments": arguments}, - request_id=request_id, - session_id=session_id, - ) - response_data = result.get("result", {}) - if isinstance(response_data, dict) and response_data.get("error"): - error_code = response_data["error"].get("code", 0) - if error_code in (-32600, -32601): - logger.info(f"Re-initializing {backend_name} after error {error_code}") - await initialize_backend(backend_name, url) - session_id = BACKEND_SESSIONS.get(backend_name) - result = await mcp_request( - url, "tools/call", - {"name": original_name, "arguments": arguments}, - request_id=request_id, - session_id=session_id, - ) - response_data = result.get("result", {}) - return response_data - -# --------------------------------------------------------------------------- -# OAuth 2.1: Token validation -# --------------------------------------------------------------------------- - -def validate_bearer_token(request: Request) -> dict | None: - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - return None - token = auth_header[7:] - token_hash = _hash(token) - info = ACCESS_TOKENS.get(token_hash) - if not info: - return None - if info["expires_at"] < time.time(): - del ACCESS_TOKENS[token_hash] - return None - return info - -# --------------------------------------------------------------------------- -# Consent page HTML template -# --------------------------------------------------------------------------- - -CONSENT_PAGE_CSS = """ -* { margin: 0; padding: 0; box-sizing: border-box; } -body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; - background: #0f172a; color: #e2e8f0; min-height: 100vh; - display: flex; align-items: center; justify-content: center; } -.card { background: #1e293b; border-radius: 16px; padding: 40px; - max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); } -h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; } -.subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; } -.client { background: #334155; border-radius: 8px; padding: 12px 16px; - margin-bottom: 24px; font-size: 14px; } -.client strong { color: #38bdf8; } -label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; } -input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px; - border: 1px solid #475569; background: #0f172a; color: #f8fafc; - font-size: 16px; margin-bottom: 20px; outline: none; } -input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); } -.buttons { display: flex; gap: 12px; } -button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px; - font-weight: 600; cursor: pointer; transition: all 0.15s; } -.approve { background: #38bdf8; color: #0f172a; } -.approve:hover { background: #7dd3fc; } -.deny { background: #334155; color: #94a3b8; } -.deny:hover { background: #475569; color: #e2e8f0; } -.scope { font-size: 13px; color: #64748b; margin-bottom: 16px; } -.error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px; - margin-bottom: 16px; font-size: 13px; } -""" - -def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str: - error_html = f'
Authorization Request
-Scope: {html.escape(scope)}
Please try connecting again.
", status_code=400) - redirect_uri = pending["redirect_uri"] - state = pending["state"] - if action != "approve": - qs = urlencode({"error": "access_denied", "state": state}) - return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) - if not GATEWAY_PASSWORD: - return HTMLResponse("OAUTH_PASSWORD not set.
", status_code=500) - if password != GATEWAY_PASSWORD: - PENDING_AUTH[internal_state] = pending - client = REGISTERED_CLIENTS.get(pending["client_id"], {}) - client_name = client.get("client_name", "Unknown Client") - return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again.")) - code = secrets.token_urlsafe(48) - AUTH_CODES[code] = { - "client_id": pending["client_id"], - "redirect_uri": redirect_uri, - "code_challenge": pending["code_challenge"], - "code_challenge_method": pending["code_challenge_method"], - "scope": pending["scope"], - "expires_at": time.time() + AUTH_CODE_TTL, - "user": "owner", - } - qs = urlencode({"code": code, "state": state}) - logger.info(f"OAuth: issued auth code for client {pending['client_id']}") - return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) - return JSONResponse({"error": "method_not_allowed"}, status_code=405) - -async def oauth_token(request: Request) -> JSONResponse: - _clean_expired() - content_type = request.headers.get("content-type", "") - if "application/json" in content_type: - try: - body = await request.json() - except Exception: - body = {} - else: - form = await request.form() - body = dict(form) - grant_type = body.get("grant_type", "") - if grant_type == "authorization_code": - code = body.get("code", "") - client_id = body.get("client_id", "") - code_verifier = body.get("code_verifier", "") - code_info = AUTH_CODES.pop(code, None) - if not code_info: - return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400) - if code_info["expires_at"] < time.time(): - return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400) - if code_info["client_id"] != client_id: - return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) - if code_info.get("code_challenge"): - if not code_verifier: - return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400) - expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode() - if expected != code_info["code_challenge"]: - return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400) - access_token = secrets.token_urlsafe(48) - refresh_token = secrets.token_urlsafe(48) - access_hash = _hash(access_token) - refresh_hash = _hash(refresh_token) - ACCESS_TOKENS[access_hash] = { - "client_id": client_id, - "scope": code_info["scope"], - "expires_at": time.time() + ACCESS_TOKEN_TTL, - "user": code_info["user"], - } - REFRESH_TOKENS[refresh_hash] = { - "client_id": client_id, - "scope": code_info["scope"], - "expires_at": time.time() + REFRESH_TOKEN_TTL, - "user": code_info["user"], - "access_token_hash": access_hash, - } - logger.info(f"OAuth: issued access token for client {client_id}") - return JSONResponse({ - "access_token": access_token, - "token_type": "Bearer", - "expires_in": ACCESS_TOKEN_TTL, - "refresh_token": refresh_token, - "scope": code_info["scope"], - }) - elif grant_type == "refresh_token": - refresh_token_val = body.get("refresh_token", "") - client_id = body.get("client_id", "") - refresh_hash = _hash(refresh_token_val) - refresh_info = REFRESH_TOKENS.get(refresh_hash) - if not refresh_info: - return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400) - if refresh_info["expires_at"] < time.time(): - del REFRESH_TOKENS[refresh_hash] - return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400) - if refresh_info["client_id"] != client_id: - return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) - old_access_hash = refresh_info.get("access_token_hash") - if old_access_hash and old_access_hash in ACCESS_TOKENS: - del ACCESS_TOKENS[old_access_hash] - del REFRESH_TOKENS[refresh_hash] - new_access_token = secrets.token_urlsafe(48) - new_refresh_token = secrets.token_urlsafe(48) - new_access_hash = _hash(new_access_token) - new_refresh_hash = _hash(new_refresh_token) - ACCESS_TOKENS[new_access_hash] = { - "client_id": client_id, - "scope": refresh_info["scope"], - "expires_at": time.time() + ACCESS_TOKEN_TTL, - "user": refresh_info["user"], - } - REFRESH_TOKENS[new_refresh_hash] = { - "client_id": client_id, - "scope": refresh_info["scope"], - "expires_at": time.time() + REFRESH_TOKEN_TTL, - "user": refresh_info["user"], - "access_token_hash": new_access_hash, - } - logger.info(f"OAuth: refreshed token for client {client_id}") - return JSONResponse({ - "access_token": new_access_token, - "token_type": "Bearer", - "expires_in": ACCESS_TOKEN_TTL, - "refresh_token": new_refresh_token, - "scope": refresh_info["scope"], - }) - return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) - -# --------------------------------------------------------------------------- -# MCP Endpoint (OAuth-protected) -# --------------------------------------------------------------------------- - -async def handle_mcp(request: Request) -> Response: - if request.method == "HEAD": - return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"}) - if request.method == "GET": - return Response( - status_code=401, - headers={ - "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', - "Content-Type": "application/json", - }, - media_type="application/json", - content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}), - ) - if request.method == "DELETE": - session_id = request.headers.get("Mcp-Session-Id") - if session_id and session_id in GATEWAY_SESSIONS: - del GATEWAY_SESSIONS[session_id] - return Response(status_code=200) - # POST — require Bearer token - token_info = validate_bearer_token(request) - if not token_info: - return Response( - status_code=401, - headers={ - "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', - "Content-Type": "application/json", - }, - media_type="application/json", - content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}), - ) - try: - body = await request.json() - except Exception: - return JSONResponse( - {"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None}, - status_code=400, - ) - method = body.get("method", "") - params = body.get("params", {}) - req_id = body.get("id") - if method == "initialize": - session_id = str(uuid.uuid4()) - GATEWAY_SESSIONS[session_id] = True - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "result": { - "protocolVersion": "2024-11-05", - "capabilities": {"tools": {"listChanged": False}}, - "serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, - }, - }) - response.headers["Mcp-Session-Id"] = session_id - return response - if method == "notifications/initialized": - return Response(status_code=202) - if method == "tools/list": - tools = list(TOOL_DEFINITIONS.values()) - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "result": {"tools": tools}, - }) - session_id = request.headers.get("Mcp-Session-Id") - if session_id: - response.headers["Mcp-Session-Id"] = session_id - return response - if method == "tools/call": - tool_name = params.get("name", "") - arguments = params.get("arguments", {}) - backend_name = TOOL_REGISTRY.get(tool_name) - if not backend_name: - return JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"}, - }) - try: - result = await forward_tool_call(backend_name, tool_name, arguments, req_id) - if isinstance(result, dict) and "jsonrpc" in result: - response = JSONResponse(result) - elif isinstance(result, dict) and "result" in result: - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "result": result["result"], - }) - elif isinstance(result, dict) and "error" in result: - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "error": result["error"], - }) - else: - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "result": result, - }) - session_id = request.headers.get("Mcp-Session-Id") - if session_id: - response.headers["Mcp-Session-Id"] = session_id - return response - except Exception as e: - logger.error(f"Tool call failed: {tool_name} - {e}") - return JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "error": {"code": -32603, "message": str(e)}, - }) - if method == "ping": - return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}}) - return JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "error": {"code": -32601, "message": f"Method not found: {method}"}, - }) - -# --------------------------------------------------------------------------- -# Health / Status -# --------------------------------------------------------------------------- - -async def health(request: Request) -> JSONResponse: - return JSONResponse({ - "status": "healthy", - "backends": len(BACKENDS), - "tools": len(TOOL_DEFINITIONS), - "oauth": "enabled", - "active_tokens": len(ACCESS_TOKENS), - "registered_clients": len(REGISTERED_CLIENTS), - }) - -async def status(request: Request) -> JSONResponse: - token_info = validate_bearer_token(request) - if not token_info: - return JSONResponse({"error": "unauthorized"}, status_code=401) - return JSONResponse({ - "backends": { - name: { - "url": url, - "tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]), - "session": BACKEND_SESSIONS.get(name), - } - for name, url in BACKENDS.items() - }, - "total_tools": len(TOOL_DEFINITIONS), - "tool_list": sorted(TOOL_DEFINITIONS.keys()), - }) - -# --------------------------------------------------------------------------- -# Startup -# --------------------------------------------------------------------------- - -async def startup(): - if not GATEWAY_PASSWORD: - logger.error("OAUTH_PASSWORD is not set! OAuth login will fail.") - else: - logger.info("OAuth password configured") - logger.info(f"OAuth issuer: {ISSUER_URL}") - logger.info(f"Starting MCP Gateway Proxy with {len(BACKENDS)} backends") - logger.info("Waiting 10s for backends to start...") - await asyncio.sleep(10) - for name, url in BACKENDS.items(): - tools = await initialize_backend(name, url) - if not tools: - logger.warning(f" {name}: no tools discovered — will retry on first request") - logger.info(f"Gateway ready: {len(TOOL_DEFINITIONS)} tools from {len(BACKENDS)} backends") - -@asynccontextmanager -async def lifespan(app): - await startup() - yield - -# Build routes list -routes = [ - # Well-known discovery (Claude tries both with and without /mcp suffix) - Route("/.well-known/oauth-protected-resource", well_known_protected_resource, methods=["GET"]), - Route("/.well-known/oauth-protected-resource/mcp", well_known_protected_resource, methods=["GET"]), - Route("/.well-known/oauth-authorization-server", well_known_oauth_authorization_server, methods=["GET"]), - Route("/.well-known/oauth-authorization-server/mcp", well_known_oauth_authorization_server, methods=["GET"]), - Route("/.well-known/openid-configuration", well_known_oauth_authorization_server, methods=["GET"]), - # OAuth endpoints at /oauth/* (spec-standard) - Route("/oauth/register", oauth_register, methods=["POST"]), - Route("/oauth/authorize", oauth_authorize, methods=["GET", "POST"]), - Route("/oauth/token", oauth_token, methods=["POST"]), - # OAuth endpoints at root (Claude may construct these from base URL) - Route("/register", oauth_register, methods=["POST"]), - Route("/authorize", oauth_authorize, methods=["GET", "POST"]), - Route("/token", oauth_token, methods=["POST"]), - # MCP endpoint (OAuth-protected) - Route("/mcp", handle_mcp, methods=["GET", "HEAD", "POST", "DELETE"]), - # Monitoring - Route("/health", health, methods=["GET"]), - Route("/status", status, methods=["GET"]), -] - -# Add OpenAI-compatible endpoints if available -if OPENAI_AVAILABLE: - routes.extend([ - # OpenAI-compatible endpoints - Route("/v1/models", list_models, methods=["GET"]), - Route("/v1/chat/completions", chat_completions, methods=["POST"]), - ]) - logger.info("OpenAI-compatible endpoints registered at /v1/*") - -app = Starlette( - routes=routes, - lifespan=lifespan, -) - -if __name__ == "__main__": - port = int(os.environ.get("PORT", "4444")) - uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")