diff --git a/docker-compose.yml b/docker-compose.yml index bdcde1c..c253e34 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -30,7 +30,6 @@ services: - MCP_BACKEND_MEMORY_BANK=http://mcp-memory-bank:8700/mcp - MCP_BACKEND_PUPPETEER=http://mcp-puppeteer:8800/mcp - MCP_BACKEND_SEQUENTIAL_THINKING=http://mcp-sequential-thinking:8900/mcp - - MCP_BACKEND_DOCKER=http://mcp-docker:9000/mcp - GATEWAY_STATIC_API_KEY=${GATEWAY_STATIC_API_KEY} depends_on: - erpnext-mcp @@ -43,7 +42,6 @@ services: - memory-bank-mcp - puppeteer-mcp - sequential-thinking-mcp - - docker-mcp networks: [mcpnet] healthcheck: test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:4444/health', timeout=5)"] @@ -243,25 +241,6 @@ services: start_period: 15s retries: 3 - docker-mcp: - build: - context: ./docker-mcp - dockerfile: Dockerfile - container_name: mcp-docker - restart: unless-stopped - environment: - - PORT=9000 - - DOCKER_HOST=${DOCKER_HOST:-unix:///var/run/docker.sock} - volumes: - - /var/run/docker.sock:/var/run/docker.sock - networks: [mcpnet] - healthcheck: - test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:9000/mcp', timeout=5)"] - interval: 30s - timeout: 5s - start_period: 15s - retries: 3 - volumes: gateway-data: linkedin-data: diff --git a/docker-mcp/docker_mcp.py b/docker-mcp/docker_mcp.py deleted file mode 100755 index 52f463f..0000000 --- a/docker-mcp/docker_mcp.py +++ /dev/null @@ -1,532 +0,0 @@ -""" -Docker MCP Server -================= -MCP server providing Docker container and image management capabilities. -Supports listing, inspecting, starting, stopping, removing containers, -managing images, viewing logs, executing commands, and managing networks/volumes. -Connects to the Docker daemon via socket or TCP. -""" - -import json -import os -from typing import Optional, List, Dict, Any - -import docker -from docker.errors import DockerException, NotFound, APIError -from mcp.server.fastmcp import FastMCP - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - -DOCKER_HOST = os.environ.get("DOCKER_HOST", "unix:///var/run/docker.sock") - -# --------------------------------------------------------------------------- -# MCP Server -# --------------------------------------------------------------------------- - -mcp = FastMCP("docker_mcp") - -# --------------------------------------------------------------------------- -# Docker client -# --------------------------------------------------------------------------- - - -def _client() -> docker.DockerClient: - """Get a Docker client instance.""" - return docker.DockerClient(base_url=DOCKER_HOST, timeout=30) - - -def _safe(func): - """Wrapper for safe Docker API calls.""" - try: - return func() - except NotFound as e: - return {"error": f"Not found: {str(e)}"} - except APIError as e: - return {"error": f"Docker API error: {str(e)}"} - except DockerException as e: - return {"error": f"Docker error: {str(e)}"} - - -# --------------------------------------------------------------------------- -# Container tools -# --------------------------------------------------------------------------- - - -@mcp.tool() -async def list_containers( - all: bool = True, - filters: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """ - List Docker containers. - - Args: - all: Show all containers (including stopped). Default True. - filters: Optional filter dict (e.g., {"status": "running", "name": "my-app"}) - """ - client = _client() - containers = client.containers.list(all=all, filters=filters or {}) - - result = [] - for c in containers: - result.append({ - "id": c.short_id, - "name": c.name, - "image": str(c.image.tags[0]) if c.image.tags else str(c.image.id[:12]), - "status": c.status, - "state": c.attrs.get("State", {}).get("Status"), - "created": c.attrs.get("Created"), - "ports": c.ports, - }) - - return {"containers": result, "total": len(result)} - - -@mcp.tool() -async def inspect_container(container_id: str) -> Dict[str, Any]: - """ - Get detailed information about a container. - - Args: - container_id: Container ID or name - """ - client = _client() - try: - c = client.containers.get(container_id) - attrs = c.attrs - return { - "id": c.short_id, - "name": c.name, - "image": str(c.image.tags[0]) if c.image.tags else str(c.image.id[:12]), - "status": c.status, - "state": attrs.get("State", {}), - "config": { - "env": attrs.get("Config", {}).get("Env", []), - "cmd": attrs.get("Config", {}).get("Cmd"), - "entrypoint": attrs.get("Config", {}).get("Entrypoint"), - "working_dir": attrs.get("Config", {}).get("WorkingDir"), - }, - "network": attrs.get("NetworkSettings", {}).get("Networks", {}), - "mounts": attrs.get("Mounts", []), - "ports": c.ports, - "created": attrs.get("Created"), - "restart_count": attrs.get("RestartCount", 0), - } - except NotFound: - return {"error": f"Container '{container_id}' not found"} - - -@mcp.tool() -async def container_logs( - container_id: str, - tail: int = 100, - since: Optional[str] = None, - timestamps: bool = False, -) -> Dict[str, Any]: - """ - Get logs from a container. - - Args: - container_id: Container ID or name - tail: Number of lines from the end (default 100) - since: Show logs since timestamp (e.g., '2024-01-01T00:00:00') - timestamps: Include timestamps in log output - """ - client = _client() - try: - c = client.containers.get(container_id) - kwargs: Dict[str, Any] = { - "tail": tail, - "timestamps": timestamps, - "stdout": True, - "stderr": True, - } - if since: - kwargs["since"] = since - - logs = c.logs(**kwargs) - log_text = logs.decode("utf-8", errors="replace") - - return { - "container": container_id, - "lines": tail, - "logs": log_text, - } - except NotFound: - return {"error": f"Container '{container_id}' not found"} - - -@mcp.tool() -async def start_container(container_id: str) -> Dict[str, Any]: - """ - Start a stopped container. - - Args: - container_id: Container ID or name - """ - client = _client() - try: - c = client.containers.get(container_id) - c.start() - c.reload() - return {"status": "started", "container": c.name, "state": c.status} - except NotFound: - return {"error": f"Container '{container_id}' not found"} - - -@mcp.tool() -async def stop_container( - container_id: str, - timeout: int = 10, -) -> Dict[str, Any]: - """ - Stop a running container. - - Args: - container_id: Container ID or name - timeout: Seconds to wait before killing (default 10) - """ - client = _client() - try: - c = client.containers.get(container_id) - c.stop(timeout=timeout) - c.reload() - return {"status": "stopped", "container": c.name, "state": c.status} - except NotFound: - return {"error": f"Container '{container_id}' not found"} - - -@mcp.tool() -async def restart_container( - container_id: str, - timeout: int = 10, -) -> Dict[str, Any]: - """ - Restart a container. - - Args: - container_id: Container ID or name - timeout: Seconds to wait before killing (default 10) - """ - client = _client() - try: - c = client.containers.get(container_id) - c.restart(timeout=timeout) - c.reload() - return {"status": "restarted", "container": c.name, "state": c.status} - except NotFound: - return {"error": f"Container '{container_id}' not found"} - - -@mcp.tool() -async def remove_container( - container_id: str, - force: bool = False, - v: bool = False, -) -> Dict[str, Any]: - """ - Remove a container. - - Args: - container_id: Container ID or name - force: Force remove even if running - v: Remove associated volumes - """ - client = _client() - try: - c = client.containers.get(container_id) - name = c.name - c.remove(force=force, v=v) - return {"status": "removed", "container": name} - except NotFound: - return {"error": f"Container '{container_id}' not found"} - - -@mcp.tool() -async def exec_in_container( - container_id: str, - command: str, - workdir: Optional[str] = None, -) -> Dict[str, Any]: - """ - Execute a command inside a running container. - - Args: - container_id: Container ID or name - command: Command to execute (shell string) - workdir: Optional working directory inside the container - """ - client = _client() - try: - c = client.containers.get(container_id) - kwargs: Dict[str, Any] = {"cmd": command, "stdout": True, "stderr": True} - if workdir: - kwargs["workdir"] = workdir - - exit_code, output = c.exec_run(**kwargs) - return { - "container": container_id, - "command": command, - "exit_code": exit_code, - "output": output.decode("utf-8", errors="replace"), - } - except NotFound: - return {"error": f"Container '{container_id}' not found"} - - -@mcp.tool() -async def container_stats(container_id: str) -> Dict[str, Any]: - """ - Get resource usage stats for a container. - - Args: - container_id: Container ID or name - """ - client = _client() - try: - c = client.containers.get(container_id) - stats = c.stats(stream=False) - - # Parse CPU - cpu_delta = stats["cpu_stats"]["cpu_usage"]["total_usage"] - \ - stats["precpu_stats"]["cpu_usage"]["total_usage"] - system_delta = stats["cpu_stats"]["system_cpu_usage"] - \ - stats["precpu_stats"]["system_cpu_usage"] - num_cpus = len(stats["cpu_stats"]["cpu_usage"].get("percpu_usage", [1])) - cpu_percent = (cpu_delta / system_delta) * num_cpus * 100.0 if system_delta > 0 else 0.0 - - # Parse Memory - mem_usage = stats["memory_stats"].get("usage", 0) - mem_limit = stats["memory_stats"].get("limit", 1) - mem_percent = (mem_usage / mem_limit) * 100.0 - - return { - "container": container_id, - "cpu_percent": round(cpu_percent, 2), - "memory_usage_mb": round(mem_usage / (1024 * 1024), 2), - "memory_limit_mb": round(mem_limit / (1024 * 1024), 2), - "memory_percent": round(mem_percent, 2), - "network_rx_bytes": sum( - v.get("rx_bytes", 0) for v in stats.get("networks", {}).values() - ), - "network_tx_bytes": sum( - v.get("tx_bytes", 0) for v in stats.get("networks", {}).values() - ), - "pids": stats.get("pids_stats", {}).get("current", 0), - } - except NotFound: - return {"error": f"Container '{container_id}' not found"} - except (KeyError, ZeroDivisionError) as e: - return {"error": f"Failed to parse stats: {str(e)}"} - - -# --------------------------------------------------------------------------- -# Image tools -# --------------------------------------------------------------------------- - - -@mcp.tool() -async def list_images( - name: Optional[str] = None, - all: bool = False, -) -> Dict[str, Any]: - """ - List Docker images. - - Args: - name: Optional filter by image name - all: Show all images including intermediate layers - """ - client = _client() - images = client.images.list(name=name, all=all) - - result = [] - for img in images: - result.append({ - "id": img.short_id, - "tags": img.tags, - "size_mb": round(img.attrs.get("Size", 0) / (1024 * 1024), 2), - "created": img.attrs.get("Created"), - }) - - return {"images": result, "total": len(result)} - - -@mcp.tool() -async def pull_image( - image: str, - tag: str = "latest", -) -> Dict[str, Any]: - """ - Pull a Docker image from a registry. - - Args: - image: Image name (e.g., 'nginx', 'python') - tag: Image tag (default: 'latest') - """ - client = _client() - try: - img = client.images.pull(image, tag=tag) - return { - "status": "pulled", - "image": f"{image}:{tag}", - "id": img.short_id, - "size_mb": round(img.attrs.get("Size", 0) / (1024 * 1024), 2), - } - except APIError as e: - return {"error": f"Failed to pull {image}:{tag}: {str(e)}"} - - -@mcp.tool() -async def remove_image( - image: str, - force: bool = False, -) -> Dict[str, Any]: - """ - Remove a Docker image. - - Args: - image: Image ID or name:tag - force: Force removal - """ - client = _client() - try: - client.images.remove(image, force=force) - return {"status": "removed", "image": image} - except NotFound: - return {"error": f"Image '{image}' not found"} - - -# --------------------------------------------------------------------------- -# System tools -# --------------------------------------------------------------------------- - - -@mcp.tool() -async def docker_system_info() -> Dict[str, Any]: - """Get Docker system-wide information.""" - client = _client() - info = client.info() - return { - "docker_version": info.get("ServerVersion"), - "os": info.get("OperatingSystem"), - "arch": info.get("Architecture"), - "cpus": info.get("NCPU"), - "memory_gb": round(info.get("MemTotal", 0) / (1024**3), 2), - "containers_running": info.get("ContainersRunning"), - "containers_stopped": info.get("ContainersStopped"), - "containers_paused": info.get("ContainersPaused"), - "images": info.get("Images"), - "storage_driver": info.get("Driver"), - } - - -@mcp.tool() -async def docker_disk_usage() -> Dict[str, Any]: - """Get Docker disk usage summary.""" - client = _client() - df = client.df() - - containers_size = sum(c.get("SizeRw", 0) for c in df.get("Containers", [])) - images_size = sum(i.get("Size", 0) for i in df.get("Images", [])) - volumes_size = sum(v.get("UsageData", {}).get("Size", 0) for v in df.get("Volumes", [])) - - return { - "containers_size_mb": round(containers_size / (1024 * 1024), 2), - "images_size_mb": round(images_size / (1024 * 1024), 2), - "volumes_size_mb": round(volumes_size / (1024 * 1024), 2), - "total_mb": round((containers_size + images_size + volumes_size) / (1024 * 1024), 2), - "images_count": len(df.get("Images", [])), - "containers_count": len(df.get("Containers", [])), - "volumes_count": len(df.get("Volumes", [])), - } - - -# --------------------------------------------------------------------------- -# Network tools -# --------------------------------------------------------------------------- - - -@mcp.tool() -async def list_networks() -> Dict[str, Any]: - """List Docker networks.""" - client = _client() - networks = client.networks.list() - - result = [] - for n in networks: - result.append({ - "id": n.short_id, - "name": n.name, - "driver": n.attrs.get("Driver"), - "scope": n.attrs.get("Scope"), - "containers": len(n.attrs.get("Containers", {})), - }) - - return {"networks": result, "total": len(result)} - - -# --------------------------------------------------------------------------- -# Volume tools -# --------------------------------------------------------------------------- - - -@mcp.tool() -async def list_volumes() -> Dict[str, Any]: - """List Docker volumes.""" - client = _client() - volumes = client.volumes.list() - - result = [] - for v in volumes: - result.append({ - "name": v.name, - "driver": v.attrs.get("Driver"), - "mountpoint": v.attrs.get("Mountpoint"), - "created": v.attrs.get("CreatedAt"), - }) - - return {"volumes": result, "total": len(result)} - - -@mcp.tool() -async def prune_system( - containers: bool = True, - images: bool = True, - volumes: bool = False, - networks: bool = True, -) -> Dict[str, Any]: - """ - Prune unused Docker resources. - - Args: - containers: Prune stopped containers - images: Prune dangling images - volumes: Prune unused volumes (CAUTION: data loss) - networks: Prune unused networks - """ - client = _client() - results = {} - - if containers: - r = client.containers.prune() - results["containers_deleted"] = len(r.get("ContainersDeleted", []) or []) - results["containers_space_mb"] = round(r.get("SpaceReclaimed", 0) / (1024 * 1024), 2) - - if images: - r = client.images.prune() - results["images_deleted"] = len(r.get("ImagesDeleted", []) or []) - results["images_space_mb"] = round(r.get("SpaceReclaimed", 0) / (1024 * 1024), 2) - - if volumes: - r = client.volumes.prune() - results["volumes_deleted"] = len(r.get("VolumesDeleted", []) or []) - results["volumes_space_mb"] = round(r.get("SpaceReclaimed", 0) / (1024 * 1024), 2) - - if networks: - r = client.networks.prune() - results["networks_deleted"] = len(r.get("NetworksDeleted", []) or []) - - return results diff --git a/docker-mcp/Dockerfile b/ssh-mcp/Dockerfile similarity index 54% rename from docker-mcp/Dockerfile rename to ssh-mcp/Dockerfile index 90ac381..a2f8fc9 100755 --- a/docker-mcp/Dockerfile +++ b/ssh-mcp/Dockerfile @@ -2,17 +2,22 @@ FROM python:3.12-slim-bookworm WORKDIR /app +# Install OpenSSH client libs needed by asyncssh +RUN apt-get update && apt-get install -y --no-install-recommends \ + libssl-dev \ + && rm -rf /var/lib/apt/lists/* + COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt -COPY docker_mcp.py . +COPY ssh_mcp.py . COPY entrypoint.py . -ENV PORT=9000 +ENV PORT=8600 -EXPOSE 9000 +EXPOSE 8600 HEALTHCHECK --interval=30s --timeout=5s --start-period=15s \ - CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:9000/mcp', timeout=5)" + CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8600/mcp', timeout=5)" CMD ["python3", "entrypoint.py"] diff --git a/docker-mcp/entrypoint.py b/ssh-mcp/entrypoint.py similarity index 77% rename from docker-mcp/entrypoint.py rename to ssh-mcp/entrypoint.py index 050c693..360b82a 100755 --- a/docker-mcp/entrypoint.py +++ b/ssh-mcp/entrypoint.py @@ -1,9 +1,9 @@ import os -from docker_mcp import mcp +from ssh_mcp import mcp from mcp.server.fastmcp.server import TransportSecuritySettings mcp.settings.host = "0.0.0.0" -mcp.settings.port = int(os.environ.get("PORT", "9000")) +mcp.settings.port = int(os.environ.get("PORT", "8600")) mcp.settings.transport_security = TransportSecuritySettings( enable_dns_rebinding_protection=False, ) diff --git a/docker-mcp/requirements.txt b/ssh-mcp/requirements.txt similarity index 70% rename from docker-mcp/requirements.txt rename to ssh-mcp/requirements.txt index a62204e..7376e4d 100755 --- a/docker-mcp/requirements.txt +++ b/ssh-mcp/requirements.txt @@ -1,6 +1,5 @@ mcp[cli]>=1.0.0 -httpx>=0.27.0 +asyncssh>=2.14.0 pydantic>=2.0.0 uvicorn>=0.30.0 starlette>=0.38.0 -docker>=7.0.0 diff --git a/ssh-mcp/ssh_mcp.py b/ssh-mcp/ssh_mcp.py new file mode 100755 index 0000000..6933651 --- /dev/null +++ b/ssh-mcp/ssh_mcp.py @@ -0,0 +1,494 @@ +""" +SSH MCP Server +============== +MCP server for executing shell commands and browsing files on a remote host +via SSH. Designed for use with TrueNAS SCALE (or any SSH-enabled host). + +Environment variables: + SSH_HOST Hostname or IP of the remote machine + SSH_PORT SSH port (default: 22) + SSH_USER SSH username + SSH_PASSWORD Password authentication (use this OR SSH_KEY_PATH) + SSH_KEY_PATH Path to private key file inside the container + SSH_PASSPHRASE Passphrase for the private key (optional) +""" + +import asyncio +import json +import os +from typing import Optional + +import asyncssh +from mcp.server.fastmcp import FastMCP +from pydantic import BaseModel, Field, ConfigDict + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +SSH_HOST = os.environ.get("SSH_HOST", "") +SSH_PORT = int(os.environ.get("SSH_PORT", "22")) +SSH_USER = os.environ.get("SSH_USER", "root") +SSH_PASSWORD = os.environ.get("SSH_PASSWORD", "") +SSH_KEY_PATH = os.environ.get("SSH_KEY_PATH", "") +SSH_PASSPHRASE = os.environ.get("SSH_PASSPHRASE", "") + +# --------------------------------------------------------------------------- +# MCP Server +# --------------------------------------------------------------------------- + +mcp = FastMCP("ssh_mcp") + +# --------------------------------------------------------------------------- +# SSH connection helper +# --------------------------------------------------------------------------- + + +async def _connect() -> asyncssh.SSHClientConnection: + """Open an SSH connection using password or key auth.""" + kwargs: dict = { + "host": SSH_HOST, + "port": SSH_PORT, + "username": SSH_USER, + "known_hosts": None, # Accept all host keys (gateway is trusted network) + } + if SSH_KEY_PATH: + kwargs["client_keys"] = [SSH_KEY_PATH] + if SSH_PASSPHRASE: + kwargs["passphrase"] = SSH_PASSPHRASE + elif SSH_PASSWORD: + kwargs["password"] = SSH_PASSWORD + + return await asyncssh.connect(**kwargs) + + +async def _run(command: str, timeout: int = 30) -> dict: + """Run a shell command and return stdout, stderr, and exit code.""" + try: + conn = await asyncio.wait_for(_connect(), timeout=10) + async with conn: + result = await asyncio.wait_for( + conn.run(command, check=False), timeout=timeout + ) + return { + "exit_code": result.exit_status, + "stdout": result.stdout or "", + "stderr": result.stderr or "", + } + except asyncio.TimeoutError: + return {"exit_code": -1, "stdout": "", "stderr": "Error: SSH connection or command timed out."} + except asyncssh.Error as e: + return {"exit_code": -1, "stdout": "", "stderr": f"SSH error: {e}"} + except Exception as e: + return {"exit_code": -1, "stdout": "", "stderr": f"Unexpected error: {e}"} + + +def _format_result(result: dict, command: str) -> str: + """Format a shell result into a clean, readable string.""" + lines = [f"$ {command}", ""] + if result["stdout"]: + lines.append(result["stdout"].rstrip()) + if result["stderr"]: + lines.append(f"[stderr]\n{result['stderr'].rstrip()}") + if result["exit_code"] != 0: + lines.append(f"\n[Exit code: {result['exit_code']}]") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Input models +# --------------------------------------------------------------------------- + + +class ShellExecInput(BaseModel): + """Input model for executing a shell command.""" + model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid") + + command: str = Field( + ..., + description="Shell command to execute on the remote host (e.g., 'df -h', 'zpool status', 'systemctl status smbd')", + min_length=1, + max_length=2000, + ) + timeout: Optional[int] = Field( + default=30, + description="Command timeout in seconds (default: 30, max: 120)", + ge=1, + le=120, + ) + + +class ReadFileInput(BaseModel): + """Input model for reading a remote file.""" + model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid") + + path: str = Field( + ..., + description="Absolute path of the file to read on the remote host (e.g., '/etc/hosts', '/var/log/syslog')", + min_length=1, + max_length=500, + ) + max_bytes: Optional[int] = Field( + default=65536, + description="Maximum bytes to read from the start of the file (default: 65536 / 64 KB)", + ge=1, + le=1048576, + ) + + +class ListDirInput(BaseModel): + """Input model for listing a remote directory.""" + model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid") + + path: str = Field( + ..., + description="Absolute path of the directory to list on the remote host (e.g., '/mnt', '/etc')", + min_length=1, + max_length=500, + ) + show_hidden: Optional[bool] = Field( + default=False, + description="Whether to include hidden files and directories (those starting with '.')", + ) + + +class TailLogInput(BaseModel): + """Input model for tailing a log file.""" + model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid") + + path: str = Field( + ..., + description="Absolute path of the log file to tail on the remote host (e.g., '/var/log/syslog', '/var/log/middlewared/middlewared.log')", + min_length=1, + max_length=500, + ) + lines: Optional[int] = Field( + default=100, + description="Number of lines to return from the end of the file (default: 100, max: 2000)", + ge=1, + le=2000, + ) + grep: Optional[str] = Field( + default=None, + description="Optional grep filter — only return lines matching this pattern (e.g., 'ERROR', 'failed')", + max_length=200, + ) + + +class FindFilesInput(BaseModel): + """Input model for finding files on the remote host.""" + model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid") + + path: str = Field( + ..., + description="Root directory to search in (e.g., '/mnt', '/etc')", + min_length=1, + max_length=500, + ) + pattern: str = Field( + ..., + description="Filename pattern to match (e.g., '*.log', 'smb.conf', '*.conf')", + min_length=1, + max_length=200, + ) + max_depth: Optional[int] = Field( + default=5, + description="Maximum directory depth to search (default: 5, max: 10)", + ge=1, + le=10, + ) + max_results: Optional[int] = Field( + default=50, + description="Maximum number of results to return (default: 50, max: 200)", + ge=1, + le=200, + ) + + +class DiskUsageInput(BaseModel): + """Input model for checking disk usage.""" + model_config = ConfigDict(str_strip_whitespace=True, validate_assignment=True, extra="forbid") + + path: Optional[str] = Field( + default="/", + description="Path to check disk usage for (default: '/'). Use a dataset path like '/mnt/tank' for ZFS.", + max_length=500, + ) + human_readable: Optional[bool] = Field( + default=True, + description="Display sizes in human-readable format (KB, MB, GB) — default: True", + ) + + +# --------------------------------------------------------------------------- +# Tools +# --------------------------------------------------------------------------- + + +@mcp.tool( + name="ssh_exec", + annotations={ + "title": "Execute Shell Command", + "readOnlyHint": False, + "destructiveHint": True, + "idempotentHint": False, + "openWorldHint": False, + }, +) +async def ssh_exec(params: ShellExecInput) -> str: + """Execute an arbitrary shell command on the remote host via SSH. + + Runs any shell command and returns stdout, stderr, and exit code. + Use this for system administration tasks, checking service status, + running TrueNAS CLI tools (e.g., midclt), ZFS commands, and more. + + Args: + params (ShellExecInput): Validated input containing: + - command (str): The shell command to run + - timeout (Optional[int]): Timeout in seconds (default: 30) + + Returns: + str: Formatted output showing the command, stdout, stderr, and exit code. + + Examples: + - "zpool status" → ZFS pool health + - "df -h" → disk usage summary + - "systemctl status smbd" → Samba service status + - "midclt call system.info" → TrueNAS system info via CLI + - "ls -la /mnt/tank" → list dataset contents + - "cat /etc/hosts" → read a config file + """ + result = await _run(params.command, timeout=params.timeout) + return _format_result(result, params.command) + + +@mcp.tool( + name="ssh_read_file", + annotations={ + "title": "Read Remote File", + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + }, +) +async def ssh_read_file(params: ReadFileInput) -> str: + """Read the contents of a file on the remote host. + + Reads a remote file and returns its contents. Automatically truncates + files larger than max_bytes to avoid context overflow. + + Args: + params (ReadFileInput): Validated input containing: + - path (str): Absolute file path on the remote host + - max_bytes (Optional[int]): Max bytes to read (default: 65536) + + Returns: + str: File contents with path header, or an error message. + + Examples: + - path="/etc/smb4.conf" → read Samba config + - path="/etc/hosts" → read host file + - path="/var/log/middlewared/middlewared.log" → read TrueNAS middleware log + """ + # Use head -c to limit bytes read from the remote side + command = f"head -c {params.max_bytes} {_quote(params.path)} 2>&1" + result = await _run(command) + + if result["exit_code"] != 0 and not result["stdout"]: + return f"Error reading {params.path}: {result['stderr']}" + + content = result["stdout"] + header = f"=== {params.path} ===\n" + + # Warn if output was truncated + if len(content.encode()) >= params.max_bytes: + header += f"[Showing first {params.max_bytes // 1024} KB — file may be larger]\n" + + return header + content + + +@mcp.tool( + name="ssh_list_dir", + annotations={ + "title": "List Remote Directory", + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + }, +) +async def ssh_list_dir(params: ListDirInput) -> str: + """List the contents of a directory on the remote host. + + Returns a detailed directory listing with file types, sizes, and permissions. + + Args: + params (ListDirInput): Validated input containing: + - path (str): Absolute directory path on the remote host + - show_hidden (Optional[bool]): Include hidden files (default: False) + + Returns: + str: Formatted directory listing, or an error message. + + Examples: + - path="/mnt" → list top-level datasets + - path="/etc" → list config directory + - path="/var/log" → list log files + """ + flags = "-lah" if params.show_hidden else "-lh" + command = f"ls {flags} {_quote(params.path)} 2>&1" + result = await _run(command) + + if result["exit_code"] != 0: + return f"Error listing {params.path}: {result['stderr'] or result['stdout']}" + + return f"=== {params.path} ===\n{result['stdout']}" + + +@mcp.tool( + name="ssh_tail_log", + annotations={ + "title": "Tail Remote Log File", + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + }, +) +async def ssh_tail_log(params: TailLogInput) -> str: + """Read the last N lines of a log file on the remote host, with optional grep filter. + + Efficiently retrieves the tail of a log file for troubleshooting. + Combine with grep to filter for specific patterns like errors or warnings. + + Args: + params (TailLogInput): Validated input containing: + - path (str): Absolute path to the log file + - lines (Optional[int]): Number of lines from the end (default: 100) + - grep (Optional[str]): Optional pattern to filter lines + + Returns: + str: Matching log lines, or an error message. + + Examples: + - path="/var/log/syslog", lines=50 → last 50 syslog entries + - path="/var/log/middlewared/middlewared.log", grep="ERROR" → recent errors + - path="/var/log/nginx/error.log", lines=200, grep="502" → recent 502 errors + """ + if params.grep: + command = f"tail -n {params.lines} {_quote(params.path)} | grep -i {_quote(params.grep)} 2>&1" + else: + command = f"tail -n {params.lines} {_quote(params.path)} 2>&1" + + result = await _run(command, timeout=15) + + if result["exit_code"] != 0 and not result["stdout"]: + return f"Error reading {params.path}: {result['stderr']}" + + header = f"=== {params.path} (last {params.lines} lines" + if params.grep: + header += f", grep: {params.grep}" + header += ") ===\n" + + output = result["stdout"].strip() + if not output: + return header + "(no matching lines)" + + return header + output + + +@mcp.tool( + name="ssh_find_files", + annotations={ + "title": "Find Files on Remote Host", + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + }, +) +async def ssh_find_files(params: FindFilesInput) -> str: + """Search for files matching a pattern on the remote host. + + Uses the `find` command to locate files by name pattern within + a directory tree. Useful for locating config files, logs, or datasets. + + Args: + params (FindFilesInput): Validated input containing: + - path (str): Root directory to search + - pattern (str): Filename pattern to match (glob-style) + - max_depth (Optional[int]): Max directory depth (default: 5) + - max_results (Optional[int]): Max results to return (default: 50) + + Returns: + str: Newline-separated list of matching file paths, or a message if none found. + + Examples: + - path="/etc", pattern="*.conf" → find all config files + - path="/mnt", pattern="*.log" → find log files on datasets + - path="/", pattern="smb.conf" → locate Samba config + """ + command = ( + f"find {_quote(params.path)} -maxdepth {params.max_depth} " + f"-name {_quote(params.pattern)} 2>/dev/null | head -n {params.max_results}" + ) + result = await _run(command, timeout=30) + + lines = [l for l in result["stdout"].strip().splitlines() if l] + + if not lines: + return f"No files matching '{params.pattern}' found under {params.path}" + + header = f"Found {len(lines)} file(s) matching '{params.pattern}' under {params.path}:\n" + return header + "\n".join(lines) + + +@mcp.tool( + name="ssh_disk_usage", + annotations={ + "title": "Check Remote Disk Usage", + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + }, +) +async def ssh_disk_usage(params: DiskUsageInput) -> str: + """Check disk usage on the remote host. + + Returns disk usage for filesystems, with optional focus on a specific path. + For TrueNAS, use ZFS dataset paths like '/mnt/tank' for accurate pool usage. + + Args: + params (DiskUsageInput): Validated input containing: + - path (Optional[str]): Path to check (default: '/') + - human_readable (Optional[bool]): Use human-readable sizes (default: True) + + Returns: + str: Disk usage summary showing used, available, and total space. + + Examples: + - path="/" → overall system disk usage + - path="/mnt/tank" → usage for a specific ZFS pool/dataset + - path="/mnt" → usage across all mounted datasets + """ + flag = "-h" if params.human_readable else "" + command = f"df {flag} {_quote(params.path)} 2>&1" + result = await _run(command) + + if result["exit_code"] != 0: + return f"Error checking disk usage for {params.path}: {result['stderr'] or result['stdout']}" + + return f"=== Disk Usage: {params.path} ===\n{result['stdout']}" + + +# --------------------------------------------------------------------------- +# Utility +# --------------------------------------------------------------------------- + + +def _quote(s: str) -> str: + """Shell-quote a string to prevent injection.""" + # Replace single quotes with '\'' (end quote, literal quote, start quote) + return "'" + s.replace("'", "'\\''") + "'"