diff --git a/backend/main.py b/backend/main.py index 1c01fa9..955b752 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,11 +1,11 @@ """ Claude Persistent Agent - FastAPI Backend -Main application with task scheduling and management +Main application with task scheduling, live chat, and agent orchestration """ -from fastapi import FastAPI, WebSocket, HTTPException, BackgroundTasks +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks from fastapi.staticfiles import StaticFiles -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List, Dict, Any @@ -26,7 +26,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app -app = FastAPI(title="Claude Persistent Agent", version="1.0.0") +app = FastAPI(title="Claude Persistent Agent", version="2.0.0") # Add CORS middleware app.add_middleware( @@ -41,7 +41,8 @@ app.add_middleware( DB_PATH = Path("/app/data/tasks.db") LOGS_PATH = Path("/app/logs") TASKS_PATH = Path("/app/tasks") -TOKEN_FILE = Path("/app/data/.claude_token") # persist token across restarts +TOKEN_FILE = Path("/app/data/.claude_token") +CHAT_HISTORY_PATH = Path("/app/data/chat_history.json") # Ensure directories exist DB_PATH.parent.mkdir(parents=True, exist_ok=True) @@ -49,30 +50,63 @@ LOGS_PATH.mkdir(parents=True, exist_ok=True) TASKS_PATH.mkdir(parents=True, exist_ok=True) +# ============ Models ============ + class Task(BaseModel): id: Optional[str] = None name: str description: Optional[str] = None prompt: str - schedule_type: str # "once", "recurring", "manual" - schedule_value: Optional[str] = None # cron expression or ISO datetime + schedule_type: str = "manual" # "once", "recurring", "manual" + schedule_value: Optional[str] = None enabled: bool = True created_at: Optional[str] = None last_run: Optional[str] = None next_run: Optional[str] = None - status: str = "idle" # idle, running, completed, failed + status: str = "idle" + # Agent orchestration fields + agent_model: Optional[str] = None # "sonnet", "opus", "haiku" + agent_tools: Optional[str] = None # comma-separated: "Bash,Read,Write,Edit" + agent_mcp_servers: Optional[str] = None # comma-separated MCP server names + agent_system_prompt: Optional[str] = None # custom system prompt + agent_max_turns: Optional[int] = None # max conversation turns + agent_permission_mode: str = "auto" # "auto", "acceptEdits", "plan" + agent_timeout: int = 300 # seconds class TaskRun(BaseModel): task_id: str run_id: str - status: str # running, completed, failed + status: str output: Optional[str] = None error: Optional[str] = None started_at: str completed_at: Optional[str] = None +class ChatMessage(BaseModel): + message: str + model: Optional[str] = None + system_prompt: Optional[str] = None + tools: Optional[str] = None + session_id: Optional[str] = None + + +class TokenSubmit(BaseModel): + token: str + token_type: str = "oauth_token" + + +class McpServerAdd(BaseModel): + name: str + server_type: str = "sse" + url: Optional[str] = None + command: Optional[str] = None + args: Optional[List[str]] = None + + +# ============ Database ============ + def init_db(): """Initialize SQLite database""" conn = sqlite3.connect(DB_PATH) @@ -90,7 +124,14 @@ def init_db(): created_at TEXT, last_run TEXT, next_run TEXT, - status TEXT DEFAULT 'idle' + status TEXT DEFAULT 'idle', + agent_model TEXT, + agent_tools TEXT, + agent_mcp_servers TEXT, + agent_system_prompt TEXT, + agent_max_turns INTEGER, + agent_permission_mode TEXT DEFAULT 'auto', + agent_timeout INTEGER DEFAULT 300 ) """) @@ -107,14 +148,57 @@ def init_db(): ) """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS chat_messages ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TEXT NOT NULL, + model TEXT, + metadata TEXT + ) + """) + + # Migration: add new columns if they don't exist + try: + cursor.execute("ALTER TABLE tasks ADD COLUMN agent_model TEXT") + except sqlite3.OperationalError: + pass + try: + cursor.execute("ALTER TABLE tasks ADD COLUMN agent_tools TEXT") + except sqlite3.OperationalError: + pass + try: + cursor.execute("ALTER TABLE tasks ADD COLUMN agent_mcp_servers TEXT") + except sqlite3.OperationalError: + pass + try: + cursor.execute("ALTER TABLE tasks ADD COLUMN agent_system_prompt TEXT") + except sqlite3.OperationalError: + pass + try: + cursor.execute("ALTER TABLE tasks ADD COLUMN agent_max_turns INTEGER") + except sqlite3.OperationalError: + pass + try: + cursor.execute("ALTER TABLE tasks ADD COLUMN agent_permission_mode TEXT DEFAULT 'auto'") + except sqlite3.OperationalError: + pass + try: + cursor.execute("ALTER TABLE tasks ADD COLUMN agent_timeout INTEGER DEFAULT 300") + except sqlite3.OperationalError: + pass + conn.commit() conn.close() +# ============ Auth Helpers ============ + def _get_claude_env(): """Get environment with auth tokens set for Claude CLI""" env = os.environ.copy() - # Load saved token if it exists if TOKEN_FILE.exists(): try: saved = json.loads(TOKEN_FILE.read_text()) @@ -127,110 +211,144 @@ def _get_claude_env(): return env +def _build_claude_cmd(prompt: str, model: str = None, tools: str = None, + system_prompt: str = None, permission_mode: str = "auto", + session_id: str = None, continue_session: bool = False) -> list: + """Build a claude CLI command with agent configuration""" + cmd = ["claude", "-p"] + + if model: + cmd.extend(["--model", model]) + + if tools: + cmd.extend(["--allowedTools", tools]) + else: + cmd.extend(["--allowedTools", "Bash,Read,Write,Edit"]) + + if permission_mode: + cmd.extend(["--permission-mode", permission_mode]) + + if system_prompt: + cmd.extend(["--system-prompt", system_prompt]) + + if continue_session and session_id: + cmd.extend(["--resume", session_id]) + elif session_id: + cmd.extend(["--session-id", session_id]) + + cmd.append(prompt) + return cmd + + +# ============ Task Execution ============ + async def run_claude_task(task: Task, run_id: str): - """Execute a Claude Code task""" + """Execute a Claude Code task as an agent""" started = datetime.now().isoformat() try: env = _get_claude_env() - # Run Claude Code in print mode (non-interactive) + cmd = _build_claude_cmd( + prompt=task.prompt, + model=task.agent_model, + tools=task.agent_tools, + system_prompt=task.agent_system_prompt, + permission_mode=task.agent_permission_mode + ) + + timeout = task.agent_timeout or 300 + + logger.info(f"Running task {task.name}: {' '.join(cmd)}") + process = await asyncio.create_subprocess_exec( - "claude", "-p", "--allowedTools", "Bash,Read,Write,Edit", - "--permission-mode", "auto", - task.prompt, + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL, # Fix stdin warning env=env ) - stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=300) + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=timeout + ) output = stdout.decode(errors="replace") if stdout else "" error = stderr.decode(errors="replace") if stderr else "" + # Filter out non-critical warnings from stderr + error_lines = [l for l in error.split("\n") + if l.strip() and "stdin" not in l.lower() and "warning" not in l.lower()] + filtered_error = "\n".join(error_lines) + status = "completed" if process.returncode == 0 else "failed" - # Save run result save_task_run(TaskRun( - task_id=task.id, - run_id=run_id, - status=status, + task_id=task.id, run_id=run_id, status=status, output=output, - error=error if status == "failed" else None, - started_at=started, - completed_at=datetime.now().isoformat() + error=filtered_error if status == "failed" else None, + started_at=started, completed_at=datetime.now().isoformat() )) - update_task_status(task.id, status) except asyncio.TimeoutError: save_task_run(TaskRun( - task_id=task.id, - run_id=run_id, - status="failed", - error="Task timeout (>5 minutes)", - started_at=started, - completed_at=datetime.now().isoformat() + task_id=task.id, run_id=run_id, status="failed", + error=f"Task timeout (>{task.agent_timeout or 300}s)", + started_at=started, completed_at=datetime.now().isoformat() )) update_task_status(task.id, "failed") except Exception as e: save_task_run(TaskRun( - task_id=task.id, - run_id=run_id, - status="failed", + task_id=task.id, run_id=run_id, status="failed", error=str(e), - started_at=started, - completed_at=datetime.now().isoformat() + started_at=started, completed_at=datetime.now().isoformat() )) update_task_status(task.id, "failed") +# ============ DB Operations ============ + def save_task(task: Task): - """Save task to database""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - if not task.id: task.id = str(uuid.uuid4()) - if not task.created_at: task.created_at = datetime.now().isoformat() cursor.execute(""" INSERT OR REPLACE INTO tasks - (id, name, description, prompt, schedule_type, schedule_value, enabled, created_at, status) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + (id, name, description, prompt, schedule_type, schedule_value, enabled, + created_at, status, agent_model, agent_tools, agent_mcp_servers, + agent_system_prompt, agent_max_turns, agent_permission_mode, agent_timeout) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( task.id, task.name, task.description, task.prompt, task.schedule_type, task.schedule_value, task.enabled, - task.created_at, task.status + task.created_at, task.status, + task.agent_model, task.agent_tools, task.agent_mcp_servers, + task.agent_system_prompt, task.agent_max_turns, + task.agent_permission_mode, task.agent_timeout )) - conn.commit() conn.close() return task def save_task_run(run: TaskRun): - """Save task run result""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(""" INSERT INTO task_runs (run_id, task_id, status, output, error, started_at, completed_at) VALUES (?, ?, ?, ?, ?, ?, ?) - """, ( - run.run_id, run.task_id, run.status, run.output, - run.error, run.started_at, run.completed_at - )) - + """, (run.run_id, run.task_id, run.status, run.output, + run.error, run.started_at, run.completed_at)) conn.commit() conn.close() def update_task_status(task_id: str, status: str): - """Update task status""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute("UPDATE tasks SET status = ?, last_run = ? WHERE id = ?", @@ -240,79 +358,118 @@ def update_task_status(task_id: str, status: str): def get_task(task_id: str) -> Optional[Task]: - """Get task from database""" conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)) row = cursor.fetchone() conn.close() - if row: - return Task( - id=row[0], name=row[1], description=row[2], prompt=row[3], - schedule_type=row[4], schedule_value=row[5], enabled=row[6], - created_at=row[7], last_run=row[8], next_run=row[9], status=row[10] - ) + return Task(**dict(row)) return None def get_all_tasks() -> List[Task]: - """Get all tasks""" conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute("SELECT * FROM tasks ORDER BY created_at DESC") rows = cursor.fetchall() conn.close() - - return [ - Task( - id=row[0], name=row[1], description=row[2], prompt=row[3], - schedule_type=row[4], schedule_value=row[5], enabled=row[6], - created_at=row[7], last_run=row[8], next_run=row[9], status=row[10] - ) - for row in rows - ] + return [Task(**dict(row)) for row in rows] -# Initialize scheduler +# ============ Scheduler ============ + scheduler = BackgroundScheduler() def schedule_task(task: Task): - """Schedule a task with APScheduler""" if not task.enabled or task.schedule_type == "manual": return - if task.schedule_type == "recurring" and task.schedule_value: try: scheduler.add_job( - run_task_job, - CronTrigger.from_crontab(task.schedule_value), - id=task.id, - args=[task], - replace_existing=True + _sync_run_task, CronTrigger.from_crontab(task.schedule_value), + id=task.id, args=[task], replace_existing=True ) logger.info(f"Scheduled task {task.name} with cron: {task.schedule_value}") except Exception as e: logger.error(f"Failed to schedule task {task.name}: {e}") -async def run_task_job(task: Task): - """Background job to run a task""" +def _sync_run_task(task: Task): + """Sync wrapper for scheduled tasks (APScheduler doesn't support async natively)""" + loop = asyncio.new_event_loop() run_id = str(uuid.uuid4()) update_task_status(task.id, "running") - await run_claude_task(task, run_id) + loop.run_until_complete(run_claude_task(task, run_id)) + loop.close() + + +# ============ Chat Sessions ============ + +# Active chat sessions: session_id -> process +_chat_sessions: Dict[str, asyncio.subprocess.Process] = {} + + +def save_chat_message(session_id: str, role: str, content: str, model: str = None, metadata: dict = None): + """Save a chat message to the database""" + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO chat_messages (id, session_id, role, content, timestamp, model, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, (str(uuid.uuid4()), session_id, role, content, + datetime.now().isoformat(), model, json.dumps(metadata) if metadata else None)) + conn.commit() + conn.close() + + +def get_chat_history(session_id: str, limit: int = 50) -> List[Dict]: + """Get chat history for a session""" + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + cursor.execute(""" + SELECT * FROM chat_messages WHERE session_id = ? + ORDER BY timestamp ASC LIMIT ? + """, (session_id, limit)) + rows = cursor.fetchall() + conn.close() + return [dict(row) for row in rows] + + +def list_chat_sessions() -> List[Dict]: + """List all chat sessions with latest message""" + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + cursor.execute(""" + SELECT session_id, + COUNT(*) as message_count, + MIN(timestamp) as started_at, + MAX(timestamp) as last_message, + (SELECT content FROM chat_messages cm2 + WHERE cm2.session_id = cm.session_id AND cm2.role = 'user' + ORDER BY timestamp ASC LIMIT 1) as first_message + FROM chat_messages cm + GROUP BY session_id + ORDER BY MAX(timestamp) DESC + LIMIT 20 + """) + rows = cursor.fetchall() + conn.close() + return [dict(row) for row in rows] # ============ API Routes ============ @app.on_event("startup") async def startup(): - """Initialize on startup""" init_db() scheduler.start() - # Load saved token into environment on startup if TOKEN_FILE.exists(): try: saved = json.loads(TOKEN_FILE.read_text()) @@ -325,29 +482,33 @@ async def startup(): except Exception as e: logger.error(f"Failed to load saved token: {e}") - # Schedule existing tasks for task in get_all_tasks(): if task.enabled: schedule_task(task) - logger.info("Claude Persistent Agent started") + logger.info("Claude Persistent Agent v2.0 started") @app.on_event("shutdown") async def shutdown(): - """Cleanup on shutdown""" scheduler.shutdown() + # Kill any active chat sessions + for sid, proc in _chat_sessions.items(): + try: + proc.kill() + except Exception: + pass @app.get("/health") async def health(): - """Health check endpoint""" - return {"status": "healthy", "timestamp": datetime.now().isoformat()} + return {"status": "healthy", "timestamp": datetime.now().isoformat(), "version": "2.0.0"} +# ---- Task CRUD ---- + @app.post("/api/tasks") async def create_task(task: Task) -> Task: - """Create a new task""" saved_task = save_task(task) schedule_task(saved_task) return saved_task @@ -355,13 +516,11 @@ async def create_task(task: Task) -> Task: @app.get("/api/tasks") async def list_tasks() -> List[Task]: - """List all tasks""" return get_all_tasks() @app.get("/api/tasks/{task_id}") async def get_task_endpoint(task_id: str) -> Task: - """Get a specific task""" task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") @@ -370,7 +529,6 @@ async def get_task_endpoint(task_id: str) -> Task: @app.put("/api/tasks/{task_id}") async def update_task_endpoint(task_id: str, task: Task) -> Task: - """Update a task""" task.id = task_id saved_task = save_task(task) schedule_task(saved_task) @@ -379,37 +537,34 @@ async def update_task_endpoint(task_id: str, task: Task) -> Task: @app.delete("/api/tasks/{task_id}") async def delete_task_endpoint(task_id: str): - """Delete a task""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,)) conn.commit() conn.close() - - if scheduler.get_job(task_id): - scheduler.remove_job(task_id) - + try: + if scheduler.get_job(task_id): + scheduler.remove_job(task_id) + except Exception: + pass return {"status": "deleted"} @app.post("/api/tasks/{task_id}/run") async def run_task_manual(task_id: str, background_tasks: BackgroundTasks): - """Manually trigger a task""" task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") - run_id = str(uuid.uuid4()) update_task_status(task_id, "running") background_tasks.add_task(run_claude_task, task, run_id) - return {"run_id": run_id, "status": "started"} @app.get("/api/tasks/{task_id}/runs") async def get_task_runs(task_id: str) -> List[Dict]: - """Get runs for a task""" conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute( "SELECT * FROM task_runs WHERE task_id = ? ORDER BY started_at DESC LIMIT 50", @@ -417,24 +572,13 @@ async def get_task_runs(task_id: str) -> List[Dict]: ) rows = cursor.fetchall() conn.close() + return [dict(row) for row in rows] - return [ - { - "run_id": row[0], - "task_id": row[1], - "status": row[2], - "output": row[3], - "error": row[4], - "started_at": row[5], - "completed_at": row[6] - } - for row in rows - ] +# ---- System Info ---- @app.get("/api/system/info") async def system_info() -> Dict: - """Get system information""" tasks = get_all_tasks() conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() @@ -449,7 +593,7 @@ async def system_info() -> Dict: conn.close() return { "app_name": "Claude Persistent Agent", - "version": "1.0.0", + "version": "2.0.0", "uptime": datetime.now().isoformat(), "scheduler_running": scheduler.running, "task_count": len(tasks), @@ -457,39 +601,26 @@ async def system_info() -> Dict: "completed_runs": completed_runs, "failed_runs": failed_runs, "running_runs": running_runs, + "active_chat_sessions": len(_chat_sessions), } @app.get("/api/system/usage") async def usage_stats() -> Dict: - """Get Claude API usage stats from session files if available""" usage = { - "models_used": [], - "session_count": 0, - "last_reset": None, - "next_reset": None, + "models_used": [], "session_count": 0, + "last_reset": None, "next_reset": None, "note": "Usage data sourced from ~/.claude session cache" } - try: claude_dir = Path("/root/.claude") sessions = list(claude_dir.glob("**/session*.json")) + list(claude_dir.glob("**/*.jsonl")) usage["session_count"] = len(sessions) - - stats_file = claude_dir / "usage_stats.json" - if stats_file.exists(): - with open(stats_file) as f: - saved = json.load(f) - usage.update(saved) - else: - now = datetime.now() - if now.day <= 1: - reset = now.replace(day=1, hour=0, minute=0, second=0) - else: - next_month = (now.replace(day=1) + timedelta(days=32)).replace(day=1) - reset = next_month.replace(hour=0, minute=0, second=0) - usage["next_reset"] = reset.isoformat() - usage["days_until_reset"] = (reset - now).days + now = datetime.now() + next_month = (now.replace(day=1) + timedelta(days=32)).replace(day=1) + reset = next_month.replace(hour=0, minute=0, second=0) + usage["next_reset"] = reset.isoformat() + usage["days_until_reset"] = (reset - now).days except Exception as e: usage["error"] = str(e) @@ -501,16 +632,13 @@ async def usage_stats() -> Dict: usage["claude_runs_total"] = row[0] usage["first_run"] = row[1] usage["last_run"] = row[2] - return usage -# ============ Auth Endpoints ============ - +# ---- Auth ---- @app.get("/api/auth/status") async def auth_status(): - """Check Claude auth status""" account = None auth_method = None status = "logged_out" @@ -529,11 +657,11 @@ async def auth_status(): proc = await asyncio.create_subprocess_exec( "claude", "auth", "status", "--json", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL, env=env ) stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=10) output = stdout.decode(errors="replace").strip() - logger.info(f"claude auth status: {output}") data = json.loads(output) if data.get("loggedIn"): status = "logged_in" @@ -543,42 +671,29 @@ async def auth_status(): logger.error(f"Auth status check failed: {e}") return { - "status": status, - "account": account, - "auth_method": auth_method, - "has_saved_token": has_saved_token, - "token_type": token_type + "status": status, "account": account, "auth_method": auth_method, + "has_saved_token": has_saved_token, "token_type": token_type } -class TokenSubmit(BaseModel): - token: str - token_type: str = "oauth_token" # "oauth_token" or "api_key" - - @app.post("/api/auth/token") async def auth_set_token(payload: TokenSubmit): - """Save an auth token (OAuth setup token or API key)""" token = payload.token.strip() token_type = payload.token_type if not token: return {"status": "error", "message": "Token cannot be empty"} - # Auto-detect token type if token.startswith("sk-ant-oat"): token_type = "oauth_token" elif token.startswith("sk-ant-api"): token_type = "api_key" - # Save to file for persistence TOKEN_FILE.write_text(json.dumps({ - "type": token_type, - "token": token, + "type": token_type, "token": token, "saved_at": datetime.now().isoformat() })) - # Set in current process environment if token_type == "oauth_token": os.environ["CLAUDE_CODE_OAUTH_TOKEN"] = token os.environ.pop("ANTHROPIC_API_KEY", None) @@ -586,67 +701,238 @@ async def auth_set_token(payload: TokenSubmit): os.environ["ANTHROPIC_API_KEY"] = token os.environ.pop("CLAUDE_CODE_OAUTH_TOKEN", None) - # Verify auth works + # Quick verify try: env = _get_claude_env() proc = await asyncio.create_subprocess_exec( "claude", "auth", "status", "--json", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - env=env + stdin=asyncio.subprocess.DEVNULL, env=env ) stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=10) output = stdout.decode(errors="replace").strip() data = json.loads(output) if data.get("loggedIn"): return { - "status": "logged_in", - "message": "Token saved and verified!", - "account": data.get("email"), - "auth_method": data.get("authMethod") - } - else: - return { - "status": "token_saved", - "message": "Token saved but auth status shows not logged in. The token may still work for API calls.", - "raw": output + "status": "logged_in", "message": "Token saved and verified!", + "account": data.get("email"), "auth_method": data.get("authMethod") } except Exception as e: - return { - "status": "token_saved", - "message": f"Token saved. Could not verify: {e}" - } + pass + + return {"status": "token_saved", "message": "Token saved. Run a test chat to verify it works with the API."} @app.post("/api/auth/logout") async def auth_logout(): - """Clear saved auth token""" if TOKEN_FILE.exists(): TOKEN_FILE.unlink() os.environ.pop("CLAUDE_CODE_OAUTH_TOKEN", None) os.environ.pop("ANTHROPIC_API_KEY", None) - try: proc = await asyncio.create_subprocess_exec( "claude", "auth", "logout", - stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL ) await asyncio.wait_for(proc.communicate(), timeout=10) except Exception: pass - return {"status": "logged_out"} -# ============ MCP Server Management ============ +# ---- Chat ---- + +@app.post("/api/chat/send") +async def chat_send(msg: ChatMessage): + """Send a message to Claude and get a response (non-streaming)""" + session_id = msg.session_id or str(uuid.uuid4()) + + # Save user message + save_chat_message(session_id, "user", msg.message, model=msg.model) + + env = _get_claude_env() + cmd = _build_claude_cmd( + prompt=msg.message, + model=msg.model, + tools=msg.tools, + system_prompt=msg.system_prompt, + session_id=session_id + ) + + try: + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL, + env=env + ) + + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=120) + + output = stdout.decode(errors="replace").strip() if stdout else "" + error = stderr.decode(errors="replace").strip() if stderr else "" + + # Filter warnings + error_lines = [l for l in error.split("\n") + if l.strip() and "stdin" not in l.lower() and "warning" not in l.lower()] + filtered_error = "\n".join(error_lines) + + if process.returncode == 0 and output: + save_chat_message(session_id, "assistant", output, model=msg.model) + return { + "session_id": session_id, + "response": output, + "status": "ok" + } + else: + error_msg = filtered_error or output or "No response from Claude" + save_chat_message(session_id, "error", error_msg) + return { + "session_id": session_id, + "response": error_msg, + "status": "error" + } + + except asyncio.TimeoutError: + save_chat_message(session_id, "error", "Response timeout (>120s)") + return {"session_id": session_id, "response": "Response timeout (>120s)", "status": "error"} + except Exception as e: + save_chat_message(session_id, "error", str(e)) + return {"session_id": session_id, "response": str(e), "status": "error"} + + +@app.websocket("/api/chat/ws/{session_id}") +async def chat_websocket(websocket: WebSocket, session_id: str): + """WebSocket endpoint for streaming chat""" + await websocket.accept() + + try: + while True: + data = await websocket.receive_text() + msg = json.loads(data) + + user_message = msg.get("message", "") + model = msg.get("model") + tools = msg.get("tools") + system_prompt = msg.get("system_prompt") + + if not user_message: + await websocket.send_json({"type": "error", "content": "Empty message"}) + continue + + save_chat_message(session_id, "user", user_message, model=model) + + env = _get_claude_env() + cmd = _build_claude_cmd( + prompt=user_message, + model=model, + tools=tools, + system_prompt=system_prompt, + session_id=session_id + ) + + await websocket.send_json({"type": "status", "content": "thinking"}) + + try: + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL, + env=env + ) + + _chat_sessions[session_id] = process + + # Stream stdout + full_output = "" + while True: + line = await asyncio.wait_for( + process.stdout.readline(), timeout=120 + ) + if not line: + break + text = line.decode(errors="replace") + full_output += text + await websocket.send_json({ + "type": "chunk", + "content": text + }) + + await process.wait() + + stderr_data = await process.stderr.read() + stderr_text = stderr_data.decode(errors="replace") if stderr_data else "" + + if process.returncode == 0 and full_output.strip(): + save_chat_message(session_id, "assistant", full_output.strip(), model=model) + await websocket.send_json({ + "type": "done", + "content": full_output.strip() + }) + else: + error_msg = stderr_text.strip() or full_output.strip() or "No response" + error_lines = [l for l in error_msg.split("\n") + if l.strip() and "stdin" not in l.lower()] + clean_error = "\n".join(error_lines) or error_msg + save_chat_message(session_id, "error", clean_error) + await websocket.send_json({ + "type": "error", + "content": clean_error + }) + + except asyncio.TimeoutError: + await websocket.send_json({"type": "error", "content": "Response timeout"}) + except Exception as e: + await websocket.send_json({"type": "error", "content": str(e)}) + finally: + _chat_sessions.pop(session_id, None) + + except WebSocketDisconnect: + logger.info(f"Chat WebSocket disconnected: {session_id}") + # Kill process if still running + proc = _chat_sessions.pop(session_id, None) + if proc and proc.returncode is None: + try: + proc.kill() + except Exception: + pass + + +@app.get("/api/chat/sessions") +async def get_chat_sessions(): + """List all chat sessions""" + return list_chat_sessions() + + +@app.get("/api/chat/history/{session_id}") +async def get_chat_session_history(session_id: str, limit: int = 50): + """Get chat history for a session""" + return get_chat_history(session_id, limit) + + +@app.delete("/api/chat/sessions/{session_id}") +async def delete_chat_session(session_id: str): + """Delete a chat session""" + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + cursor.execute("DELETE FROM chat_messages WHERE session_id = ?", (session_id,)) + conn.commit() + conn.close() + return {"status": "deleted"} + + +# ---- MCP Servers ---- @app.get("/api/mcp/servers") async def list_mcp_servers(): - """List configured MCP servers""" try: proc = await asyncio.create_subprocess_exec( "claude", "mcp", "list", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL ) stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=10) output = stdout.decode(errors="replace").strip() @@ -654,7 +940,6 @@ async def list_mcp_servers(): servers = [] if "No MCP servers configured" in output: return {"servers": [], "raw": output} - try: data = json.loads(output) if isinstance(data, list): @@ -668,23 +953,13 @@ async def list_mcp_servers(): parts = line.split() if len(parts) >= 1: servers.append({"name": parts[0], "details": " ".join(parts[1:])}) - return {"servers": servers, "raw": output} except Exception as e: return {"servers": [], "error": str(e)} -class McpServerAdd(BaseModel): - name: str - server_type: str = "sse" # sse | stdio - url: Optional[str] = None - command: Optional[str] = None - args: Optional[List[str]] = None - - @app.post("/api/mcp/servers") async def add_mcp_server(server: McpServerAdd): - """Add an MCP server""" try: cmd = ["claude", "mcp", "add", server.name] if server.server_type == "sse" and server.url: @@ -695,30 +970,24 @@ async def add_mcp_server(server: McpServerAdd): cmd.extend(server.args) proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL ) stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=15) output = stdout.decode(errors="replace") + stderr.decode(errors="replace") - - return { - "status": "ok" if proc.returncode == 0 else "error", - "message": output.strip(), - "name": server.name - } + return {"status": "ok" if proc.returncode == 0 else "error", + "message": output.strip(), "name": server.name} except Exception as e: return {"status": "error", "message": str(e)} @app.delete("/api/mcp/servers/{name}") async def remove_mcp_server(name: str): - """Remove an MCP server""" try: proc = await asyncio.create_subprocess_exec( "claude", "mcp", "remove", name, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL ) stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=10) output = stdout.decode(errors="replace") + stderr.decode(errors="replace") @@ -727,7 +996,7 @@ async def remove_mcp_server(name: str): return {"status": "error", "message": str(e)} -# Serve static frontend +# Serve static frontend — MUST be last app.mount("/", StaticFiles(directory="/app/frontend/dist", html=True), name="static")