mcp-servers/mcp-gateway/gateway-proxy/oauth_storage.py

144 lines
4.2 KiB
Python
Raw Normal View History

"""
Persistent OAuth Client Storage
Stores OAuth client registrations to disk so they survive gateway restarts.
"""
import json
import os
import logging
from typing import Dict, Any
logger = logging.getLogger("mcp-gateway.oauth-storage")
# Path to persist OAuth clients
OAUTH_STORAGE_FILE = os.environ.get("OAUTH_STORAGE_FILE", "/data/oauth_clients.json")
def ensure_storage_dir() -> None:
"""Create storage directory if it doesn't exist"""
storage_dir = os.path.dirname(OAUTH_STORAGE_FILE)
if storage_dir and not os.path.exists(storage_dir):
try:
os.makedirs(storage_dir, exist_ok=True)
logger.info(f"Created OAuth storage directory: {storage_dir}")
except Exception as e:
logger.error(f"Failed to create OAuth storage directory: {e}")
def load_oauth_clients() -> Dict[str, Dict[str, Any]]:
"""
Load OAuth clients from persistent storage.
Returns empty dict if file doesn't exist.
"""
ensure_storage_dir()
if not os.path.exists(OAUTH_STORAGE_FILE):
logger.info(f"OAuth storage file not found: {OAUTH_STORAGE_FILE}")
return {}
try:
with open(OAUTH_STORAGE_FILE, 'r') as f:
clients = json.load(f)
logger.info(f"Loaded {len(clients)} OAuth clients from storage")
return clients
except json.JSONDecodeError as e:
logger.error(f"Failed to parse OAuth storage file: {e}")
return {}
except Exception as e:
logger.error(f"Failed to load OAuth clients: {e}")
return {}
def save_oauth_clients(clients: Dict[str, Dict[str, Any]]) -> bool:
"""
Save OAuth clients to persistent storage.
Returns True if successful, False otherwise.
"""
ensure_storage_dir()
try:
# Write to temp file first, then atomic rename (safer)
temp_file = OAUTH_STORAGE_FILE + ".tmp"
with open(temp_file, 'w') as f:
json.dump(clients, f, indent=2)
# Atomic rename
os.replace(temp_file, OAUTH_STORAGE_FILE)
logger.debug(f"Saved {len(clients)} OAuth clients to storage")
return True
except Exception as e:
logger.error(f"Failed to save OAuth clients: {e}")
# Clean up temp file if it exists
try:
if os.path.exists(temp_file):
os.remove(temp_file)
except:
pass
return False
def get_client(client_id: str) -> Dict[str, Any] | None:
"""Get a specific OAuth client by ID"""
clients = load_oauth_clients()
return clients.get(client_id)
def register_client(client_id: str, client_info: Dict[str, Any]) -> bool:
"""Register a new OAuth client"""
clients = load_oauth_clients()
clients[client_id] = client_info
success = save_oauth_clients(clients)
if success:
logger.info(f"Registered OAuth client: {client_id} ({client_info.get('client_name', 'Unknown')})")
else:
logger.error(f"Failed to register OAuth client: {client_id}")
return success
def unregister_client(client_id: str) -> bool:
"""Unregister an OAuth client"""
clients = load_oauth_clients()
if client_id not in clients:
logger.warning(f"Client {client_id} not found for unregistration")
return False
del clients[client_id]
success = save_oauth_clients(clients)
if success:
logger.info(f"Unregistered OAuth client: {client_id}")
else:
logger.error(f"Failed to unregister OAuth client: {client_id}")
return success
def client_exists(client_id: str) -> bool:
"""Check if a client is registered"""
clients = load_oauth_clients()
return client_id in clients
def get_all_clients() -> Dict[str, Dict[str, Any]]:
"""Get all registered OAuth clients"""
return load_oauth_clients()
def clear_all_clients() -> bool:
"""Clear all OAuth clients (use with caution!)"""
logger.warning("Clearing all OAuth clients")
return save_oauth_clients({})
# Validation helper
def validate_client_secret(client_id: str, client_secret: str) -> bool:
"""Validate a client's secret"""
client = get_client(client_id)
if not client:
return False
return client.get("client_secret") == client_secret