Add gateway-proxy/gateway_proxy_fixed.py
This commit is contained in:
parent
759c2131f7
commit
dac5fc77bb
1 changed files with 707 additions and 0 deletions
707
gateway-proxy/gateway_proxy_fixed.py
Normal file
707
gateway-proxy/gateway_proxy_fixed.py
Normal file
|
|
@ -0,0 +1,707 @@
|
||||||
|
"""
|
||||||
|
MCP Gateway Proxy with OAuth 2.1
|
||||||
|
=================================
|
||||||
|
Aggregates multiple MCP servers behind a single Streamable HTTP endpoint.
|
||||||
|
Implements a self-contained OAuth 2.1 provider compatible with claude.ai:
|
||||||
|
- RFC 8414 Authorization Server Metadata
|
||||||
|
- RFC 9728 Protected Resource Metadata
|
||||||
|
- RFC 7591 Dynamic Client Registration
|
||||||
|
- PKCE (S256) per OAuth 2.1
|
||||||
|
- Authorization Code Grant with refresh tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import html
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import HTMLResponse, JSONResponse, RedirectResponse, Response, StreamingResponse
|
||||||
|
from starlette.routing import Route
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
logger = logging.getLogger("mcp-gateway")
|
||||||
|
|
||||||
|
# Import OpenAI-compatible routes
|
||||||
|
try:
|
||||||
|
from openai_routes import chat_completions, list_models
|
||||||
|
OPENAI_AVAILABLE = True
|
||||||
|
logger.info("✓ OpenAI-compatible routes imported successfully")
|
||||||
|
except ImportError as e:
|
||||||
|
OPENAI_AVAILABLE = False
|
||||||
|
logger.error(f"✗ Failed to import OpenAI routes: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Configuration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
ISSUER_URL = os.environ.get("OAUTH_ISSUER_URL", "https://mcp.wilddragon.net")
|
||||||
|
GATEWAY_PASSWORD = os.environ.get("OAUTH_PASSWORD", "")
|
||||||
|
ACCESS_TOKEN_TTL = int(os.environ.get("OAUTH_ACCESS_TOKEN_TTL", "3600"))
|
||||||
|
REFRESH_TOKEN_TTL = int(os.environ.get("OAUTH_REFRESH_TOKEN_TTL", "2592000"))
|
||||||
|
AUTH_CODE_TTL = 600
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# In-memory stores
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
REGISTERED_CLIENTS: dict[str, dict] = {}
|
||||||
|
AUTH_CODES: dict[str, dict] = {}
|
||||||
|
ACCESS_TOKENS: dict[str, dict] = {}
|
||||||
|
REFRESH_TOKENS: dict[str, dict] = {}
|
||||||
|
PENDING_AUTH: dict[str, dict] = {}
|
||||||
|
|
||||||
|
def _hash(value: str) -> str:
|
||||||
|
return hashlib.sha256(value.encode()).hexdigest()
|
||||||
|
|
||||||
|
def _clean_expired():
|
||||||
|
now = time.time()
|
||||||
|
for store in (AUTH_CODES, ACCESS_TOKENS, REFRESH_TOKENS):
|
||||||
|
expired = [k for k, v in store.items() if v.get("expires_at", 0) < now]
|
||||||
|
for k in expired:
|
||||||
|
del store[k]
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Backend MCP aggregation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_backends() -> dict[str, str]:
|
||||||
|
backends = {}
|
||||||
|
for key, value in os.environ.items():
|
||||||
|
if key.startswith("MCP_BACKEND_"):
|
||||||
|
name = key[len("MCP_BACKEND_"):].lower()
|
||||||
|
backends[name] = value
|
||||||
|
logger.info(f"Backend configured: {name} -> {value}")
|
||||||
|
return backends
|
||||||
|
|
||||||
|
BACKENDS: dict[str, str] = load_backends()
|
||||||
|
TOOL_REGISTRY: dict[str, str] = {}
|
||||||
|
TOOL_DEFINITIONS: dict[str, dict] = {}
|
||||||
|
BACKEND_SESSIONS: dict[str, str | None] = {}
|
||||||
|
GATEWAY_SESSIONS: dict[str, bool] = {}
|
||||||
|
|
||||||
|
def parse_sse_response(text: str) -> dict | None:
|
||||||
|
for line in text.strip().split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("data: "):
|
||||||
|
try:
|
||||||
|
return json.loads(line[6:])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def mcp_request(backend_url: str, method: str, params: dict | None = None, request_id: Any = 1, session_id: str | None = None) -> dict:
|
||||||
|
payload = {"jsonrpc": "2.0", "method": method, "id": request_id}
|
||||||
|
if params is not None:
|
||||||
|
payload["params"] = params
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Host": "localhost",
|
||||||
|
}
|
||||||
|
if session_id:
|
||||||
|
headers["Mcp-Session-Id"] = session_id
|
||||||
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
resp = await client.post(backend_url, json=payload, headers=headers)
|
||||||
|
new_session = resp.headers.get("Mcp-Session-Id") or resp.headers.get("mcp-session-id")
|
||||||
|
content_type = resp.headers.get("content-type", "")
|
||||||
|
if resp.status_code in (200, 201):
|
||||||
|
try:
|
||||||
|
if "text/event-stream" in content_type:
|
||||||
|
parsed = parse_sse_response(resp.text)
|
||||||
|
return {"result": parsed, "session_id": new_session}
|
||||||
|
else:
|
||||||
|
return {"result": resp.json(), "session_id": new_session}
|
||||||
|
except Exception:
|
||||||
|
return {"result": None, "session_id": new_session}
|
||||||
|
elif resp.status_code == 202:
|
||||||
|
return {"result": None, "session_id": new_session}
|
||||||
|
else:
|
||||||
|
logger.error(f"Backend {backend_url} returned {resp.status_code}: {resp.text[:200]}")
|
||||||
|
return {"result": None, "session_id": new_session}
|
||||||
|
|
||||||
|
async def initialize_backend(name: str, url: str) -> list[dict]:
|
||||||
|
logger.info(f"Initializing backend: {name} at {url}")
|
||||||
|
for attempt in range(3):
|
||||||
|
try:
|
||||||
|
init_result = await mcp_request(url, "initialize", {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"},
|
||||||
|
})
|
||||||
|
session_id = init_result.get("session_id")
|
||||||
|
BACKEND_SESSIONS[name] = session_id
|
||||||
|
logger.info(f" {name}: initialized (session: {session_id})")
|
||||||
|
await mcp_request(url, "notifications/initialized", {}, request_id=None, session_id=session_id)
|
||||||
|
tools_result = await mcp_request(url, "tools/list", {}, request_id=2, session_id=session_id)
|
||||||
|
tools_data = tools_result.get("result", {})
|
||||||
|
if isinstance(tools_data, dict):
|
||||||
|
if "result" in tools_data:
|
||||||
|
tools_data = tools_data["result"]
|
||||||
|
tools = tools_data.get("tools", [])
|
||||||
|
else:
|
||||||
|
tools = []
|
||||||
|
logger.info(f" {name}: discovered {len(tools)} tools")
|
||||||
|
for tool in tools:
|
||||||
|
original_name = tool["name"]
|
||||||
|
prefixed_name = f"{name}_{original_name}"
|
||||||
|
tool["name"] = prefixed_name
|
||||||
|
TOOL_REGISTRY[prefixed_name] = name
|
||||||
|
TOOL_DEFINITIONS[prefixed_name] = tool
|
||||||
|
return tools
|
||||||
|
except Exception as e:
|
||||||
|
if attempt < 2:
|
||||||
|
logger.info(f" {name}: attempt {attempt+1} failed, retrying in 5s...")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
else:
|
||||||
|
logger.error(f" {name}: failed to initialize after 3 attempts - {e}")
|
||||||
|
return []
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def forward_tool_call(backend_name: str, tool_name: str, arguments: dict, request_id: Any) -> dict:
|
||||||
|
url = BACKENDS[backend_name]
|
||||||
|
session_id = BACKEND_SESSIONS.get(backend_name)
|
||||||
|
prefix = f"{backend_name}_"
|
||||||
|
original_name = tool_name[len(prefix):] if tool_name.startswith(prefix) else tool_name
|
||||||
|
if not session_id:
|
||||||
|
await initialize_backend(backend_name, url)
|
||||||
|
session_id = BACKEND_SESSIONS.get(backend_name)
|
||||||
|
result = await mcp_request(
|
||||||
|
url, "tools/call",
|
||||||
|
{"name": original_name, "arguments": arguments},
|
||||||
|
request_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
response_data = result.get("result", {})
|
||||||
|
if isinstance(response_data, dict) and response_data.get("error"):
|
||||||
|
error_code = response_data["error"].get("code", 0)
|
||||||
|
if error_code in (-32600, -32601):
|
||||||
|
logger.info(f"Re-initializing {backend_name} after error {error_code}")
|
||||||
|
await initialize_backend(backend_name, url)
|
||||||
|
session_id = BACKEND_SESSIONS.get(backend_name)
|
||||||
|
result = await mcp_request(
|
||||||
|
url, "tools/call",
|
||||||
|
{"name": original_name, "arguments": arguments},
|
||||||
|
request_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
response_data = result.get("result", {})
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OAuth 2.1: Token validation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def validate_bearer_token(request: Request) -> dict | None:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return None
|
||||||
|
token = auth_header[7:]
|
||||||
|
token_hash = _hash(token)
|
||||||
|
info = ACCESS_TOKENS.get(token_hash)
|
||||||
|
if not info:
|
||||||
|
return None
|
||||||
|
if info["expires_at"] < time.time():
|
||||||
|
del ACCESS_TOKENS[token_hash]
|
||||||
|
return None
|
||||||
|
return info
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Consent page HTML template
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
CONSENT_PAGE_CSS = """
|
||||||
|
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||||
|
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||||
|
background: #0f172a; color: #e2e8f0; min-height: 100vh;
|
||||||
|
display: flex; align-items: center; justify-content: center; }
|
||||||
|
.card { background: #1e293b; border-radius: 16px; padding: 40px;
|
||||||
|
max-width: 420px; width: 90%; box-shadow: 0 25px 50px rgba(0,0,0,0.4); }
|
||||||
|
h1 { font-size: 22px; margin-bottom: 8px; color: #f8fafc; }
|
||||||
|
.subtitle { color: #94a3b8; margin-bottom: 24px; font-size: 14px; }
|
||||||
|
.client { background: #334155; border-radius: 8px; padding: 12px 16px;
|
||||||
|
margin-bottom: 24px; font-size: 14px; }
|
||||||
|
.client strong { color: #38bdf8; }
|
||||||
|
label { display: block; font-size: 13px; color: #94a3b8; margin-bottom: 6px; }
|
||||||
|
input[type=password] { width: 100%; padding: 12px 16px; border-radius: 8px;
|
||||||
|
border: 1px solid #475569; background: #0f172a; color: #f8fafc;
|
||||||
|
font-size: 16px; margin-bottom: 20px; outline: none; }
|
||||||
|
input[type=password]:focus { border-color: #38bdf8; box-shadow: 0 0 0 3px rgba(56,189,248,0.15); }
|
||||||
|
.buttons { display: flex; gap: 12px; }
|
||||||
|
button { flex: 1; padding: 12px; border-radius: 8px; border: none; font-size: 15px;
|
||||||
|
font-weight: 600; cursor: pointer; transition: all 0.15s; }
|
||||||
|
.approve { background: #38bdf8; color: #0f172a; }
|
||||||
|
.approve:hover { background: #7dd3fc; }
|
||||||
|
.deny { background: #334155; color: #94a3b8; }
|
||||||
|
.deny:hover { background: #475569; color: #e2e8f0; }
|
||||||
|
.scope { font-size: 13px; color: #64748b; margin-bottom: 16px; }
|
||||||
|
.error { background: #7f1d1d; color: #fca5a5; padding: 10px 14px; border-radius: 8px;
|
||||||
|
margin-bottom: 16px; font-size: 13px; }
|
||||||
|
"""
|
||||||
|
|
||||||
|
def render_consent_page(client_name: str, scope: str, internal_state: str, error_msg: str = "") -> str:
|
||||||
|
error_html = f'<div class="error">{html.escape(error_msg)}</div>' if error_msg else ""
|
||||||
|
return f"""<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>MCP Gateway — Authorize</title>
|
||||||
|
<style>{CONSENT_PAGE_CSS}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="card">
|
||||||
|
<h1>MCP Gateway</h1>
|
||||||
|
<p class="subtitle">Authorization Request</p>
|
||||||
|
<div class="client"><strong>{html.escape(client_name)}</strong> wants access to your MCP tools.</div>
|
||||||
|
<p class="scope">Scope: <code>{html.escape(scope)}</code></p>
|
||||||
|
{error_html}
|
||||||
|
<form method="POST" action="/oauth/authorize">
|
||||||
|
<input type="hidden" name="internal_state" value="{internal_state}" />
|
||||||
|
<label for="password">Gateway Password</label>
|
||||||
|
<input type="password" name="password" id="password" placeholder="Enter your password" autofocus />
|
||||||
|
<div class="buttons">
|
||||||
|
<button type="submit" name="action" value="approve" class="approve">Authorize</button>
|
||||||
|
<button type="submit" name="action" value="deny" class="deny">Deny</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>"""
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OAuth 2.1 Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def well_known_protected_resource(request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse({
|
||||||
|
"resource": ISSUER_URL,
|
||||||
|
"authorization_servers": [ISSUER_URL],
|
||||||
|
"scopes_supported": ["mcp:tools"],
|
||||||
|
"bearer_methods_supported": ["header"],
|
||||||
|
})
|
||||||
|
|
||||||
|
async def well_known_oauth_authorization_server(request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse({
|
||||||
|
"issuer": ISSUER_URL,
|
||||||
|
"authorization_endpoint": f"{ISSUER_URL}/authorize",
|
||||||
|
"token_endpoint": f"{ISSUER_URL}/token",
|
||||||
|
"registration_endpoint": f"{ISSUER_URL}/register",
|
||||||
|
"scopes_supported": ["mcp:tools"],
|
||||||
|
"response_types_supported": ["code"],
|
||||||
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
|
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
|
"service_documentation": f"{ISSUER_URL}/health",
|
||||||
|
})
|
||||||
|
|
||||||
|
async def oauth_register(request: Request) -> JSONResponse:
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse({"error": "invalid_request"}, status_code=400)
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
client_secret = secrets.token_urlsafe(48)
|
||||||
|
client_info = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
"client_name": body.get("client_name", "Unknown Client"),
|
||||||
|
"redirect_uris": body.get("redirect_uris", []),
|
||||||
|
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
|
||||||
|
"response_types": body.get("response_types", ["code"]),
|
||||||
|
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_post"),
|
||||||
|
}
|
||||||
|
REGISTERED_CLIENTS[client_id] = client_info
|
||||||
|
logger.info(f"DCR: registered client '{client_info['client_name']}' as {client_id}")
|
||||||
|
return JSONResponse(client_info, status_code=201)
|
||||||
|
|
||||||
|
async def oauth_authorize(request: Request) -> Response:
|
||||||
|
_clean_expired()
|
||||||
|
if request.method == "GET":
|
||||||
|
params = request.query_params
|
||||||
|
client_id = params.get("client_id", "")
|
||||||
|
redirect_uri = params.get("redirect_uri", "")
|
||||||
|
state = params.get("state", "")
|
||||||
|
scope = params.get("scope", "mcp:tools")
|
||||||
|
code_challenge = params.get("code_challenge", "")
|
||||||
|
code_challenge_method = params.get("code_challenge_method", "S256")
|
||||||
|
response_type = params.get("response_type", "code")
|
||||||
|
if response_type != "code":
|
||||||
|
return JSONResponse({"error": "unsupported_response_type"}, status_code=400)
|
||||||
|
client = REGISTERED_CLIENTS.get(client_id)
|
||||||
|
if not client:
|
||||||
|
return JSONResponse({"error": "invalid_client", "error_description": "Client not registered."}, status_code=400)
|
||||||
|
internal_state = secrets.token_urlsafe(32)
|
||||||
|
PENDING_AUTH[internal_state] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"code_challenge": code_challenge,
|
||||||
|
"code_challenge_method": code_challenge_method,
|
||||||
|
"scope": scope,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
client_name = client.get("client_name", "Unknown Client")
|
||||||
|
return HTMLResponse(render_consent_page(client_name, scope, internal_state))
|
||||||
|
if request.method == "POST":
|
||||||
|
form = await request.form()
|
||||||
|
internal_state = form.get("internal_state", "")
|
||||||
|
password = form.get("password", "")
|
||||||
|
action = form.get("action", "")
|
||||||
|
pending = PENDING_AUTH.pop(internal_state, None)
|
||||||
|
if not pending:
|
||||||
|
return HTMLResponse("<h1>Authorization expired</h1><p>Please try connecting again.</p>", status_code=400)
|
||||||
|
redirect_uri = pending["redirect_uri"]
|
||||||
|
state = pending["state"]
|
||||||
|
if action != "approve":
|
||||||
|
qs = urlencode({"error": "access_denied", "state": state})
|
||||||
|
return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302)
|
||||||
|
if not GATEWAY_PASSWORD:
|
||||||
|
return HTMLResponse("<h1>Server misconfigured</h1><p>OAUTH_PASSWORD not set.</p>", status_code=500)
|
||||||
|
if password != GATEWAY_PASSWORD:
|
||||||
|
PENDING_AUTH[internal_state] = pending
|
||||||
|
client = REGISTERED_CLIENTS.get(pending["client_id"], {})
|
||||||
|
client_name = client.get("client_name", "Unknown Client")
|
||||||
|
return HTMLResponse(render_consent_page(client_name, pending["scope"], internal_state, "Invalid password. Please try again."))
|
||||||
|
code = secrets.token_urlsafe(48)
|
||||||
|
AUTH_CODES[code] = {
|
||||||
|
"client_id": pending["client_id"],
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"code_challenge": pending["code_challenge"],
|
||||||
|
"code_challenge_method": pending["code_challenge_method"],
|
||||||
|
"scope": pending["scope"],
|
||||||
|
"expires_at": time.time() + AUTH_CODE_TTL,
|
||||||
|
"user": "owner",
|
||||||
|
}
|
||||||
|
qs = urlencode({"code": code, "state": state})
|
||||||
|
logger.info(f"OAuth: issued auth code for client {pending['client_id']}")
|
||||||
|
return RedirectResponse(f"{redirect_uri}?{qs}", status_code=302)
|
||||||
|
return JSONResponse({"error": "method_not_allowed"}, status_code=405)
|
||||||
|
|
||||||
|
async def oauth_token(request: Request) -> JSONResponse:
|
||||||
|
_clean_expired()
|
||||||
|
content_type = request.headers.get("content-type", "")
|
||||||
|
if "application/json" in content_type:
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
body = {}
|
||||||
|
else:
|
||||||
|
form = await request.form()
|
||||||
|
body = dict(form)
|
||||||
|
grant_type = body.get("grant_type", "")
|
||||||
|
if grant_type == "authorization_code":
|
||||||
|
code = body.get("code", "")
|
||||||
|
client_id = body.get("client_id", "")
|
||||||
|
code_verifier = body.get("code_verifier", "")
|
||||||
|
code_info = AUTH_CODES.pop(code, None)
|
||||||
|
if not code_info:
|
||||||
|
return JSONResponse({"error": "invalid_grant", "error_description": "Code invalid or expired."}, status_code=400)
|
||||||
|
if code_info["expires_at"] < time.time():
|
||||||
|
return JSONResponse({"error": "invalid_grant", "error_description": "Code expired."}, status_code=400)
|
||||||
|
if code_info["client_id"] != client_id:
|
||||||
|
return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400)
|
||||||
|
if code_info.get("code_challenge"):
|
||||||
|
if not code_verifier:
|
||||||
|
return JSONResponse({"error": "invalid_grant", "error_description": "code_verifier required."}, status_code=400)
|
||||||
|
expected = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
|
||||||
|
if expected != code_info["code_challenge"]:
|
||||||
|
return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed."}, status_code=400)
|
||||||
|
access_token = secrets.token_urlsafe(48)
|
||||||
|
refresh_token = secrets.token_urlsafe(48)
|
||||||
|
access_hash = _hash(access_token)
|
||||||
|
refresh_hash = _hash(refresh_token)
|
||||||
|
ACCESS_TOKENS[access_hash] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"scope": code_info["scope"],
|
||||||
|
"expires_at": time.time() + ACCESS_TOKEN_TTL,
|
||||||
|
"user": code_info["user"],
|
||||||
|
}
|
||||||
|
REFRESH_TOKENS[refresh_hash] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"scope": code_info["scope"],
|
||||||
|
"expires_at": time.time() + REFRESH_TOKEN_TTL,
|
||||||
|
"user": code_info["user"],
|
||||||
|
"access_token_hash": access_hash,
|
||||||
|
}
|
||||||
|
logger.info(f"OAuth: issued access token for client {client_id}")
|
||||||
|
return JSONResponse({
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": ACCESS_TOKEN_TTL,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"scope": code_info["scope"],
|
||||||
|
})
|
||||||
|
elif grant_type == "refresh_token":
|
||||||
|
refresh_token_val = body.get("refresh_token", "")
|
||||||
|
client_id = body.get("client_id", "")
|
||||||
|
refresh_hash = _hash(refresh_token_val)
|
||||||
|
refresh_info = REFRESH_TOKENS.get(refresh_hash)
|
||||||
|
if not refresh_info:
|
||||||
|
return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token invalid."}, status_code=400)
|
||||||
|
if refresh_info["expires_at"] < time.time():
|
||||||
|
del REFRESH_TOKENS[refresh_hash]
|
||||||
|
return JSONResponse({"error": "invalid_grant", "error_description": "Refresh token expired."}, status_code=400)
|
||||||
|
if refresh_info["client_id"] != client_id:
|
||||||
|
return JSONResponse({"error": "invalid_grant", "error_description": "Client mismatch."}, status_code=400)
|
||||||
|
old_access_hash = refresh_info.get("access_token_hash")
|
||||||
|
if old_access_hash and old_access_hash in ACCESS_TOKENS:
|
||||||
|
del ACCESS_TOKENS[old_access_hash]
|
||||||
|
del REFRESH_TOKENS[refresh_hash]
|
||||||
|
new_access_token = secrets.token_urlsafe(48)
|
||||||
|
new_refresh_token = secrets.token_urlsafe(48)
|
||||||
|
new_access_hash = _hash(new_access_token)
|
||||||
|
new_refresh_hash = _hash(new_refresh_token)
|
||||||
|
ACCESS_TOKENS[new_access_hash] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"scope": refresh_info["scope"],
|
||||||
|
"expires_at": time.time() + ACCESS_TOKEN_TTL,
|
||||||
|
"user": refresh_info["user"],
|
||||||
|
}
|
||||||
|
REFRESH_TOKENS[new_refresh_hash] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"scope": refresh_info["scope"],
|
||||||
|
"expires_at": time.time() + REFRESH_TOKEN_TTL,
|
||||||
|
"user": refresh_info["user"],
|
||||||
|
"access_token_hash": new_access_hash,
|
||||||
|
}
|
||||||
|
logger.info(f"OAuth: refreshed token for client {client_id}")
|
||||||
|
return JSONResponse({
|
||||||
|
"access_token": new_access_token,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": ACCESS_TOKEN_TTL,
|
||||||
|
"refresh_token": new_refresh_token,
|
||||||
|
"scope": refresh_info["scope"],
|
||||||
|
})
|
||||||
|
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MCP Endpoint (OAuth-protected)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def handle_mcp(request: Request) -> Response:
|
||||||
|
if request.method == "HEAD":
|
||||||
|
return Response(status_code=200, headers={"MCP-Protocol-Version": "2024-11-05"})
|
||||||
|
if request.method == "GET":
|
||||||
|
return Response(
|
||||||
|
status_code=401,
|
||||||
|
headers={
|
||||||
|
"WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"',
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
media_type="application/json",
|
||||||
|
content=json.dumps({"error": "unauthorized", "message": "Bearer token required."}),
|
||||||
|
)
|
||||||
|
if request.method == "DELETE":
|
||||||
|
session_id = request.headers.get("Mcp-Session-Id")
|
||||||
|
if session_id and session_id in GATEWAY_SESSIONS:
|
||||||
|
del GATEWAY_SESSIONS[session_id]
|
||||||
|
return Response(status_code=200)
|
||||||
|
# POST — require Bearer token
|
||||||
|
token_info = validate_bearer_token(request)
|
||||||
|
if not token_info:
|
||||||
|
return Response(
|
||||||
|
status_code=401,
|
||||||
|
headers={
|
||||||
|
"WWW-Authenticate": f'Bearer resource_metadata="{ISSUER_URL}/.well-known/oauth-protected-resource"',
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
media_type="application/json",
|
||||||
|
content=json.dumps({"error": "unauthorized", "message": "Valid Bearer token required."}),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse(
|
||||||
|
{"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
method = body.get("method", "")
|
||||||
|
params = body.get("params", {})
|
||||||
|
req_id = body.get("id")
|
||||||
|
if method == "initialize":
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
GATEWAY_SESSIONS[session_id] = True
|
||||||
|
response = JSONResponse({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {"tools": {"listChanged": False}},
|
||||||
|
"serverInfo": {"name": "mcp-gateway-proxy", "version": "1.0.0"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
response.headers["Mcp-Session-Id"] = session_id
|
||||||
|
return response
|
||||||
|
if method == "notifications/initialized":
|
||||||
|
return Response(status_code=202)
|
||||||
|
if method == "tools/list":
|
||||||
|
tools = list(TOOL_DEFINITIONS.values())
|
||||||
|
response = JSONResponse({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {"tools": tools},
|
||||||
|
})
|
||||||
|
session_id = request.headers.get("Mcp-Session-Id")
|
||||||
|
if session_id:
|
||||||
|
response.headers["Mcp-Session-Id"] = session_id
|
||||||
|
return response
|
||||||
|
if method == "tools/call":
|
||||||
|
tool_name = params.get("name", "")
|
||||||
|
arguments = params.get("arguments", {})
|
||||||
|
backend_name = TOOL_REGISTRY.get(tool_name)
|
||||||
|
if not backend_name:
|
||||||
|
return JSONResponse({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
|
||||||
|
})
|
||||||
|
try:
|
||||||
|
result = await forward_tool_call(backend_name, tool_name, arguments, req_id)
|
||||||
|
if isinstance(result, dict) and "jsonrpc" in result:
|
||||||
|
response = JSONResponse(result)
|
||||||
|
elif isinstance(result, dict) and "result" in result:
|
||||||
|
response = JSONResponse({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": result["result"],
|
||||||
|
})
|
||||||
|
elif isinstance(result, dict) and "error" in result:
|
||||||
|
response = JSONResponse({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"error": result["error"],
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
response = JSONResponse({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": result,
|
||||||
|
})
|
||||||
|
session_id = request.headers.get("Mcp-Session-Id")
|
||||||
|
if session_id:
|
||||||
|
response.headers["Mcp-Session-Id"] = session_id
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool call failed: {tool_name} - {e}")
|
||||||
|
return JSONResponse({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"error": {"code": -32603, "message": str(e)},
|
||||||
|
})
|
||||||
|
if method == "ping":
|
||||||
|
return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": {}})
|
||||||
|
return JSONResponse({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
||||||
|
})
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Health / Status
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def health(request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse({
|
||||||
|
"status": "healthy",
|
||||||
|
"backends": len(BACKENDS),
|
||||||
|
"tools": len(TOOL_DEFINITIONS),
|
||||||
|
"oauth": "enabled",
|
||||||
|
"active_tokens": len(ACCESS_TOKENS),
|
||||||
|
"registered_clients": len(REGISTERED_CLIENTS),
|
||||||
|
})
|
||||||
|
|
||||||
|
async def status(request: Request) -> JSONResponse:
|
||||||
|
token_info = validate_bearer_token(request)
|
||||||
|
if not token_info:
|
||||||
|
return JSONResponse({"error": "unauthorized"}, status_code=401)
|
||||||
|
return JSONResponse({
|
||||||
|
"backends": {
|
||||||
|
name: {
|
||||||
|
"url": url,
|
||||||
|
"tools": len([t for t, b in TOOL_REGISTRY.items() if b == name]),
|
||||||
|
"session": BACKEND_SESSIONS.get(name),
|
||||||
|
}
|
||||||
|
for name, url in BACKENDS.items()
|
||||||
|
},
|
||||||
|
"total_tools": len(TOOL_DEFINITIONS),
|
||||||
|
"tool_list": sorted(TOOL_DEFINITIONS.keys()),
|
||||||
|
})
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Startup
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def startup():
|
||||||
|
if not GATEWAY_PASSWORD:
|
||||||
|
logger.error("OAUTH_PASSWORD is not set! OAuth login will fail.")
|
||||||
|
else:
|
||||||
|
logger.info("OAuth password configured")
|
||||||
|
logger.info(f"OAuth issuer: {ISSUER_URL}")
|
||||||
|
logger.info(f"Starting MCP Gateway Proxy with {len(BACKENDS)} backends")
|
||||||
|
logger.info("Waiting 10s for backends to start...")
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
for name, url in BACKENDS.items():
|
||||||
|
tools = await initialize_backend(name, url)
|
||||||
|
if not tools:
|
||||||
|
logger.warning(f" {name}: no tools discovered — will retry on first request")
|
||||||
|
logger.info(f"Gateway ready: {len(TOOL_DEFINITIONS)} tools from {len(BACKENDS)} backends")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app):
|
||||||
|
await startup()
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Build routes list
|
||||||
|
routes = [
|
||||||
|
# Well-known discovery (Claude tries both with and without /mcp suffix)
|
||||||
|
Route("/.well-known/oauth-protected-resource", well_known_protected_resource, methods=["GET"]),
|
||||||
|
Route("/.well-known/oauth-protected-resource/mcp", well_known_protected_resource, methods=["GET"]),
|
||||||
|
Route("/.well-known/oauth-authorization-server", well_known_oauth_authorization_server, methods=["GET"]),
|
||||||
|
Route("/.well-known/oauth-authorization-server/mcp", well_known_oauth_authorization_server, methods=["GET"]),
|
||||||
|
Route("/.well-known/openid-configuration", well_known_oauth_authorization_server, methods=["GET"]),
|
||||||
|
# OAuth endpoints at /oauth/* (spec-standard)
|
||||||
|
Route("/oauth/register", oauth_register, methods=["POST"]),
|
||||||
|
Route("/oauth/authorize", oauth_authorize, methods=["GET", "POST"]),
|
||||||
|
Route("/oauth/token", oauth_token, methods=["POST"]),
|
||||||
|
# OAuth endpoints at root (Claude may construct these from base URL)
|
||||||
|
Route("/register", oauth_register, methods=["POST"]),
|
||||||
|
Route("/authorize", oauth_authorize, methods=["GET", "POST"]),
|
||||||
|
Route("/token", oauth_token, methods=["POST"]),
|
||||||
|
# MCP endpoint (OAuth-protected)
|
||||||
|
Route("/mcp", handle_mcp, methods=["GET", "HEAD", "POST", "DELETE"]),
|
||||||
|
# Monitoring
|
||||||
|
Route("/health", health, methods=["GET"]),
|
||||||
|
Route("/status", status, methods=["GET"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add OpenAI-compatible endpoints if available
|
||||||
|
if OPENAI_AVAILABLE:
|
||||||
|
routes.extend([
|
||||||
|
# OpenAI-compatible endpoints
|
||||||
|
Route("/v1/models", list_models, methods=["GET"]),
|
||||||
|
Route("/v1/chat/completions", chat_completions, methods=["POST"]),
|
||||||
|
])
|
||||||
|
logger.info("OpenAI-compatible endpoints registered at /v1/*")
|
||||||
|
|
||||||
|
app = Starlette(
|
||||||
|
routes=routes,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
port = int(os.environ.get("PORT", "4444"))
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
|
||||||
Loading…
Reference in a new issue