diff --git a/backend/main.py b/backend/main.py index 955b752..ffd26b4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -213,8 +213,13 @@ def _get_claude_env(): def _build_claude_cmd(prompt: str, model: str = None, tools: str = None, system_prompt: str = None, permission_mode: str = "auto", - session_id: str = None, continue_session: bool = False) -> list: - """Build a claude CLI command with agent configuration""" + resume_session_id: str = None) -> list: + """Build a claude CLI command with agent configuration. + + Note: --session-id requires a valid UUID and creates a persistent session. + For chat continuity, use resume_session_id with --resume. + For one-shot tasks, omit resume_session_id. + """ cmd = ["claude", "-p"] if model: @@ -231,10 +236,9 @@ def _build_claude_cmd(prompt: str, model: str = None, tools: str = None, if system_prompt: cmd.extend(["--system-prompt", system_prompt]) - if continue_session and session_id: - cmd.extend(["--resume", session_id]) - elif session_id: - cmd.extend(["--session-id", session_id]) + # Resume an existing Claude session for multi-turn chat + if resume_session_id: + cmd.extend(["--resume", resume_session_id]) cmd.append(prompt) return cmd @@ -409,8 +413,10 @@ def _sync_run_task(task: Task): # ============ Chat Sessions ============ -# Active chat sessions: session_id -> process +# Active chat sessions: our_session_id -> process _chat_sessions: Dict[str, asyncio.subprocess.Process] = {} +# Mapping: our_session_id -> claude_session_id (for --resume) +_claude_session_map: Dict[str, str] = {} def save_chat_message(session_id: str, role: str, content: str, model: str = None, metadata: dict = None): @@ -426,6 +432,32 @@ def save_chat_message(session_id: str, role: str, content: str, model: str = Non conn.close() +def _get_claude_session_id(our_session_id: str) -> Optional[str]: + """Get the Claude CLI session ID mapped to our session ID""" + if our_session_id in _claude_session_map: + return _claude_session_map[our_session_id] + # Also check DB metadata + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + cursor.execute(""" + SELECT metadata FROM chat_messages + WHERE session_id = ? AND metadata IS NOT NULL + ORDER BY timestamp DESC LIMIT 1 + """, (our_session_id,)) + row = cursor.fetchone() + conn.close() + if row and row[0]: + try: + meta = json.loads(row[0]) + claude_sid = meta.get("claude_session_id") + if claude_sid: + _claude_session_map[our_session_id] = claude_sid + return claude_sid + except Exception: + pass + return None + + def get_chat_history(session_id: str, limit: int = 50) -> List[Dict]: """Get chat history for a session""" conn = sqlite3.connect(DB_PATH) @@ -684,16 +716,18 @@ async def auth_set_token(payload: TokenSubmit): if not token: return {"status": "error", "message": "Token cannot be empty"} + # Auto-detect token type from prefix if token.startswith("sk-ant-oat"): token_type = "oauth_token" elif token.startswith("sk-ant-api"): token_type = "api_key" + else: + return { + "status": "error", + "message": f"Invalid token format. Expected 'sk-ant-oat01-...' (setup token) or 'sk-ant-api03-...' (API key). Got: '{token[:12]}...'" + } - TOKEN_FILE.write_text(json.dumps({ - "type": token_type, "token": token, - "saved_at": datetime.now().isoformat() - })) - + # Set env vars BEFORE saving so we can test if token_type == "oauth_token": os.environ["CLAUDE_CODE_OAUTH_TOKEN"] = token os.environ.pop("ANTHROPIC_API_KEY", None) @@ -701,26 +735,74 @@ async def auth_set_token(payload: TokenSubmit): os.environ["ANTHROPIC_API_KEY"] = token os.environ.pop("CLAUDE_CODE_OAUTH_TOKEN", None) - # Quick verify + # Actually verify by making a real API call (not just checking auth status which only looks at env vars) try: env = _get_claude_env() proc = await asyncio.create_subprocess_exec( - "claude", "auth", "status", "--json", + "claude", "-p", "--output-format", "json", "--no-session-persistence", + "--max-budget-usd", "0.01", + "Reply with just the word VERIFIED", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.DEVNULL, env=env ) - stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=10) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=30) output = stdout.decode(errors="replace").strip() - data = json.loads(output) - if data.get("loggedIn"): - return { - "status": "logged_in", "message": "Token saved and verified!", - "account": data.get("email"), "auth_method": data.get("authMethod") - } - except Exception as e: - pass - return {"status": "token_saved", "message": "Token saved. Run a test chat to verify it works with the API."} + try: + data = json.loads(output) + result = data.get("result", "") + is_error = data.get("is_error", False) + + if is_error: + # Token was rejected — don't save it + os.environ.pop("CLAUDE_CODE_OAUTH_TOKEN", None) + os.environ.pop("ANTHROPIC_API_KEY", None) + return { + "status": "error", + "message": f"Token rejected by Claude API: {result}" + } + else: + # Token works! Save it. + TOKEN_FILE.write_text(json.dumps({ + "type": token_type, "token": token, + "saved_at": datetime.now().isoformat() + })) + return { + "status": "logged_in", + "message": "Token verified and saved! Claude is ready.", + "auth_method": token_type + } + except json.JSONDecodeError: + # Couldn't parse JSON but got some output + if proc.returncode == 0: + TOKEN_FILE.write_text(json.dumps({ + "type": token_type, "token": token, + "saved_at": datetime.now().isoformat() + })) + return { + "status": "logged_in", + "message": "Token appears to work. Saved!", + "auth_method": token_type + } + else: + os.environ.pop("CLAUDE_CODE_OAUTH_TOKEN", None) + os.environ.pop("ANTHROPIC_API_KEY", None) + return { + "status": "error", + "message": f"Token verification failed: {output or stderr.decode(errors='replace')}" + } + + except asyncio.TimeoutError: + # If it timed out, it probably got through auth at least + TOKEN_FILE.write_text(json.dumps({ + "type": token_type, "token": token, + "saved_at": datetime.now().isoformat() + })) + return {"status": "token_saved", "message": "Token saved but verification timed out. Try sending a chat message to confirm."} + except Exception as e: + os.environ.pop("CLAUDE_CODE_OAUTH_TOKEN", None) + os.environ.pop("ANTHROPIC_API_KEY", None) + return {"status": "error", "message": f"Token verification error: {str(e)}"} @app.post("/api/auth/logout") @@ -745,22 +827,39 @@ async def auth_logout(): @app.post("/api/chat/send") async def chat_send(msg: ChatMessage): - """Send a message to Claude and get a response (non-streaming)""" + """Send a message to Claude and get a response (non-streaming). + + Uses --output-format json to capture the Claude session_id for + conversation continuity via --resume on subsequent messages. + """ session_id = msg.session_id or str(uuid.uuid4()) # Save user message save_chat_message(session_id, "user", msg.message, model=msg.model) env = _get_claude_env() + + # Check if we have an existing Claude session to resume + claude_sid = _get_claude_session_id(session_id) + cmd = _build_claude_cmd( prompt=msg.message, model=msg.model, tools=msg.tools, system_prompt=msg.system_prompt, - session_id=session_id + resume_session_id=claude_sid # None for first message, session_id for follow-ups ) + # Insert --output-format json after "-p" so we can parse the session_id try: + p_idx = cmd.index("-p") + cmd.insert(p_idx + 1, "--output-format") + cmd.insert(p_idx + 2, "json") + except ValueError: + pass + + try: + logger.info(f"Chat send cmd: {' '.join(cmd[:6])}...") process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, @@ -771,16 +870,35 @@ async def chat_send(msg: ChatMessage): stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=120) - output = stdout.decode(errors="replace").strip() if stdout else "" + raw_output = stdout.decode(errors="replace").strip() if stdout else "" error = stderr.decode(errors="replace").strip() if stderr else "" - # Filter warnings + # Filter warnings from stderr error_lines = [l for l in error.split("\n") if l.strip() and "stdin" not in l.lower() and "warning" not in l.lower()] filtered_error = "\n".join(error_lines) + # Try to parse JSON output for session_id and result + output = raw_output + try: + data = json.loads(raw_output) + output = data.get("result", raw_output) + # Capture Claude's session ID for future --resume calls + new_claude_sid = data.get("session_id") + if new_claude_sid: + _claude_session_map[session_id] = new_claude_sid + logger.info(f"Mapped session {session_id[:8]}... -> Claude {new_claude_sid[:8]}...") + # Check if Claude reported an error + if data.get("is_error"): + save_chat_message(session_id, "error", output, metadata={"claude_session_id": new_claude_sid}) + return {"session_id": session_id, "response": output, "status": "error"} + except (json.JSONDecodeError, TypeError): + pass + if process.returncode == 0 and output: - save_chat_message(session_id, "assistant", output, model=msg.model) + new_claude_sid = _claude_session_map.get(session_id) + save_chat_message(session_id, "assistant", output, model=msg.model, + metadata={"claude_session_id": new_claude_sid} if new_claude_sid else None) return { "session_id": session_id, "response": output, @@ -805,7 +923,12 @@ async def chat_send(msg: ChatMessage): @app.websocket("/api/chat/ws/{session_id}") async def chat_websocket(websocket: WebSocket, session_id: str): - """WebSocket endpoint for streaming chat""" + """WebSocket endpoint for streaming chat. + + Uses --output-format stream-json for streaming, and captures the + Claude session_id from the final 'result' message for --resume on + subsequent messages in the same chat session. + """ await websocket.accept() try: @@ -825,17 +948,30 @@ async def chat_websocket(websocket: WebSocket, session_id: str): save_chat_message(session_id, "user", user_message, model=model) env = _get_claude_env() + + # Check for existing Claude session to resume + claude_sid = _get_claude_session_id(session_id) + cmd = _build_claude_cmd( prompt=user_message, model=model, tools=tools, system_prompt=system_prompt, - session_id=session_id + resume_session_id=claude_sid ) + # Use stream-json for streaming output + try: + p_idx = cmd.index("-p") + cmd.insert(p_idx + 1, "--output-format") + cmd.insert(p_idx + 2, "stream-json") + except ValueError: + pass + await websocket.send_json({"type": "status", "content": "thinking"}) try: + logger.info(f"Chat WS cmd: {' '.join(cmd[:6])}...") process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, @@ -846,20 +982,67 @@ async def chat_websocket(websocket: WebSocket, session_id: str): _chat_sessions[session_id] = process - # Stream stdout + # Stream stdout line by line (stream-json emits one JSON object per line) full_output = "" while True: - line = await asyncio.wait_for( - process.stdout.readline(), timeout=120 - ) + try: + line = await asyncio.wait_for( + process.stdout.readline(), timeout=180 + ) + except asyncio.TimeoutError: + await websocket.send_json({"type": "error", "content": "Response timeout (>180s)"}) + break if not line: break - text = line.decode(errors="replace") - full_output += text - await websocket.send_json({ - "type": "chunk", - "content": text - }) + text = line.decode(errors="replace").strip() + if not text: + continue + + # Try to parse stream-json events + try: + event = json.loads(text) + event_type = event.get("type", "") + + if event_type == "assistant": + # Assistant message with content blocks + content = event.get("message", {}).get("content", []) + for block in content: + if block.get("type") == "text": + chunk = block.get("text", "") + full_output += chunk + await websocket.send_json({"type": "chunk", "content": chunk}) + elif event_type == "content_block_delta": + delta = event.get("delta", {}) + if delta.get("type") == "text_delta": + chunk = delta.get("text", "") + full_output += chunk + await websocket.send_json({"type": "chunk", "content": chunk}) + elif event_type == "result": + # Final result — capture session_id for --resume + result_text = event.get("result", "") + new_claude_sid = event.get("session_id") + is_error = event.get("is_error", False) + + if new_claude_sid: + _claude_session_map[session_id] = new_claude_sid + logger.info(f"WS mapped {session_id[:8]}... -> Claude {new_claude_sid[:8]}...") + + if is_error: + error_content = result_text or full_output or "Claude returned an error" + save_chat_message(session_id, "error", error_content, + metadata={"claude_session_id": new_claude_sid}) + await websocket.send_json({"type": "error", "content": error_content}) + elif result_text and not full_output: + # If we got no streaming chunks but have a result + full_output = result_text + await websocket.send_json({"type": "chunk", "content": result_text}) + else: + # Unknown event type — might contain text content + pass + except json.JSONDecodeError: + # Not JSON — treat as raw text output + full_output += text + "\n" + await websocket.send_json({"type": "chunk", "content": text + "\n"}) await process.wait() @@ -867,32 +1050,35 @@ async def chat_websocket(websocket: WebSocket, session_id: str): stderr_text = stderr_data.decode(errors="replace") if stderr_data else "" if process.returncode == 0 and full_output.strip(): - save_chat_message(session_id, "assistant", full_output.strip(), model=model) + new_claude_sid = _claude_session_map.get(session_id) + save_chat_message(session_id, "assistant", full_output.strip(), model=model, + metadata={"claude_session_id": new_claude_sid} if new_claude_sid else None) await websocket.send_json({ "type": "done", "content": full_output.strip() }) - else: - error_msg = stderr_text.strip() or full_output.strip() or "No response" - error_lines = [l for l in error_msg.split("\n") + elif not full_output.strip(): + # No output at all + error_lines = [l for l in stderr_text.split("\n") if l.strip() and "stdin" not in l.lower()] - clean_error = "\n".join(error_lines) or error_msg + clean_error = "\n".join(error_lines) or "No response from Claude" save_chat_message(session_id, "error", clean_error) + await websocket.send_json({"type": "error", "content": clean_error}) + else: + # Had output but non-zero return code — still send done await websocket.send_json({ - "type": "error", - "content": clean_error + "type": "done", + "content": full_output.strip() }) - except asyncio.TimeoutError: - await websocket.send_json({"type": "error", "content": "Response timeout"}) except Exception as e: + logger.error(f"Chat WS error: {e}") await websocket.send_json({"type": "error", "content": str(e)}) finally: _chat_sessions.pop(session_id, None) except WebSocketDisconnect: logger.info(f"Chat WebSocket disconnected: {session_id}") - # Kill process if still running proc = _chat_sessions.pop(session_id, None) if proc and proc.returncode is None: try: