143 lines
4.2 KiB
Python
143 lines
4.2 KiB
Python
"""
|
|
Persistent OAuth Client Storage
|
|
Stores OAuth client registrations to disk so they survive gateway restarts.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import logging
|
|
from typing import Dict, Any
|
|
|
|
logger = logging.getLogger("mcp-gateway.oauth-storage")
|
|
|
|
# Path to persist OAuth clients
|
|
OAUTH_STORAGE_FILE = os.environ.get("OAUTH_STORAGE_FILE", "/data/oauth_clients.json")
|
|
|
|
|
|
def ensure_storage_dir() -> None:
|
|
"""Create storage directory if it doesn't exist"""
|
|
storage_dir = os.path.dirname(OAUTH_STORAGE_FILE)
|
|
if storage_dir and not os.path.exists(storage_dir):
|
|
try:
|
|
os.makedirs(storage_dir, exist_ok=True)
|
|
logger.info(f"Created OAuth storage directory: {storage_dir}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create OAuth storage directory: {e}")
|
|
|
|
|
|
def load_oauth_clients() -> Dict[str, Dict[str, Any]]:
|
|
"""
|
|
Load OAuth clients from persistent storage.
|
|
Returns empty dict if file doesn't exist.
|
|
"""
|
|
ensure_storage_dir()
|
|
|
|
if not os.path.exists(OAUTH_STORAGE_FILE):
|
|
logger.info(f"OAuth storage file not found: {OAUTH_STORAGE_FILE}")
|
|
return {}
|
|
|
|
try:
|
|
with open(OAUTH_STORAGE_FILE, 'r') as f:
|
|
clients = json.load(f)
|
|
logger.info(f"Loaded {len(clients)} OAuth clients from storage")
|
|
return clients
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Failed to parse OAuth storage file: {e}")
|
|
return {}
|
|
except Exception as e:
|
|
logger.error(f"Failed to load OAuth clients: {e}")
|
|
return {}
|
|
|
|
|
|
def save_oauth_clients(clients: Dict[str, Dict[str, Any]]) -> bool:
|
|
"""
|
|
Save OAuth clients to persistent storage.
|
|
Returns True if successful, False otherwise.
|
|
"""
|
|
ensure_storage_dir()
|
|
|
|
try:
|
|
# Write to temp file first, then atomic rename (safer)
|
|
temp_file = OAUTH_STORAGE_FILE + ".tmp"
|
|
|
|
with open(temp_file, 'w') as f:
|
|
json.dump(clients, f, indent=2)
|
|
|
|
# Atomic rename
|
|
os.replace(temp_file, OAUTH_STORAGE_FILE)
|
|
logger.debug(f"Saved {len(clients)} OAuth clients to storage")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to save OAuth clients: {e}")
|
|
# Clean up temp file if it exists
|
|
try:
|
|
if os.path.exists(temp_file):
|
|
os.remove(temp_file)
|
|
except:
|
|
pass
|
|
return False
|
|
|
|
|
|
def get_client(client_id: str) -> Dict[str, Any] | None:
|
|
"""Get a specific OAuth client by ID"""
|
|
clients = load_oauth_clients()
|
|
return clients.get(client_id)
|
|
|
|
|
|
def register_client(client_id: str, client_info: Dict[str, Any]) -> bool:
|
|
"""Register a new OAuth client"""
|
|
clients = load_oauth_clients()
|
|
clients[client_id] = client_info
|
|
success = save_oauth_clients(clients)
|
|
|
|
if success:
|
|
logger.info(f"Registered OAuth client: {client_id} ({client_info.get('client_name', 'Unknown')})")
|
|
else:
|
|
logger.error(f"Failed to register OAuth client: {client_id}")
|
|
|
|
return success
|
|
|
|
|
|
def unregister_client(client_id: str) -> bool:
|
|
"""Unregister an OAuth client"""
|
|
clients = load_oauth_clients()
|
|
|
|
if client_id not in clients:
|
|
logger.warning(f"Client {client_id} not found for unregistration")
|
|
return False
|
|
|
|
del clients[client_id]
|
|
success = save_oauth_clients(clients)
|
|
|
|
if success:
|
|
logger.info(f"Unregistered OAuth client: {client_id}")
|
|
else:
|
|
logger.error(f"Failed to unregister OAuth client: {client_id}")
|
|
|
|
return success
|
|
|
|
|
|
def client_exists(client_id: str) -> bool:
|
|
"""Check if a client is registered"""
|
|
clients = load_oauth_clients()
|
|
return client_id in clients
|
|
|
|
|
|
def get_all_clients() -> Dict[str, Dict[str, Any]]:
|
|
"""Get all registered OAuth clients"""
|
|
return load_oauth_clients()
|
|
|
|
|
|
def clear_all_clients() -> bool:
|
|
"""Clear all OAuth clients (use with caution!)"""
|
|
logger.warning("Clearing all OAuth clients")
|
|
return save_oauth_clients({})
|
|
|
|
|
|
# Validation helper
|
|
def validate_client_secret(client_id: str, client_secret: str) -> bool:
|
|
"""Validate a client's secret"""
|
|
client = get_client(client_id)
|
|
if not client:
|
|
return False
|
|
return client.get("client_secret") == client_secret
|