mcp-servers/ssh-mcp/ssh_mcp.py
Zac Gaetano c387d80d1b Replace docker-mcp with ssh-mcp
SSH access to TrueNAS covers all Docker management needs (docker compose,
logs, restarts) making a dedicated Docker MCP redundant.

- Add ssh-mcp (port 8600): execute shell commands, read files, list dirs,
  tail logs, find files, check disk usage on the TrueNAS host via asyncssh
- Remove docker-mcp (port 9000): redundant given SSH access to the host
- Update docker-compose.yml: wire in ssh-mcp service, remove docker-mcp
  service and its MCP_BACKEND_DOCKER gateway env var
2026-03-31 23:25:10 -04:00

494 lines
17 KiB
Python
Executable file

"""
SSH MCP Server
==============
MCP server for executing shell commands and browsing files on a remote host
via SSH. Designed for use with TrueNAS SCALE (or any SSH-enabled host).
Environment variables:
SSH_HOST Hostname or IP of the remote machine
SSH_PORT SSH port (default: 22)
SSH_USER SSH username
SSH_PASSWORD Password authentication (use this OR SSH_KEY_PATH)
SSH_KEY_PATH Path to private key file inside the container
SSH_PASSPHRASE Passphrase for the private key (optional)
"""
import asyncio
import json
import os
from typing import Optional
import asyncssh
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel, Field, ConfigDict
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
SSH_HOST = os.environ.get("SSH_HOST", "")
SSH_PORT = int(os.environ.get("SSH_PORT", "22"))
SSH_USER = os.environ.get("SSH_USER", "root")
SSH_PASSWORD = os.environ.get("SSH_PASSWORD", "")
SSH_KEY_PATH = os.environ.get("SSH_KEY_PATH", "")
SSH_PASSPHRASE = os.environ.get("SSH_PASSPHRASE", "")
# ---------------------------------------------------------------------------
# MCP Server
# ---------------------------------------------------------------------------
mcp = FastMCP("ssh_mcp")
# ---------------------------------------------------------------------------
# SSH connection helper
# ---------------------------------------------------------------------------
async def _connect() -> asyncssh.SSHClientConnection:
"""Open an SSH connection using password or key auth."""
kwargs: dict = {
"host": SSH_HOST,
"port": SSH_PORT,
"username": SSH_USER,
"known_hosts": None, # Accept all host keys (gateway is trusted network)
}
if SSH_KEY_PATH:
kwargs["client_keys"] = [SSH_KEY_PATH]
if SSH_PASSPHRASE:
kwargs["passphrase"] = SSH_PASSPHRASE
elif SSH_PASSWORD:
kwargs["password"] = SSH_PASSWORD
return await asyncssh.connect(**kwargs)
async def _run(command: str, timeout: int = 30) -> dict:
"""Run a shell command and return stdout, stderr, and exit code."""
try:
conn = await asyncio.wait_for(_connect(), timeout=10)
async with conn:
result = await asyncio.wait_for(
conn.run(command, check=False), timeout=timeout
)
return {
"exit_code": result.exit_status,
"stdout": result.stdout or "",
"stderr": result.stderr or "",
}
except asyncio.TimeoutError:
return {"exit_code": -1, "stdout": "", "stderr": "Error: SSH connection or command timed out."}
except asyncssh.Error as e:
return {"exit_code": -1, "stdout": "", "stderr": f"SSH error: {e}"}
except Exception as e:
return {"exit_code": -1, "stdout": "", "stderr": f"Unexpected error: {e}"}
def _format_result(result: dict, command: str) -> str:
"""Format a shell result into a clean, readable string."""
lines = [f"$ {command}", ""]
if result["stdout"]:
lines.append(result["stdout"].rstrip())
if result["stderr"]:
lines.append(f"[stderr]\n{result['stderr'].rstrip()}")
if result["exit_code"] != 0:
lines.append(f"\n[Exit code: {result['exit_code']}]")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Input models
# ---------------------------------------------------------------------------
class ShellExecInput(BaseModel):
"""Input model for executing a shell command."""
model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid")
command: str = Field(
...,
description="Shell command to execute on the remote host (e.g., 'df -h', 'zpool status', 'systemctl status smbd')",
min_length=1,
max_length=2000,
)
timeout: Optional[int] = Field(
default=30,
description="Command timeout in seconds (default: 30, max: 120)",
ge=1,
le=120,
)
class ReadFileInput(BaseModel):
"""Input model for reading a remote file."""
model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid")
path: str = Field(
...,
description="Absolute path of the file to read on the remote host (e.g., '/etc/hosts', '/var/log/syslog')",
min_length=1,
max_length=500,
)
max_bytes: Optional[int] = Field(
default=65536,
description="Maximum bytes to read from the start of the file (default: 65536 / 64 KB)",
ge=1,
le=1048576,
)
class ListDirInput(BaseModel):
"""Input model for listing a remote directory."""
model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid")
path: str = Field(
...,
description="Absolute path of the directory to list on the remote host (e.g., '/mnt', '/etc')",
min_length=1,
max_length=500,
)
show_hidden: Optional[bool] = Field(
default=False,
description="Whether to include hidden files and directories (those starting with '.')",
)
class TailLogInput(BaseModel):
"""Input model for tailing a log file."""
model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid")
path: str = Field(
...,
description="Absolute path of the log file to tail on the remote host (e.g., '/var/log/syslog', '/var/log/middlewared/middlewared.log')",
min_length=1,
max_length=500,
)
lines: Optional[int] = Field(
default=100,
description="Number of lines to return from the end of the file (default: 100, max: 2000)",
ge=1,
le=2000,
)
grep: Optional[str] = Field(
default=None,
description="Optional grep filter — only return lines matching this pattern (e.g., 'ERROR', 'failed')",
max_length=200,
)
class FindFilesInput(BaseModel):
"""Input model for finding files on the remote host."""
model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid")
path: str = Field(
...,
description="Root directory to search in (e.g., '/mnt', '/etc')",
min_length=1,
max_length=500,
)
pattern: str = Field(
...,
description="Filename pattern to match (e.g., '*.log', 'smb.conf', '*.conf')",
min_length=1,
max_length=200,
)
max_depth: Optional[int] = Field(
default=5,
description="Maximum directory depth to search (default: 5, max: 10)",
ge=1,
le=10,
)
max_results: Optional[int] = Field(
default=50,
description="Maximum number of results to return (default: 50, max: 200)",
ge=1,
le=200,
)
class DiskUsageInput(BaseModel):
"""Input model for checking disk usage."""
model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid")
path: Optional[str] = Field(
default="/",
description="Path to check disk usage for (default: '/'). Use a dataset path like '/mnt/tank' for ZFS.",
max_length=500,
)
human_readable: Optional[bool] = Field(
default=True,
description="Display sizes in human-readable format (KB, MB, GB) — default: True",
)
# ---------------------------------------------------------------------------
# Tools
# ---------------------------------------------------------------------------
@mcp.tool(
name="ssh_exec",
annotations={
"title": "Execute Shell Command",
"readOnlyHint": False,
"destructiveHint": True,
"idempotentHint": False,
"openWorldHint": False,
},
)
async def ssh_exec(params: ShellExecInput) -> str:
"""Execute an arbitrary shell command on the remote host via SSH.
Runs any shell command and returns stdout, stderr, and exit code.
Use this for system administration tasks, checking service status,
running TrueNAS CLI tools (e.g., midclt), ZFS commands, and more.
Args:
params (ShellExecInput): Validated input containing:
- command (str): The shell command to run
- timeout (Optional[int]): Timeout in seconds (default: 30)
Returns:
str: Formatted output showing the command, stdout, stderr, and exit code.
Examples:
- "zpool status" → ZFS pool health
- "df -h" → disk usage summary
- "systemctl status smbd" → Samba service status
- "midclt call system.info" → TrueNAS system info via CLI
- "ls -la /mnt/tank" → list dataset contents
- "cat /etc/hosts" → read a config file
"""
result = await _run(params.command, timeout=params.timeout)
return _format_result(result, params.command)
@mcp.tool(
name="ssh_read_file",
annotations={
"title": "Read Remote File",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": False,
},
)
async def ssh_read_file(params: ReadFileInput) -> str:
"""Read the contents of a file on the remote host.
Reads a remote file and returns its contents. Automatically truncates
files larger than max_bytes to avoid context overflow.
Args:
params (ReadFileInput): Validated input containing:
- path (str): Absolute file path on the remote host
- max_bytes (Optional[int]): Max bytes to read (default: 65536)
Returns:
str: File contents with path header, or an error message.
Examples:
- path="/etc/smb4.conf" → read Samba config
- path="/etc/hosts" → read host file
- path="/var/log/middlewared/middlewared.log" → read TrueNAS middleware log
"""
# Use head -c to limit bytes read from the remote side
command = f"head -c {params.max_bytes} {_quote(params.path)} 2>&1"
result = await _run(command)
if result["exit_code"] != 0 and not result["stdout"]:
return f"Error reading {params.path}: {result['stderr']}"
content = result["stdout"]
header = f"=== {params.path} ===\n"
# Warn if output was truncated
if len(content.encode()) >= params.max_bytes:
header += f"[Showing first {params.max_bytes // 1024} KB — file may be larger]\n"
return header + content
@mcp.tool(
name="ssh_list_dir",
annotations={
"title": "List Remote Directory",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": False,
},
)
async def ssh_list_dir(params: ListDirInput) -> str:
"""List the contents of a directory on the remote host.
Returns a detailed directory listing with file types, sizes, and permissions.
Args:
params (ListDirInput): Validated input containing:
- path (str): Absolute directory path on the remote host
- show_hidden (Optional[bool]): Include hidden files (default: False)
Returns:
str: Formatted directory listing, or an error message.
Examples:
- path="/mnt" → list top-level datasets
- path="/etc" → list config directory
- path="/var/log" → list log files
"""
flags = "-lah" if params.show_hidden else "-lh"
command = f"ls {flags} {_quote(params.path)} 2>&1"
result = await _run(command)
if result["exit_code"] != 0:
return f"Error listing {params.path}: {result['stderr'] or result['stdout']}"
return f"=== {params.path} ===\n{result['stdout']}"
@mcp.tool(
name="ssh_tail_log",
annotations={
"title": "Tail Remote Log File",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": False,
},
)
async def ssh_tail_log(params: TailLogInput) -> str:
"""Read the last N lines of a log file on the remote host, with optional grep filter.
Efficiently retrieves the tail of a log file for troubleshooting.
Combine with grep to filter for specific patterns like errors or warnings.
Args:
params (TailLogInput): Validated input containing:
- path (str): Absolute path to the log file
- lines (Optional[int]): Number of lines from the end (default: 100)
- grep (Optional[str]): Optional pattern to filter lines
Returns:
str: Matching log lines, or an error message.
Examples:
- path="/var/log/syslog", lines=50 → last 50 syslog entries
- path="/var/log/middlewared/middlewared.log", grep="ERROR" → recent errors
- path="/var/log/nginx/error.log", lines=200, grep="502" → recent 502 errors
"""
if params.grep:
command = f"tail -n {params.lines} {_quote(params.path)} | grep -i {_quote(params.grep)} 2>&1"
else:
command = f"tail -n {params.lines} {_quote(params.path)} 2>&1"
result = await _run(command, timeout=15)
if result["exit_code"] != 0 and not result["stdout"]:
return f"Error reading {params.path}: {result['stderr']}"
header = f"=== {params.path} (last {params.lines} lines"
if params.grep:
header += f", grep: {params.grep}"
header += ") ===\n"
output = result["stdout"].strip()
if not output:
return header + "(no matching lines)"
return header + output
@mcp.tool(
name="ssh_find_files",
annotations={
"title": "Find Files on Remote Host",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": False,
},
)
async def ssh_find_files(params: FindFilesInput) -> str:
"""Search for files matching a pattern on the remote host.
Uses the `find` command to locate files by name pattern within
a directory tree. Useful for locating config files, logs, or datasets.
Args:
params (FindFilesInput): Validated input containing:
- path (str): Root directory to search
- pattern (str): Filename pattern to match (glob-style)
- max_depth (Optional[int]): Max directory depth (default: 5)
- max_results (Optional[int]): Max results to return (default: 50)
Returns:
str: Newline-separated list of matching file paths, or a message if none found.
Examples:
- path="/etc", pattern="*.conf" → find all config files
- path="/mnt", pattern="*.log" → find log files on datasets
- path="/", pattern="smb.conf" → locate Samba config
"""
command = (
f"find {_quote(params.path)} -maxdepth {params.max_depth} "
f"-name {_quote(params.pattern)} 2>/dev/null | head -n {params.max_results}"
)
result = await _run(command, timeout=30)
lines = [l for l in result["stdout"].strip().splitlines() if l]
if not lines:
return f"No files matching '{params.pattern}' found under {params.path}"
header = f"Found {len(lines)} file(s) matching '{params.pattern}' under {params.path}:\n"
return header + "\n".join(lines)
@mcp.tool(
name="ssh_disk_usage",
annotations={
"title": "Check Remote Disk Usage",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": False,
},
)
async def ssh_disk_usage(params: DiskUsageInput) -> str:
"""Check disk usage on the remote host.
Returns disk usage for filesystems, with optional focus on a specific path.
For TrueNAS, use ZFS dataset paths like '/mnt/tank' for accurate pool usage.
Args:
params (DiskUsageInput): Validated input containing:
- path (Optional[str]): Path to check (default: '/')
- human_readable (Optional[bool]): Use human-readable sizes (default: True)
Returns:
str: Disk usage summary showing used, available, and total space.
Examples:
- path="/" → overall system disk usage
- path="/mnt/tank" → usage for a specific ZFS pool/dataset
- path="/mnt" → usage across all mounted datasets
"""
flag = "-h" if params.human_readable else ""
command = f"df {flag} {_quote(params.path)} 2>&1"
result = await _run(command)
if result["exit_code"] != 0:
return f"Error checking disk usage for {params.path}: {result['stderr'] or result['stdout']}"
return f"=== Disk Usage: {params.path} ===\n{result['stdout']}"
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
def _quote(s: str) -> str:
"""Shell-quote a string to prevent injection."""
# Replace single quotes with '\'' (end quote, literal quote, start quote)
return "'" + s.replace("'", "'\\''") + "'"