diff --git a/gateway-proxy/oauth_storage.py b/gateway-proxy/oauth_storage.py new file mode 100644 index 0000000..888846d --- /dev/null +++ b/gateway-proxy/oauth_storage.py @@ -0,0 +1,143 @@ +""" +Persistent OAuth Client Storage +Stores OAuth client registrations to disk so they survive gateway restarts. +""" + +import json +import os +import logging +from typing import Dict, Any + +logger = logging.getLogger("mcp-gateway.oauth-storage") + +# Path to persist OAuth clients +OAUTH_STORAGE_FILE = os.environ.get("OAUTH_STORAGE_FILE", "/data/oauth_clients.json") + + +def ensure_storage_dir() -> None: + """Create storage directory if it doesn't exist""" + storage_dir = os.path.dirname(OAUTH_STORAGE_FILE) + if storage_dir and not os.path.exists(storage_dir): + try: + os.makedirs(storage_dir, exist_ok=True) + logger.info(f"Created OAuth storage directory: {storage_dir}") + except Exception as e: + logger.error(f"Failed to create OAuth storage directory: {e}") + + +def load_oauth_clients() -> Dict[str, Dict[str, Any]]: + """ + Load OAuth clients from persistent storage. + Returns empty dict if file doesn't exist. + """ + ensure_storage_dir() + + if not os.path.exists(OAUTH_STORAGE_FILE): + logger.info(f"OAuth storage file not found: {OAUTH_STORAGE_FILE}") + return {} + + try: + with open(OAUTH_STORAGE_FILE, 'r') as f: + clients = json.load(f) + logger.info(f"Loaded {len(clients)} OAuth clients from storage") + return clients + except json.JSONDecodeError as e: + logger.error(f"Failed to parse OAuth storage file: {e}") + return {} + except Exception as e: + logger.error(f"Failed to load OAuth clients: {e}") + return {} + + +def save_oauth_clients(clients: Dict[str, Dict[str, Any]]) -> bool: + """ + Save OAuth clients to persistent storage. + Returns True if successful, False otherwise. + """ + ensure_storage_dir() + + try: + # Write to temp file first, then atomic rename (safer) + temp_file = OAUTH_STORAGE_FILE + ".tmp" + + with open(temp_file, 'w') as f: + json.dump(clients, f, indent=2) + + # Atomic rename + os.replace(temp_file, OAUTH_STORAGE_FILE) + logger.debug(f"Saved {len(clients)} OAuth clients to storage") + return True + except Exception as e: + logger.error(f"Failed to save OAuth clients: {e}") + # Clean up temp file if it exists + try: + if os.path.exists(temp_file): + os.remove(temp_file) + except: + pass + return False + + +def get_client(client_id: str) -> Dict[str, Any] | None: + """Get a specific OAuth client by ID""" + clients = load_oauth_clients() + return clients.get(client_id) + + +def register_client(client_id: str, client_info: Dict[str, Any]) -> bool: + """Register a new OAuth client""" + clients = load_oauth_clients() + clients[client_id] = client_info + success = save_oauth_clients(clients) + + if success: + logger.info(f"Registered OAuth client: {client_id} ({client_info.get('client_name', 'Unknown')})") + else: + logger.error(f"Failed to register OAuth client: {client_id}") + + return success + + +def unregister_client(client_id: str) -> bool: + """Unregister an OAuth client""" + clients = load_oauth_clients() + + if client_id not in clients: + logger.warning(f"Client {client_id} not found for unregistration") + return False + + del clients[client_id] + success = save_oauth_clients(clients) + + if success: + logger.info(f"Unregistered OAuth client: {client_id}") + else: + logger.error(f"Failed to unregister OAuth client: {client_id}") + + return success + + +def client_exists(client_id: str) -> bool: + """Check if a client is registered""" + clients = load_oauth_clients() + return client_id in clients + + +def get_all_clients() -> Dict[str, Dict[str, Any]]: + """Get all registered OAuth clients""" + return load_oauth_clients() + + +def clear_all_clients() -> bool: + """Clear all OAuth clients (use with caution!)""" + logger.warning("Clearing all OAuth clients") + return save_oauth_clients({}) + + +# Validation helper +def validate_client_secret(client_id: str, client_secret: str) -> bool: + """Validate a client's secret""" + client = get_client(client_id) + if not client: + return False + return client.get("client_secret") == client_secret