mcp-servers/gateway-proxy/gateway_proxy_fixed.py

708 lines
30 KiB
Python
Raw Normal View History

"""
MCP Gateway Proxy with OAuth 2.1
=================================
Aggregates multiple MCP servers behind a single Streamable HTTP endpoint.
Implements a self-contained OAuth 2.1 provider compatible with claude.ai:
- RFC 8414 Authorization Server Metadata
- RFC 9728 Protected Resource Metadata
- RFC 7591 Dynamic Client Registration
- PKCE (S256) per OAuth 2.1
- Authorization Code Grant with refresh tokens
"""
import asyncio
import base64
import hashlib
import html
import json
import logging
import os
import secrets
import time
import uuid
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlencode
import httpx
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse
from starlette.routing import Route
import uvicorn
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("mcp-gateway")
# Import OpenAI-compatible routes
try:
from openai_routes import chat_completions, list_models
OPENAI_AVAILABLE = True
logger.info("✓ OpenAI-compatible routes imported successfully")
except ImportError as e:
OPENAI_AVAILABLE = False
logger.error(f"✗ Failed to import OpenAI routes: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net")
GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "")
ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600"))
REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000"))
AUTH_CODE_TTL = 600
# ---------------------------------------------------------------------------
# In-memory stores
# ---------------------------------------------------------------------------
REGISTERED_CLIENTS: dict[str, dict] = {}
AUTH_CODES: dict[str, dict] = {}
ACCESS_TOKENS: dict[str, dict] = {}
REFRESH_TOKENS: dict[str, dict] = {}
PENDING_AUTH: dict[str, dict] = {}
def _hash(value: str) -> str:
return hashlib.sha256(value.encode()).hexdigest()
def _clean_expired():
now = time.time()
for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS):
expired = [k for k, v in store.items() if v.get("expires_at", 0) < now]
for k in expired:
del store[k]
# ---------------------------------------------------------------------------
# Backend MCP aggregation
# ---------------------------------------------------------------------------
def load_backends() -> dict[str, str]:
backends = {}
for key, value in os.environ.items():
if key.startswith("MCP_BACKEND_"):
name = key[len("MCP_BACKEND_"):].lower()
backends[name] = value
logger.info(f"Backend configured: {name} -> {value}")
return backends
BACKENDS: dict[str, str] = load_backends()
TOOL_REGISTRY: dict[str, str] = {}
TOOL_DEFINITIONS: dict[str, dict] = {}
BACKEND_SESSIONS: dict[str, str | None] = {}
GATEWAY_SESSIONS: dict[str, bool] = {}
def parse_sse_response(text: str) -> dict | None:
for line in text.strip().split("\n"):
line = line.strip()
if line.startswith("data: "):
try:
return json.loads(line[6:])
except json.JSONDecodeError:
continue
return None
async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict:
payload = {"jsonrpc": "2.0", "method": method, "id": request_id}
if params is not None:
payload["params"] = params
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Host": "localhost",
}
if session_id:
headers["Mcp-Session-Id"] = session_id
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(backend_url, json=payload, headers=headers)
new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id")
content_type = resp.headers.get("content-type", "")
if resp.status_code in (200, 201):
try:
if "text/event-stream" in content_type:
parsed = parse_sse_response(resp.text)
return {"result": parsed, "session_id": new_session}
else:
return {"result": resp.json(), "session_id": new_session}
except Exception:
return {"result": None, "session_id": new_session}
elif resp.status_code == 202:
return {"result": None, "session_id": new_session}
else:
logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}")
return {"result": None, "session_id": new_session}
async def initialize_backend(name: str, url: str) -> list[dict]:
logger.info(f"Initializing backend: {name} at {url}")
for attempt in range(3):
try:
init_result = await mcp_request(url, "initialize", {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"},
})
session_id = init_result.get("session_id")
BACKEND_SESSIONS[name] = session_id
logger.info(f" {name}: initialized (session: {session_id})")
await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id)
tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id)
tools_data = tools_result.get("result", {})
if isinstance(tools_data, dict):
if "result" in tools_data:
tools_data = tools_data["result"]
tools = tools_data.get("tools", [])
else:
tools = []
logger.info(f" {name}: discovered {len(tools)} tools")
for tool in tools:
original_name = tool["name"]
prefixed_name = f"{name}_{original_name}"
tool["name"] = prefixed_name
TOOL_REGISTRY[prefixed_name] = name
TOOL_DEFINITIONS[prefixed_name] = tool
return tools
except Exception as e:
if attempt < 2:
logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...")
await asyncio.sleep(5)
else:
logger.error(f" {name}: failed to initialize after 3 attempts - {e}")
return []
return []
async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict:
url = BACKENDS[backend_name]
session_id = BACKEND_SESSIONS.get(backend_name)
prefix = f"{backend_name}_"
original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name
if not session_id:
await initialize_backend(backend_name, url)
session_id = BACKEND_SESSIONS.get(backend_name)
result = await mcp_request(
url, "tools/call",
{"name": original_name, "arguments": arguments},
request_id=request_id,
session_id=session_id,
)
response_data = result.get("result", {})
if isinstance(response_data, dict) and response_data.get("error"):
error_code = response_data["error"].get("code", 0)
if error_code in (-32600, -32601):
logger.info(f"Re-initializing {backend_name} after error {error_code}")
await initialize_backend(backend_name, url)
session_id = BACKEND_SESSIONS.get(backend_name)
result = await mcp_request(
url, "tools/call",
{"name": original_name, "arguments": arguments},
request_id=request_id,
session_id=session_id,
)
response_data = result.get("result", {})
return response_data
# ---------------------------------------------------------------------------
# OAuth 2.1: Token validation
# ---------------------------------------------------------------------------
def validate_bearer_token(request: Request) -> dict | None:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:]
token_hash = _hash(token)
info = ACCESS_TOKENS.get(token_hash)
if not info:
return None
if info["expires_at"] < time.time():
del ACCESS_TOKENS[token_hash]
return None
return info
# ---------------------------------------------------------------------------
# Consent page HTML template
# ---------------------------------------------------------------------------
CONSENT_PAGE_CSS = """
* { margin: 0; padding: 0; box-sizing: border-box; }
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0f172a; color: #e2e8f0; min-height: 100vh;
display: flex; align-items: center; justify-content: center; }
.card { background: #1e293b; border-radius: 16px; padding: 40px;
max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); }
h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; }
.subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; }
.client { background: #334155; border-radius: 8px; padding: 12px 16px;
margin-bottom: 24px; font-size: 14px; }
.client strong { color: #38bdf8; }
label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; }
input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px;
border: 1px solid #475569; background: #0f172a; color: #f8fafc;
font-size: 16px; margin-bottom: 20px; outline: none; }
input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); }
.buttons { display: flex; gap: 12px; }
button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px;
font-weight: 600; cursor: pointer; transition: all 0.15s; }
.approve { background: #38bdf8; color: #0f172a; }
.approve:hover { background: #7dd3fc; }
.deny { background: #334155; color: #94a3b8; }
.deny:hover { background: #475569; color: #e2e8f0; }
.scope { font-size: 13px; color: #64748b; margin-bottom: 16px; }
.error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px;
margin-bottom: 16px; font-size: 13px; }
"""
def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str:
error_html = f'<div class="error">{html.escape(error_msg)}</div>' if error_msg else ""
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MCP Gateway Authorize</title>
<style>{CONSENT_PAGE_CSS}</style>
</head>
<body>
<div class="card">
<h1>MCP Gateway</h1>
<p class="subtitle">Authorization Request</p>
<div class="client"><strong>{html.escape(client_name)}</strong> wants access to your MCP tools.</div>
<p class="scope">Scope: <code>{html.escape(scope)}</code></p>
{error_html}
<form method="POST" action="/oauth/authorize">
<input type="hidden" name="internal_state" value="{internal_state}" />
<label for="password">Gateway Password</label>
<input type="password" name="password" id="password" placeholder="Enter your password" autofocus />
<div class="buttons">
<button type="submit" name="action" value="approve" class="approve">Authorize</button>
<button type="submit" name="action" value="deny" class="deny">Deny</button>
</div>
</form>
</div>
</body>
</html>"""
# ---------------------------------------------------------------------------
# OAuth 2.1 Endpoints
# ---------------------------------------------------------------------------
async def well_known_protected_resource(request: Request) -> JSONResponse:
return JSONResponse({
"resource": ISSUER_URL,
"authorization_servers": [ISSUER_URL],
"scopes_supported": ["mcp:tools"],
"bearer_methods_supported": ["header"],
})
async def well_known_oauth_authorization_server(request: Request) -> JSONResponse:
return JSONResponse({
"issuer": ISSUER_URL,
"authorization_endpoint": f"{ISSUER_URL}/authorize",
"token_endpoint": f"{ISSUER_URL}/token",
"registration_endpoint": f"{ISSUER_URL}/register",
"scopes_supported": ["mcp:tools"],
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
"code_challenge_methods_supported": ["S256"],
"service_documentation": f"{ISSUER_URL}/health",
})
async def oauth_register(request: Request) -> JSONResponse:
try:
body = await request.json()
except Exception:
return JSONResponse({"error": "invalid_request"}, status_code=400)
client_id = str(uuid.uuid4())
client_secret = secrets.token_urlsafe(48)
client_info = {
"client_id": client_id,
"client_secret": client_secret,
"client_name": body.get("client_name", "Unknown Client"),
"redirect_uris": body.get("redirect_uris", []),
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
"response_types": body.get("response_types", ["code"]),
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_post"),
}
REGISTERED_CLIENTS[client_id] = client_info
logger.info(f"DCR: registered client '{client_info['client_name']}' as {client_id}")
return JSONResponse(client_info, status_code=201)
async def oauth_authorize(request: Request) -> Response:
_clean_expired()
if request.method == "GET":
params = request.query_params
client_id = params.get("client_id", "")
redirect_uri = params.get("redirect_uri", "")
state = params.get("state", "")
scope = params.get("scope", "mcp:tools")
code_challenge = params.get("code_challenge", "")
code_challenge_method = params.get("code_challenge_method", "S256")
response_type = params.get("response_type", "code")
if response_type != "code":
return JSONResponse({"error": "unsupported_response_type"}, status_code=400)
client = REGISTERED_CLIENTS.get(client_id)
if not client:
return JSONResponse({"error": "invalid_client", "error_description": "Client not registered."}, status_code=400)
internal_state = secrets.token_urlsafe(32)
PENDING_AUTH[internal_state] = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"scope": scope,
"state": state,
}
client_name = client.get("client_name", "Unknown Client")
return HTMLResponse(render_consent_page(client_name, scope, internal_state))
if request.method == "POST":
form = await request.form()
internal_state = form.get("internal_state", "")
password = form.get("password", "")
action = form.get("action", "")
pending = PENDING_AUTH.pop(internal_state, None)
if not pending:
return HTMLResponse("<h1>Authorization expired</h1><p>Please try connecting again.</p>", status_code=400)
redirect_uri = pending["redirect_uri"]
state = pending["state"]
if action != "approve":
qs = urlencode({"error": "access_denied", "state": state})
return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302)
if not GATEWAY_PASSWORD:
return HTMLResponse("<h1>Server misconfigured</h1><p>OAUTH_PASSWORD not set.</p>", status_code=500)
if password != GATEWAY_PASSWORD:
PENDING_AUTH[internal_state] = pending
client = REGISTERED_CLIENTS.get(pending["client_id"], {})
client_name = client.get("client_name", "Unknown Client")
return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again."))
code = secrets.token_urlsafe(48)
AUTH_CODES[code] = {
"client_id": pending["client_id"],
"redirect_uri": redirect_uri,
"code_challenge": pending["code_challenge"],
"code_challenge_method": pending["code_challenge_method"],
"scope": pending["scope"],
"expires_at": time.time() + AUTH_CODE_TTL,
"user": "owner",
}
qs = urlencode({"code": code, "state": state})
logger.info(f"OAuth: issued auth code for client {pending['client_id']}")
return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302)
return JSONResponse({"error": "method_not_allowed"}, status_code=405)
async def oauth_token(request: Request) -> JSONResponse:
_clean_expired()
content_type = request.headers.get("content-type", "")
if "application/json" in content_type:
try:
body = await request.json()
except Exception:
body = {}
else:
form = await request.form()
body = dict(form)
grant_type = body.get("grant_type", "")
if grant_type == "authorization_code":
code = body.get("code", "")
client_id = body.get("client_id", "")
code_verifier = body.get("code_verifier", "")
code_info = AUTH_CODES.pop(code, None)
if not code_info:
return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400)
if code_info["expires_at"] < time.time():
return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400)
if code_info["client_id"] != client_id:
return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400)
if code_info.get("code_challenge"):
if not code_verifier:
return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400)
expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
if expected != code_info["code_challenge"]:
return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400)
access_token = secrets.token_urlsafe(48)
refresh_token = secrets.token_urlsafe(48)
access_hash = _hash(access_token)
refresh_hash = _hash(refresh_token)
ACCESS_TOKENS[access_hash] = {
"client_id": client_id,
"scope": code_info["scope"],
"expires_at": time.time() + ACCESS_TOKEN_TTL,
"user": code_info["user"],
}
REFRESH_TOKENS[refresh_hash] = {
"client_id": client_id,
"scope": code_info["scope"],
"expires_at": time.time() + REFRESH_TOKEN_TTL,
"user": code_info["user"],
"access_token_hash": access_hash,
}
logger.info(f"OAuth: issued access token for client {client_id}")
return JSONResponse({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": ACCESS_TOKEN_TTL,
"refresh_token": refresh_token,
"scope": code_info["scope"],
})
elif grant_type == "refresh_token":
refresh_token_val = body.get("refresh_token", "")
client_id = body.get("client_id", "")
refresh_hash = _hash(refresh_token_val)
refresh_info = REFRESH_TOKENS.get(refresh_hash)
if not refresh_info:
return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400)
if refresh_info["expires_at"] < time.time():
del REFRESH_TOKENS[refresh_hash]
return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400)
if refresh_info["client_id"] != client_id:
return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400)
old_access_hash = refresh_info.get("access_token_hash")
if old_access_hash and old_access_hash in ACCESS_TOKENS:
del ACCESS_TOKENS[old_access_hash]
del REFRESH_TOKENS[refresh_hash]
new_access_token = secrets.token_urlsafe(48)
new_refresh_token = secrets.token_urlsafe(48)
new_access_hash = _hash(new_access_token)
new_refresh_hash = _hash(new_refresh_token)
ACCESS_TOKENS[new_access_hash] = {
"client_id": client_id,
"scope": refresh_info["scope"],
"expires_at": time.time() + ACCESS_TOKEN_TTL,
"user": refresh_info["user"],
}
REFRESH_TOKENS[new_refresh_hash] = {
"client_id": client_id,
"scope": refresh_info["scope"],
"expires_at": time.time() + REFRESH_TOKEN_TTL,
"user": refresh_info["user"],
"access_token_hash": new_access_hash,
}
logger.info(f"OAuth: refreshed token for client {client_id}")
return JSONResponse({
"access_token": new_access_token,
"token_type": "Bearer",
"expires_in": ACCESS_TOKEN_TTL,
"refresh_token": new_refresh_token,
"scope": refresh_info["scope"],
})
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
# ---------------------------------------------------------------------------
# MCP Endpoint (OAuth-protected)
# ---------------------------------------------------------------------------
async def handle_mcp(request: Request) -> Response:
if request.method == "HEAD":
return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"})
if request.method == "GET":
return Response(
status_code=401,
headers={
"WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"',
"Content-Type": "application/json",
},
media_type="application/json",
content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}),
)
if request.method == "DELETE":
session_id = request.headers.get("Mcp-Session-Id")
if session_id and session_id in GATEWAY_SESSIONS:
del GATEWAY_SESSIONS[session_id]
return Response(status_code=200)
# POST — require Bearer token
token_info = validate_bearer_token(request)
if not token_info:
return Response(
status_code=401,
headers={
"WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"',
"Content-Type": "application/json",
},
media_type="application/json",
content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}),
)
try:
body = await request.json()
except Exception:
return JSONResponse(
{"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None},
status_code=400,
)
method = body.get("method", "")
params = body.get("params", {})
req_id = body.get("id")
if method == "initialize":
session_id = str(uuid.uuid4())
GATEWAY_SESSIONS[session_id] = True
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"},
},
})
response.headers["Mcp-Session-Id"] = session_id
return response
if method == "notifications/initialized":
return Response(status_code=202)
if method == "tools/list":
tools = list(TOOL_DEFINITIONS.values())
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": {"tools": tools},
})
session_id = request.headers.get("Mcp-Session-Id")
if session_id:
response.headers["Mcp-Session-Id"] = session_id
return response
if method == "tools/call":
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
backend_name = TOOL_REGISTRY.get(tool_name)
if not backend_name:
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
})
try:
result = await forward_tool_call(backend_name, tool_name, arguments, req_id)
if isinstance(result, dict) and "jsonrpc" in result:
response = JSONResponse(result)
elif isinstance(result, dict) and "result" in result:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": result["result"],
})
elif isinstance(result, dict) and "error" in result:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": result["error"],
})
else:
response = JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"result": result,
})
session_id = request.headers.get("Mcp-Session-Id")
if session_id:
response.headers["Mcp-Session-Id"] = session_id
return response
except Exception as e:
logger.error(f"Tool call failed: {tool_name} - {e}")
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32603, "message": str(e)},
})
if method == "ping":
return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}})
return JSONResponse({
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Method not found: {method}"},
})
# ---------------------------------------------------------------------------
# Health / Status
# ---------------------------------------------------------------------------
async def health(request: Request) -> JSONResponse:
return JSONResponse({
"status": "healthy",
"backends": len(BACKENDS),
"tools": len(TOOL_DEFINITIONS),
"oauth": "enabled",
"active_tokens": len(ACCESS_TOKENS),
"registered_clients": len(REGISTERED_CLIENTS),
})
async def status(request: Request) -> JSONResponse:
token_info = validate_bearer_token(request)
if not token_info:
return JSONResponse({"error": "unauthorized"}, status_code=401)
return JSONResponse({
"backends": {
name: {
"url": url,
"tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]),
"session": BACKEND_SESSIONS.get(name),
}
for name, url in BACKENDS.items()
},
"total_tools": len(TOOL_DEFINITIONS),
"tool_list": sorted(TOOL_DEFINITIONS.keys()),
})
# ---------------------------------------------------------------------------
# Startup
# ---------------------------------------------------------------------------
async def startup():
if not GATEWAY_PASSWORD:
logger.error("OAUTH_PASSWORD is not set! OAuth login will fail.")
else:
logger.info("OAuth password configured")
logger.info(f"OAuth issuer: {ISSUER_URL}")
logger.info(f"Starting MCP Gateway Proxy with {len(BACKENDS)} backends")
logger.info("Waiting 10s for backends to start...")
await asyncio.sleep(10)
for name, url in BACKENDS.items():
tools = await initialize_backend(name, url)
if not tools:
logger.warning(f" {name}: no tools discovered — will retry on first request")
logger.info(f"Gateway ready: {len(TOOL_DEFINITIONS)} tools from {len(BACKENDS)} backends")
@asynccontextmanager
async def lifespan(app):
await startup()
yield
# Build routes list
routes = [
# Well-known discovery (Claude tries both with and without /mcp suffix)
Route("/.well-known/oauth-protected-resource", well_known_protected_resource, methods=["GET"]),
Route("/.well-known/oauth-protected-resource/mcp", well_known_protected_resource, methods=["GET"]),
Route("/.well-known/oauth-authorization-server", well_known_oauth_authorization_server, methods=["GET"]),
Route("/.well-known/oauth-authorization-server/mcp", well_known_oauth_authorization_server, methods=["GET"]),
Route("/.well-known/openid-configuration", well_known_oauth_authorization_server, methods=["GET"]),
# OAuth endpoints at /oauth/* (spec-standard)
Route("/oauth/register", oauth_register, methods=["POST"]),
Route("/oauth/authorize", oauth_authorize, methods=["GET", "POST"]),
Route("/oauth/token", oauth_token, methods=["POST"]),
# OAuth endpoints at root (Claude may construct these from base URL)
Route("/register", oauth_register, methods=["POST"]),
Route("/authorize", oauth_authorize, methods=["GET", "POST"]),
Route("/token", oauth_token, methods=["POST"]),
# MCP endpoint (OAuth-protected)
Route("/mcp", handle_mcp, methods=["GET", "HEAD", "POST", "DELETE"]),
# Monitoring
Route("/health", health, methods=["GET"]),
Route("/status", status, methods=["GET"]),
]
# Add OpenAI-compatible endpoints if available
if OPENAI_AVAILABLE:
routes.extend([
# OpenAI-compatible endpoints
Route("/v1/models", list_models, methods=["GET"]),
Route("/v1/chat/completions", chat_completions, methods=["POST"]),
])
logger.info("OpenAI-compatible endpoints registered at /v1/*")
app = Starlette(
routes=routes,
lifespan=lifespan,
)
if __name__ == "__main__":
port = int(os.environ.get("PORT", "4444"))
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")