Add mcp-gateway/gateway-proxy/oauth_storage.py
This commit is contained in:
parent
17593ff97f
commit
aa6a408b34
1 changed files with 143 additions and 0 deletions
143
mcp-gateway/gateway-proxy/oauth_storage.py
Normal file
143
mcp-gateway/gateway-proxy/oauth_storage.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
Persistent OAuth Client Storage
|
||||
Stores OAuth client registrations to disk so they survive gateway restarts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger("mcp-gateway.oauth-storage")
|
||||
|
||||
# Path to persist OAuth clients
|
||||
OAUTH_STORAGE_FILE = os.environ.get("OAUTH_STORAGE_FILE", "/data/oauth_clients.json")
|
||||
|
||||
|
||||
def ensure_storage_dir() -> None:
|
||||
"""Create storage directory if it doesn't exist"""
|
||||
storage_dir = os.path.dirname(OAUTH_STORAGE_FILE)
|
||||
if storage_dir and not os.path.exists(storage_dir):
|
||||
try:
|
||||
os.makedirs(storage_dir, exist_ok=True)
|
||||
logger.info(f"Created OAuth storage directory: {storage_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create OAuth storage directory: {e}")
|
||||
|
||||
|
||||
def load_oauth_clients() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Load OAuth clients from persistent storage.
|
||||
Returns empty dict if file doesn't exist.
|
||||
"""
|
||||
ensure_storage_dir()
|
||||
|
||||
if not os.path.exists(OAUTH_STORAGE_FILE):
|
||||
logger.info(f"OAuth storage file not found: {OAUTH_STORAGE_FILE}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(OAUTH_STORAGE_FILE, 'r') as f:
|
||||
clients = json.load(f)
|
||||
logger.info(f"Loaded {len(clients)} OAuth clients from storage")
|
||||
return clients
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse OAuth storage file: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load OAuth clients: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_oauth_clients(clients: Dict[str, Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
Save OAuth clients to persistent storage.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
ensure_storage_dir()
|
||||
|
||||
try:
|
||||
# Write to temp file first, then atomic rename (safer)
|
||||
temp_file = OAUTH_STORAGE_FILE + ".tmp"
|
||||
|
||||
with open(temp_file, 'w') as f:
|
||||
json.dump(clients, f, indent=2)
|
||||
|
||||
# Atomic rename
|
||||
os.replace(temp_file, OAUTH_STORAGE_FILE)
|
||||
logger.debug(f"Saved {len(clients)} OAuth clients to storage")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save OAuth clients: {e}")
|
||||
# Clean up temp file if it exists
|
||||
try:
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def get_client(client_id: str) -> Dict[str, Any] | None:
|
||||
"""Get a specific OAuth client by ID"""
|
||||
clients = load_oauth_clients()
|
||||
return clients.get(client_id)
|
||||
|
||||
|
||||
def register_client(client_id: str, client_info: Dict[str, Any]) -> bool:
|
||||
"""Register a new OAuth client"""
|
||||
clients = load_oauth_clients()
|
||||
clients[client_id] = client_info
|
||||
success = save_oauth_clients(clients)
|
||||
|
||||
if success:
|
||||
logger.info(f"Registered OAuth client: {client_id} ({client_info.get('client_name', 'Unknown')})")
|
||||
else:
|
||||
logger.error(f"Failed to register OAuth client: {client_id}")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def unregister_client(client_id: str) -> bool:
|
||||
"""Unregister an OAuth client"""
|
||||
clients = load_oauth_clients()
|
||||
|
||||
if client_id not in clients:
|
||||
logger.warning(f"Client {client_id} not found for unregistration")
|
||||
return False
|
||||
|
||||
del clients[client_id]
|
||||
success = save_oauth_clients(clients)
|
||||
|
||||
if success:
|
||||
logger.info(f"Unregistered OAuth client: {client_id}")
|
||||
else:
|
||||
logger.error(f"Failed to unregister OAuth client: {client_id}")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def client_exists(client_id: str) -> bool:
|
||||
"""Check if a client is registered"""
|
||||
clients = load_oauth_clients()
|
||||
return client_id in clients
|
||||
|
||||
|
||||
def get_all_clients() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all registered OAuth clients"""
|
||||
return load_oauth_clients()
|
||||
|
||||
|
||||
def clear_all_clients() -> bool:
|
||||
"""Clear all OAuth clients (use with caution!)"""
|
||||
logger.warning("Clearing all OAuth clients")
|
||||
return save_oauth_clients({})
|
||||
|
||||
|
||||
# Validation helper
|
||||
def validate_client_secret(client_id: str, client_secret: str) -> bool:
|
||||
"""Validate a client's secret"""
|
||||
client = get_client(client_id)
|
||||
if not client:
|
||||
return False
|
||||
return client.get("client_secret") == client_secret
|
||||
Loading…
Reference in a new issue