From 5ab91785b0590488ff015ab5dd00be43a35c0fb8 Mon Sep 17 00:00:00 2001 From: zgaetano Date: Tue, 31 Mar 2026 15:33:09 -0400 Subject: [PATCH] Remove mcp-gateway/gateway-proxy/gateway_proxy.py --- mcp-gateway/gateway-proxy/gateway_proxy.py | 1106 -------------------- 1 file changed, 1106 deletions(-) delete mode 100644 mcp-gateway/gateway-proxy/gateway_proxy.py diff --git a/mcp-gateway/gateway-proxy/gateway_proxy.py b/mcp-gateway/gateway-proxy/gateway_proxy.py deleted file mode 100644 index 1adc32e..0000000 --- a/mcp-gateway/gateway-proxy/gateway_proxy.py +++ /dev/null @@ -1,1106 +0,0 @@ -""" -MCP Gateway Proxy with OAuth 2.1 -================================= -Aggregates multiple MCP servers behind a single Streamable HTTP endpoint. -Implements a self-contained OAuth 2.1 provider compatible with claude.ai: - - RFC 8414 Authorization Server Metadata - - RFC 9728 Protected Resource Metadata - - RFC 7591 Dynamic Client Registration - - PKCE (S256) per OAuth 2.1 - - Authorization Code Grant with refresh tokens -""" - -import asyncio -import base64 -import hashlib -import html -import json -import logging -import os -import secrets -import time -import uuid -from contextlib import asynccontextmanager -from typing import Any -from urllib.parse import urlencode - -import httpx -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse -from starlette.routing import Route -import uvicorn - -from user_routes import ( - create_user, list_users, get_user, delete_user, - toggle_user, generate_api_key, revoke_api_key, set_mcp_access -) -from user_dashboard_ui import user_management_dashboard as user_dashboard - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -logger = logging.getLogger("mcp-gateway") - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - -ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net") -GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "") -ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600")) -REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000")) -AUTH_CODE_TTL = 600 -STATIC_API_KEY = os.environ.get("GATEWAY_STATIC_API_KEY", "") - -# --------------------------------------------------------------------------- -# In-memory stores -# --------------------------------------------------------------------------- - -REGISTERED_CLIENTS: dict[str, dict] = {} -AUTH_CODES: dict[str, dict] = {} -ACCESS_TOKENS: dict[str, dict] = {} -REFRESH_TOKENS: dict[str, dict] = {} -PENDING_AUTH: dict[str, dict] = {} - - -def _hash(value: str) -> str: - return hashlib.sha256(value.encode()).hexdigest() - - -def _clean_expired(): - now = time.time() - for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS): - expired = [k for k, v in store.items() if v.get("expires_at", 0) < now] - for k in expired: - del store[k] - - -# --------------------------------------------------------------------------- -# Backend MCP aggregation -# --------------------------------------------------------------------------- - -def load_backends() -> dict[str, str]: - backends = {} - for key, value in os.environ.items(): - if key.startswith("MCP_BACKEND_"): - name = key[len("MCP_BACKEND_"):].lower() - backends[name] = value - logger.info(f"Backend configured: {name} -> {value}") - return backends - - -BACKENDS: dict[str, str] = load_backends() -TOOL_REGISTRY: dict[str, str] = {} -TOOL_DEFINITIONS: dict[str, dict] = {} -BACKEND_SESSIONS: dict[str, str | None] = {} -GATEWAY_SESSIONS: dict[str, bool] = {} - - -def parse_sse_response(text: str) -> dict | None: - for line in text.strip().split("\n"): - line = line.strip() - if line.startswith("data: "): - try: - return json.loads(line[6:]) - except json.JSONDecodeError: - continue - return None - - -async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict: - payload = {"jsonrpc": "2.0", "method": method, "id": request_id} - if params is not None: - payload["params"] = params - - headers = { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - "Host": "localhost", - } - if session_id: - headers["Mcp-Session-Id"] = session_id - - async with httpx.AsyncClient(timeout=30) as client: - resp = await client.post(backend_url, json=payload, headers=headers) - new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id") - content_type = resp.headers.get("content-type", "") - - if resp.status_code in (200, 201): - try: - if "text/event-stream" in content_type: - parsed = parse_sse_response(resp.text) - return {"result": parsed, "session_id": new_session} - else: - return {"result": resp.json(), "session_id": new_session} - except Exception: - return {"result": None, "session_id": new_session} - elif resp.status_code == 202: - return {"result": None, "session_id": new_session} - else: - logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}") - return {"result": None, "session_id": new_session} - - -def _normalize_schema(obj: Any) -> None: - """Recursively normalize JSON schemas so mcpo can parse them. - Fixes: - - type as list (e.g. ["string","null"]) -> "string" - - items as list (tuple validation) -> {"type": "string"} - - anyOf/oneOf containing non-dict entries - """ - if isinstance(obj, dict): - if isinstance(obj.get("type"), list): - obj["type"] = "string" - # Convert tuple-style items (list) to simple dict schema - if "items" in obj and isinstance(obj["items"], list): - obj["items"] = {"type": "string"} - # Clean anyOf/oneOf entries that might contain non-dicts - for key in ("anyOf", "oneOf"): - if key in obj and isinstance(obj[key], list): - obj[key] = [x for x in obj[key] if isinstance(x, dict)] - if len(obj[key]) == 1: - # Collapse single-option anyOf/oneOf - obj.update(obj.pop(key)[0]) - elif len(obj[key]) == 0: - del obj[key] - obj["type"] = "string" - for v in obj.values(): - _normalize_schema(v) - elif isinstance(obj, list): - for item in obj: - _normalize_schema(item) - - -async def initialize_backend(name: str, url: str) -> list[dict]: - logger.info(f"Initializing backend: {name} at {url}") - for attempt in range(3): - try: - init_result = await mcp_request(url, "initialize", { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, - }) - session_id = init_result.get("session_id") - BACKEND_SESSIONS[name] = session_id - logger.info(f" {name}: initialized (session: {session_id})") - - await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id) - - tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id) - tools_data = tools_result.get("result", {}) - - if isinstance(tools_data, dict): - if "result" in tools_data: - tools_data = tools_data["result"] - tools = tools_data.get("tools", []) - else: - tools = [] - - logger.info(f" {name}: discovered {len(tools)} tools") - for tool in tools: - original_name = tool["name"] - prefixed_name = f"{name}_{original_name}" - tool["name"] = prefixed_name - # Recursively normalize list-type schemas (e.g. ["string","null"]) to "string" - # so mcpo can parse the schema without crashing - _normalize_schema(tool.get("inputSchema", {})) - TOOL_REGISTRY[prefixed_name] = name - TOOL_DEFINITIONS[prefixed_name] = tool - - return tools - except Exception as e: - if attempt < 2: - logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...") - await asyncio.sleep(5) - else: - logger.error(f" {name}: failed to initialize after 3 attempts - {e}") - return [] - return [] - - -async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict: - url = BACKENDS[backend_name] - session_id = BACKEND_SESSIONS.get(backend_name) - - prefix = f"{backend_name}_" - original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name - - if not session_id: - await initialize_backend(backend_name, url) - session_id = BACKEND_SESSIONS.get(backend_name) - - result = await mcp_request( - url, "tools/call", - {"name": original_name, "arguments": arguments}, - request_id=request_id, - session_id=session_id, - ) - response_data = result.get("result", {}) - - if isinstance(response_data, dict) and response_data.get("error"): - error_code = response_data["error"].get("code", 0) - if error_code in (-32600, -32601): - logger.info(f"Re-initializing {backend_name} after error {error_code}") - await initialize_backend(backend_name, url) - session_id = BACKEND_SESSIONS.get(backend_name) - result = await mcp_request( - url, "tools/call", - {"name": original_name, "arguments": arguments}, - request_id=request_id, - session_id=session_id, - ) - response_data = result.get("result", {}) - - return response_data - - -# --------------------------------------------------------------------------- -# OAuth 2.1: Token validation -# --------------------------------------------------------------------------- - -def validate_bearer_token(request: Request) -> dict | None: - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - return None - token = auth_header[7:] - # Check static API key first (persistent, survives restarts) - if STATIC_API_KEY and token == STATIC_API_KEY: - return {"client_id": "static", "scope": "mcp:tools", "user": "static"} - token_hash = _hash(token) - info = ACCESS_TOKENS.get(token_hash) - if not info: - return None - if info["expires_at"] < time.time(): - del ACCESS_TOKENS[token_hash] - return None - return info - - -# --------------------------------------------------------------------------- -# Consent page HTML template -# --------------------------------------------------------------------------- - -CONSENT_PAGE_CSS = """ -* { margin: 0; padding: 0; box-sizing: border-box; } -body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; - background: #0f172a; color: #e2e8f0; min-height: 100vh; - display: flex; align-items: center; justify-content: center; } -.card { background: #1e293b; border-radius: 16px; padding: 40px; - max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); } -h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; } -.subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; } -.client { background: #334155; border-radius: 8px; padding: 12px 16px; - margin-bottom: 24px; font-size: 14px; } -.client strong { color: #38bdf8; } -label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; } -input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px; - border: 1px solid #475569; background: #0f172a; color: #f8fafc; - font-size: 16px; margin-bottom: 20px; outline: none; } -input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); } -.buttons { display: flex; gap: 12px; } -button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px; - font-weight: 600; cursor: pointer; transition: all 0.15s; } -.approve { background: #38bdf8; color: #0f172a; } -.approve:hover { background: #7dd3fc; } -.deny { background: #334155; color: #94a3b8; } -.deny:hover { background: #475569; color: #e2e8f0; } -.scope { font-size: 13px; color: #64748b; margin-bottom: 16px; } -.error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px; - margin-bottom: 16px; font-size: 13px; } -""" - - -def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str: - error_html = f'
{html.escape(error_msg)}
' if error_msg else "" - return f""" - - - - - MCP Gateway — Authorize - - - -
-

MCP Gateway

-

Authorization Request

-
{html.escape(client_name)} wants access to your MCP tools.
-

Scope: {html.escape(scope)}

- {error_html} -
- - - -
- - -
-
-
- -""" - - -# --------------------------------------------------------------------------- -# OAuth 2.1 Endpoints -# --------------------------------------------------------------------------- - -async def well_known_protected_resource(request: Request) -> JSONResponse: - return JSONResponse({ - "resource": ISSUER_URL, - "authorization_servers": [ISSUER_URL], - "scopes_supported": ["mcp:tools"], - "bearer_methods_supported": ["header"], - }) - - -async def well_known_oauth_authorization_server(request: Request) -> JSONResponse: - return JSONResponse({ - "issuer": ISSUER_URL, - "authorization_endpoint": f"{ISSUER_URL}/authorize", - "token_endpoint": f"{ISSUER_URL}/token", - "registration_endpoint": f"{ISSUER_URL}/register", - "scopes_supported": ["mcp:tools"], - "response_types_supported": ["code"], - "grant_types_supported": ["authorization_code", "refresh_token"], - "token_endpoint_auth_methods_supported": ["client_secret_post", "none"], - "code_challenge_methods_supported": ["S256"], - "service_documentation": f"{ISSUER_URL}/health", - }) - - -async def oauth_register(request: Request) -> JSONResponse: - try: - body = await request.json() - except Exception: - return JSONResponse({"error": "invalid_request"}, status_code=400) - - client_id = str(uuid.uuid4()) - client_secret = secrets.token_urlsafe(48) - - client_info = { - "client_id": client_id, - "client_secret": client_secret, - "client_name": body.get("client_name", "Unknown Client"), - "redirect_uris": body.get("redirect_uris", []), - "grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]), - "response_types": body.get("response_types", ["code"]), - "token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_post"), - } - REGISTERED_CLIENTS[client_id] = client_info - logger.info(f"DCR: registered client '{client_info['client_name']}' as {client_id}") - - return JSONResponse(client_info, status_code=201) - - -async def oauth_authorize(request: Request) -> Response: - _clean_expired() - - if request.method == "GET": - params = request.query_params - client_id = params.get("client_id", "") - redirect_uri = params.get("redirect_uri", "") - state = params.get("state", "") - scope = params.get("scope", "mcp:tools") - code_challenge = params.get("code_challenge", "") - code_challenge_method = params.get("code_challenge_method", "S256") - response_type = params.get("response_type", "code") - - if response_type != "code": - return JSONResponse({"error": "unsupported_response_type"}, status_code=400) - - client = REGISTERED_CLIENTS.get(client_id) - if not client: - # Auto-register unknown clients at authorize time (handles clients that - # skip DCR or whose registration was lost on gateway restart) - client = { - "client_id": client_id, - "client_secret": secrets.token_urlsafe(48), - "client_name": f"auto-{client_id[:8]}", - "redirect_uris": [redirect_uri] if redirect_uri else [], - "grant_types": ["authorization_code", "refresh_token"], - "response_types": ["code"], - "token_endpoint_auth_method": "none", - } - REGISTERED_CLIENTS[client_id] = client - logger.info(f"OAuth: auto-registered unknown client {client_id} at authorize") - - internal_state = secrets.token_urlsafe(32) - PENDING_AUTH[internal_state] = { - "client_id": client_id, - "redirect_uri": redirect_uri, - "code_challenge": code_challenge, - "code_challenge_method": code_challenge_method, - "scope": scope, - "state": state, - } - - client_name = client.get("client_name", "Unknown Client") - return HTMLResponse(render_consent_page(client_name, scope, internal_state)) - - if request.method == "POST": - form = await request.form() - internal_state = form.get("internal_state", "") - password = form.get("password", "") - action = form.get("action", "") - - pending = PENDING_AUTH.pop(internal_state, None) - if not pending: - return HTMLResponse("

Authorization expired

Please try connecting again.

", status_code=400) - - redirect_uri = pending["redirect_uri"] - state = pending["state"] - - if action != "approve": - qs = urlencode({"error": "access_denied", "state": state}) - return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) - - if not GATEWAY_PASSWORD: - return HTMLResponse("

Server misconfigured

OAUTH_PASSWORD not set.

", status_code=500) - - if password != GATEWAY_PASSWORD: - PENDING_AUTH[internal_state] = pending - client = REGISTERED_CLIENTS.get(pending["client_id"], {}) - client_name = client.get("client_name", "Unknown Client") - return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again.")) - - code = secrets.token_urlsafe(48) - AUTH_CODES[code] = { - "client_id": pending["client_id"], - "redirect_uri": redirect_uri, - "code_challenge": pending["code_challenge"], - "code_challenge_method": pending["code_challenge_method"], - "scope": pending["scope"], - "expires_at": time.time() + AUTH_CODE_TTL, - "user": "owner", - } - - qs = urlencode({"code": code, "state": state}) - logger.info(f"OAuth: issued auth code for client {pending['client_id']}") - return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302) - - return JSONResponse({"error": "method_not_allowed"}, status_code=405) - - -async def oauth_token(request: Request) -> JSONResponse: - _clean_expired() - - content_type = request.headers.get("content-type", "") - if "application/json" in content_type: - try: - body = await request.json() - except Exception: - body = {} - else: - form = await request.form() - body = dict(form) - - grant_type = body.get("grant_type", "") - - if grant_type == "authorization_code": - code = body.get("code", "") - client_id = body.get("client_id", "") - code_verifier = body.get("code_verifier", "") - - code_info = AUTH_CODES.pop(code, None) - if not code_info: - return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400) - - if code_info["expires_at"] < time.time(): - return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400) - - if code_info["client_id"] != client_id: - return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) - - if code_info.get("code_challenge"): - if not code_verifier: - return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400) - expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode() - if expected != code_info["code_challenge"]: - return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400) - - access_token = secrets.token_urlsafe(48) - refresh_token = secrets.token_urlsafe(48) - access_hash = _hash(access_token) - refresh_hash = _hash(refresh_token) - - ACCESS_TOKENS[access_hash] = { - "client_id": client_id, - "scope": code_info["scope"], - "expires_at": time.time() + ACCESS_TOKEN_TTL, - "user": code_info["user"], - } - REFRESH_TOKENS[refresh_hash] = { - "client_id": client_id, - "scope": code_info["scope"], - "expires_at": time.time() + REFRESH_TOKEN_TTL, - "user": code_info["user"], - "access_token_hash": access_hash, - } - - logger.info(f"OAuth: issued access token for client {client_id}") - return JSONResponse({ - "access_token": access_token, - "token_type": "Bearer", - "expires_in": ACCESS_TOKEN_TTL, - "refresh_token": refresh_token, - "scope": code_info["scope"], - }) - - elif grant_type == "refresh_token": - refresh_token_val = body.get("refresh_token", "") - client_id = body.get("client_id", "") - - refresh_hash = _hash(refresh_token_val) - refresh_info = REFRESH_TOKENS.get(refresh_hash) - - if not refresh_info: - return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400) - - if refresh_info["expires_at"] < time.time(): - del REFRESH_TOKENS[refresh_hash] - return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400) - - if refresh_info["client_id"] != client_id: - return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400) - - old_access_hash = refresh_info.get("access_token_hash") - if old_access_hash and old_access_hash in ACCESS_TOKENS: - del ACCESS_TOKENS[old_access_hash] - - del REFRESH_TOKENS[refresh_hash] - - new_access_token = secrets.token_urlsafe(48) - new_refresh_token = secrets.token_urlsafe(48) - new_access_hash = _hash(new_access_token) - new_refresh_hash = _hash(new_refresh_token) - - ACCESS_TOKENS[new_access_hash] = { - "client_id": client_id, - "scope": refresh_info["scope"], - "expires_at": time.time() + ACCESS_TOKEN_TTL, - "user": refresh_info["user"], - } - REFRESH_TOKENS[new_refresh_hash] = { - "client_id": client_id, - "scope": refresh_info["scope"], - "expires_at": time.time() + REFRESH_TOKEN_TTL, - "user": refresh_info["user"], - "access_token_hash": new_access_hash, - } - - logger.info(f"OAuth: refreshed token for client {client_id}") - return JSONResponse({ - "access_token": new_access_token, - "token_type": "Bearer", - "expires_in": ACCESS_TOKEN_TTL, - "refresh_token": new_refresh_token, - "scope": refresh_info["scope"], - }) - - return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) - - -# --------------------------------------------------------------------------- -# MCP Endpoint (OAuth-protected) -# --------------------------------------------------------------------------- - -async def handle_mcp(request: Request) -> Response: - if request.method == "HEAD": - return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"}) - - if request.method == "GET": - token_info = validate_bearer_token(request) - if not token_info: - return Response( - status_code=401, - headers={ - "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', - "Content-Type": "application/json", - }, - media_type="application/json", - content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}), - ) - # Authenticated GET — return empty SSE stream (mcpo uses this for server-sent events) - async def empty_sse(): - yield b": ping\n\n" - return StreamingResponse(empty_sse(), media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}) - - if request.method == "DELETE": - session_id = request.headers.get("Mcp-Session-Id") - if session_id and session_id in GATEWAY_SESSIONS: - del GATEWAY_SESSIONS[session_id] - return Response(status_code=200) - - # POST — require Bearer token - token_info = validate_bearer_token(request) - if not token_info: - return Response( - status_code=401, - headers={ - "WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"', - "Content-Type": "application/json", - }, - media_type="application/json", - content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}), - ) - - try: - body = await request.json() - except Exception: - return JSONResponse( - {"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None}, - status_code=400, - ) - - method = body.get("method", "") - params = body.get("params", {}) - req_id = body.get("id") - - if method == "initialize": - session_id = str(uuid.uuid4()) - GATEWAY_SESSIONS[session_id] = True - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "result": { - "protocolVersion": "2024-11-05", - "capabilities": {"tools": {"listChanged": False}}, - "serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"}, - }, - }) - response.headers["Mcp-Session-Id"] = session_id - return response - - if method == "notifications/initialized": - return Response(status_code=202) - - if method == "tools/list": - tools = list(TOOL_DEFINITIONS.values()) - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "result": {"tools": tools}, - }) - session_id = request.headers.get("Mcp-Session-Id") - if session_id: - response.headers["Mcp-Session-Id"] = session_id - return response - - if method == "tools/call": - tool_name = params.get("name", "") - arguments = params.get("arguments", {}) - - backend_name = TOOL_REGISTRY.get(tool_name) - if not backend_name: - return JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"}, - }) - - try: - result = await forward_tool_call(backend_name, tool_name, arguments, req_id) - - if isinstance(result, dict) and "jsonrpc" in result: - response = JSONResponse(result) - elif isinstance(result, dict) and "result" in result: - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "result": result["result"], - }) - elif isinstance(result, dict) and "error" in result: - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "error": result["error"], - }) - else: - response = JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "result": result, - }) - - session_id = request.headers.get("Mcp-Session-Id") - if session_id: - response.headers["Mcp-Session-Id"] = session_id - return response - - except Exception as e: - logger.error(f"Tool call failed: {tool_name} - {e}") - return JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "error": {"code": -32603, "message": str(e)}, - }) - - if method == "ping": - return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}}) - - return JSONResponse({ - "jsonrpc": "2.0", - "id": req_id, - "error": {"code": -32601, "message": f"Method not found: {method}"}, - }) - - -# --------------------------------------------------------------------------- -# Health / Status -# --------------------------------------------------------------------------- - -async def health(request: Request) -> JSONResponse: - return JSONResponse({ - "status": "healthy", - "backends": len(BACKENDS), - "tools": len(TOOL_DEFINITIONS), - "oauth": "enabled", - "active_tokens": len(ACCESS_TOKENS), - "registered_clients": len(REGISTERED_CLIENTS), - }) - - -async def status(request: Request) -> JSONResponse: - token_info = validate_bearer_token(request) - if not token_info: - return JSONResponse({"error": "unauthorized"}, status_code=401) - - return JSONResponse({ - "backends": { - name: { - "url": url, - "tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]), - "session": BACKEND_SESSIONS.get(name), - } - for name, url in BACKENDS.items() - }, - "total_tools": len(TOOL_DEFINITIONS), - "tool_list": sorted(TOOL_DEFINITIONS.keys()), - }) - - -# --------------------------------------------------------------------------- -# Dashboard -# --------------------------------------------------------------------------- - -DISPLAY_NAMES = { - "erpnext": "ERPNext", - "truenas": "TrueNAS", - "homeassistant": "Home Assistant", - "wave": "Wave Finance", - "linkedin": "LinkedIn", -} - -async def probe_backend(name: str, url: str) -> dict: - """Live-probe a backend: initialize it and count its tools in real time.""" - start = time.time() - try: - result = await mcp_request(url, "initialize", { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": {"name": "mcp-gateway-dashboard", "version": "1.0.0"}, - }) - session_id = result.get("session_id") - tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id) - tools_data = tools_result.get("result", {}) - if isinstance(tools_data, dict): - if "result" in tools_data: - tools_data = tools_data["result"] - tools = tools_data.get("tools", []) - else: - tools = [] - elapsed = round((time.time() - start) * 1000) - return { - "status": "healthy", - "toolCount": len(tools), - "responseTime": elapsed, - } - except Exception as e: - elapsed = round((time.time() - start) * 1000) - logger.warning(f"Dashboard probe failed for {name}: {e}") - # Fall back to cached registry values if live probe fails - cached_tools = len([t for t, b in TOOL_REGISTRY.items() if b == name]) - return { - "status": "healthy" if cached_tools > 0 else "unhealthy", - "toolCount": cached_tools, - "responseTime": elapsed, - "note": "cached", - } - - -async def dashboard_status(request: Request) -> JSONResponse: - """Public endpoint — no auth required — live-probes all backends.""" - probes = await asyncio.gather( - *[probe_backend(name, url) for name, url in BACKENDS.items()], - return_exceptions=True - ) - services = [] - for (name, url), probe in zip(BACKENDS.items(), probes): - if isinstance(probe, Exception): - probe = {"status": "unhealthy", "toolCount": 0, "responseTime": None} - services.append({ - "name": DISPLAY_NAMES.get(name, name.capitalize()), - "key": name, - "url": url, - "status": probe["status"], - "toolCount": probe["toolCount"], - "responseTime": probe.get("responseTime"), - "lastCheck": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - }) - return JSONResponse({ - "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - "services": services, - "summary": { - "total": len(services), - "healthy": len([s for s in services if s["status"] == "healthy"]), - "unhealthy": len([s for s in services if s["status"] == "unhealthy"]), - "totalTools": sum(s["toolCount"] for s in services), - } - }) - - -DASHBOARD_HTML = """ - - - - - MCP Gateway Dashboard - - - - -
-
-
-

MCP Gateway

-

Service Status Dashboard

-
- -
- - -
-
-

Total Services

-

-
-
-

Total Tools

-

-
-
-

Healthy

-

-
-
- -

- - -
- - -
- - - -""" - - -async def dashboard(request: Request) -> Response: - return Response(DASHBOARD_HTML, media_type="text/html") - - -# --------------------------------------------------------------------------- -# Startup -# --------------------------------------------------------------------------- - -async def startup(): - if not GATEWAY_PASSWORD: - logger.error("OAUTH_PASSWORD is not set! OAuth login will fail.") - else: - logger.info("OAuth password configured") - - logger.info(f"OAuth issuer: {ISSUER_URL}") - logger.info(f"Starting MCP Gateway Proxy with {len(BACKENDS)} backends") - logger.info("Waiting 10s for backends to start...") - await asyncio.sleep(10) - - for name, url in BACKENDS.items(): - tools = await initialize_backend(name, url) - if not tools: - logger.warning(f" {name}: no tools discovered — will retry on first request") - - logger.info(f"Gateway ready: {len(TOOL_DEFINITIONS)} tools from {len(BACKENDS)} backends") - - -@asynccontextmanager -async def lifespan(app): - await startup() - yield - - -# --------------------------------------------------------------------------- -# LinkedIn OAuth proxy routes -# --------------------------------------------------------------------------- - -LINKEDIN_MCP_URL = "http://mcp-linkedin:8500" - - -async def linkedin_auth_proxy(request: Request): - """Proxy /linkedin/auth to the linkedin-mcp container.""" - async with httpx.AsyncClient() as client: - resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/auth", follow_redirects=False) - if resp.status_code in (301, 302, 307, 308): - return RedirectResponse(resp.headers["location"]) - return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html")) - - -async def linkedin_callback_proxy(request: Request): - """Proxy /linkedin/callback to the linkedin-mcp container.""" - query = str(request.url.query) - async with httpx.AsyncClient() as client: - resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/callback?{query}", follow_redirects=False) - return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html")) - - -async def linkedin_status_proxy(request: Request): - """Proxy /linkedin/status to the linkedin-mcp container.""" - async with httpx.AsyncClient() as client: - resp = await client.get(f"{LINKEDIN_MCP_URL}/linkedin/status", follow_redirects=False) - return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type", "text/html")) - - -# Build routes list -routes = [ - # Well-known discovery (Claude tries both with and without /mcp suffix) - Route("/.well-known/oauth-protected-resource", well_known_protected_resource, methods=["GET"]), - Route("/.well-known/oauth-protected-resource/mcp", well_known_protected_resource, methods=["GET"]), - Route("/.well-known/oauth-authorization-server", well_known_oauth_authorization_server, methods=["GET"]), - Route("/.well-known/oauth-authorization-server/mcp", well_known_oauth_authorization_server, methods=["GET"]), - Route("/.well-known/openid-configuration", well_known_oauth_authorization_server, methods=["GET"]), - - # OAuth endpoints at /oauth/* (spec-standard) - Route("/oauth/register", oauth_register, methods=["POST"]), - Route("/oauth/authorize", oauth_authorize, methods=["GET", "POST"]), - Route("/oauth/token", oauth_token, methods=["POST"]), - - # OAuth endpoints at root (Claude may construct these from base URL) - Route("/register", oauth_register, methods=["POST"]), - Route("/authorize", oauth_authorize, methods=["GET", "POST"]), - Route("/token", oauth_token, methods=["POST"]), - - # MCP endpoint (OAuth-protected) - Route("/mcp", handle_mcp, methods=["GET", "HEAD", "POST", "DELETE"]), - - # Monitoring - Route("/health", health, methods=["GET"]), - Route("/status", status, methods=["GET"]), - - # Dashboard (no auth required) - Route("/dashboard", dashboard, methods=["GET"]), - Route("/dashboard/status", dashboard_status, methods=["GET"]), - - # LinkedIn OAuth flow (proxied to linkedin-mcp container) - Route("/linkedin/auth", linkedin_auth_proxy, methods=["GET"]), - Route("/linkedin/callback", linkedin_callback_proxy, methods=["GET"]), - Route("/linkedin/status", linkedin_status_proxy, methods=["GET"]), - - - # Admin / User Management UI - Route("/admin", user_dashboard, methods=["GET"]), - - # User management API - Route("/users", list_users, methods=["GET"]), - Route("/users", create_user, methods=["POST"]), - Route("/users/{username}", get_user, methods=["GET"]), - Route("/users/{username}", delete_user, methods=["DELETE"]), - Route("/users/{username}/enable", toggle_user, methods=["PATCH"]), - Route("/users/{username}/keys", generate_api_key, methods=["POST"]), - Route("/users/{username}/mcp-access", set_mcp_access, methods=["PUT"]), - Route("/keys/revoke", revoke_api_key, methods=["POST"]), -] - -app = Starlette( - routes=routes, - lifespan=lifespan, -) - -if __name__ == "__main__": - port = int(os.environ.get("PORT", "4444")) - uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") \ No newline at end of file