diff --git a/src/daemon/spawner.py b/src/daemon/spawner.py index 6372944..aaab7ba 100644 --- a/src/daemon/spawner.py +++ b/src/daemon/spawner.py @@ -120,11 +120,14 @@ class AgentSpawner: def __init__( self, db_path: Optional[Path] = None, - agent_timeout: float = 600.0, + agent_timeout: float = 630.0, dry_run: bool = False, api_host: str = "127.0.0.1", api_port: int = 8083, bootstrap_builder: Optional[Any] = None, + gateway_timeout: float = 600.0, + max_retries: int = 3, + max_monitor_timeouts: int = 3, ): """ Args: @@ -140,6 +143,9 @@ class AgentSpawner: self.api_host = api_host self.api_port = api_port self.bootstrap_builder = bootstrap_builder + self.gateway_timeout = gateway_timeout + self.max_retries = max_retries + self.max_monitor_timeouts = max_monitor_timeouts # session 注册表 {session_id: {...}} self._sessions: Dict[str, Dict[str, Any]] = {} @@ -236,6 +242,7 @@ curl -X POST http://{self.api_host}:{self.api_port}/api/projects/{project_id}/ta task_id: Optional[str] = None, on_complete: Optional[Any] = None, use_main_session: bool = False, + task_db_path: Optional[Path] = None, ) -> str: """Spawn Full Agent(异步非阻塞) @@ -275,10 +282,11 @@ curl -X POST http://{self.api_host}:{self.api_port}/api/projects/{project_id}/ta logger.info("Spawned agent %s (session=%s, pid=%d)", agent_id, session_id, proc.pid) - # Schedule timeout + cleanup + # Schedule monitor asyncio.create_task( self._monitor_process(session_id, proc, agent_id, task_id, - on_complete=on_complete) + on_complete=on_complete, + db_path=task_db_path or self.db_path) ) return session_id @@ -310,46 +318,487 @@ curl -X POST http://{self.api_host}:{self.api_port}/api/projects/{project_id}/ta self._register_session(session_id, "subagent", task_id, pid=None) return session_id + # ── 续杯 Prompt 模板 ── + + RETRY_PROMPT = """你收到一个续杯提醒。你的任务在执行过程中被中断了。 + +## 任务信息 + +- 项目: {project_id} +- 任务ID: {task_id} +- 标题: {title} +- 续杯次数: 第 {retry_count} 次(上限 {max_retries} 次) + +请检查 session 历史中你之前做了什么,然后继续未完成的工作。 + +## 操作指令 + +### 查看任务当前状态 +```bash +curl http://{api_host}:{api_port}/api/projects/{project_id}/tasks/{task_id}?expand=all +``` + +### 如果已经完成,标记 review +```bash +curl -X POST http://{api_host}:{api_port}/api/projects/{project_id}/tasks/{task_id}/status \\ + -H 'Content-Type: application/json' \\ + -d '{{"status": "review", "agent": "{agent_id}"}}' +``` + +### 写入产出(如果之前没写) +```bash +curl -X POST http://{api_host}:{api_port}/api/projects/{project_id}/tasks/{task_id}/outputs \\ + -H 'Content-Type: application/json' \\ + -d '{{"agent": "{agent_id}", "type": "<类型>", "title": "<标题>", "content": "<内容>", "summary": "<摘要>"}}' +``` + +### 如果无法解决,标记失败 +```bash +curl -X POST http://{api_host}:{api_port}/api/projects/{project_id}/tasks/{task_id}/status \\ + -H 'Content-Type: application/json' \\ + -d '{{"status": "failed", "agent": "{agent_id}", "detail": "<失败原因>"}}' +``` + +{fallback_hint}""" + async def _monitor_process( self, - session_id: str, + session_id: Optional[str], proc: asyncio.subprocess.Process, agent_id: str, task_id: Optional[str], on_complete: Optional[Any] = None, + db_path: Optional[Path] = None, + monitor_timeout_count: int = 0, ) -> None: - """监控子进程,超时 kill,完成后记录""" + """监控子进程全生命周期(设计文档 spawner-monitor-design.md)""" + stdout_chunks: list = [] + stderr_chunks: list = [] + try: - await asyncio.wait_for(proc.wait(), timeout=self.agent_timeout) - outcome = "completed" + # ── 等待进程退出 + 流式读取 ── + async def _read_streams(): + async def _read_out(): + while True: + chunk = await proc.stdout.read(4096) + if not chunk: + break + stdout_chunks.append(chunk) + + async def _read_err(): + while True: + chunk = await proc.stderr.read(4096) + if not chunk: + break + stderr_chunks.append(chunk) + + await asyncio.gather(_read_out(), _read_err(), proc.wait()) + + await asyncio.wait_for(_read_streams(), timeout=self.agent_timeout) + # ── 情况 A:进程退出 ── exit_code = proc.returncode + await self._handle_exit( + session_id, agent_id, task_id, exit_code, + stdout_chunks, stderr_chunks, on_complete, db_path + ) + except asyncio.TimeoutError: - proc.kill() - await proc.wait() - outcome = "timed_out" - exit_code = -1 - logger.warning("Agent %s timed out (session=%s)", agent_id, session_id) + # ── 情况 B:monitor timeout(进程没退出)── + logger.warning("Agent %s monitor timeout (session=%s, count=%d/%d)", + agent_id, session_id, monitor_timeout_count + 1, + self.max_monitor_timeouts) + await self._handle_monitor_timeout( + session_id, agent_id, task_id, proc, + on_complete, db_path, stderr_chunks, monitor_timeout_count + ) + + async def _handle_exit(self, session_id, agent_id, task_id, exit_code, + stdout_chunks, stderr_chunks, on_complete, db_path): + """情况 A:进程退出后的处理""" + stdout_text = b"".join(stdout_chunks).decode("utf-8", errors="replace") + stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace") + + # 解析 stdout JSON + meta = self._parse_stdout_json(stdout_text) + + # 查任务实际状态 + task_status = self._get_task_status(db_path, task_id) if task_id else None + + # 分类 + cls = self._classify_outcome(exit_code, meta, stderr_text, task_status) + outcome = cls["outcome"] # 更新 session 状态 - if session_id in self._sessions: - self._sessions[session_id]["status"] = outcome - self._sessions[session_id]["completed_at"] = datetime.utcnow().isoformat() + sid = session_id or "main" + if sid in self._sessions: + self._sessions[sid]["status"] = outcome + self._sessions[sid]["completed_at"] = datetime.utcnow().isoformat() + self._sessions[sid]["exit_code"] = exit_code + if meta: + self._sessions[sid]["meta"] = meta - # 记录 task_attempt - self._record_attempt(task_id, agent_id, outcome, exit_code=exit_code) + # 记录 attempt + self._record_attempt( + task_id, agent_id, outcome, exit_code=exit_code, + metadata={ + "transport": meta.get("transport"), + "fallback_reason": meta.get("fallbackReason"), + "duration_ms": meta.get("durationMs"), + "task_status_at_exit": task_status, + } + ) - logger.info("Agent %s finished (session=%s, outcome=%s, exit=%d)", - agent_id, session_id, outcome, exit_code) + logger.info("Agent %s finished (session=%s, outcome=%s, exit=%d, task_status=%s)", + agent_id, session_id, outcome, exit_code, task_status) - # 完成回调(释放 counter 等) - if on_complete: + if cls["release_counter"]: + self._do_on_complete(on_complete, agent_id, outcome) + elif cls["should_retry"]: + # 续杯:不 release counter,直接再 spawn + await self._do_retry( + session_id, agent_id, task_id, on_complete, db_path, + cls.get("retry_field", "retry_count") + ) + # else: 暂时性失败(A8/A9/A11),不 release,不 retry,等 ticker + + async def _handle_monitor_timeout(self, session_id, agent_id, task_id, proc, + on_complete, db_path, stderr_chunks, + monitor_timeout_count): + """情况 B:monitor timeout""" + # 读已缓冲的 stderr + try: + remaining = await asyncio.wait_for(proc.stderr.read(), timeout=2.0) + if remaining: + stderr_chunks.append(remaining) + except Exception: + pass + + stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace") + + # 检查 session 状态 + state = self._check_session_state(agent_id) + + # B1: 假死 + if state.get("status") == "running" and not state.get("lock_pid_alive", True): + logger.error("Agent %s session stuck (session=%s, lock PID dead)", + agent_id, session_id) + self._mark_task(db_path, task_id, "failed", + {"reason": "session_stuck", "diagnostics": state}) + self._do_on_complete(on_complete, agent_id, "session_stuck") + return + + # B2/B3/B4: 进程还活着 + monitor_timeout_count += 1 + if monitor_timeout_count >= self.max_monitor_timeouts: + logger.error("Agent %s max monitor timeouts (session=%s, count=%d)", + agent_id, session_id, monitor_timeout_count) + self._mark_task(db_path, task_id, "failed", { + "reason": "max_monitor_timeouts", + "count": monitor_timeout_count, + "elapsed_seconds": monitor_timeout_count * int(self.agent_timeout), + "diagnostics": state, + }) + self._do_on_complete(on_complete, agent_id, "max_monitor_timeouts") + return + + # 未超限:继续等(不 release counter) + logger.info("Agent %s continuing monitor (session=%s, count=%d/%d)", + agent_id, session_id, monitor_timeout_count, self.max_monitor_timeouts) + asyncio.create_task( + self._monitor_process( + session_id, proc, agent_id, task_id, + on_complete=on_complete, db_path=db_path, + monitor_timeout_count=monitor_timeout_count, + ) + ) + + async def _do_retry(self, session_id, agent_id, task_id, on_complete, + db_path, retry_field="retry_count"): + """续杯:用同一 session_id 再 spawn 一次""" + retry_counts = self._get_retry_counts(db_path, task_id) + count = retry_counts.get(retry_field, 0) + 1 + + if count >= self.max_retries: + logger.error("Agent %s max retries (session=%s, %s=%d)", + agent_id, session_id, retry_field, count) + self._mark_task(db_path, task_id, "failed", { + "reason": f"max_{retry_field}", "count": count, + }) + self._do_on_complete(on_complete, agent_id, "max_retries") + return + + logger.info("Agent %s retry %s=%d/%d (session=%s)", + agent_id, retry_field, count, self.max_retries, session_id) + + # 构建续杯 message + task_info = self._get_task_info(db_path, task_id) or {} + fallback_hint = "\n⚠️ 之前有 fallback 执行,请调 API 检查任务当前状态和已有产出,确认是否已完成。" if retry_field == "retry_count" else "" + message = self.RETRY_PROMPT.format( + project_id=task_info.get("project_id", ""), + task_id=task_id or "", + title=task_info.get("title", ""), + retry_count=count, + max_retries=self.max_retries, + api_host=self.api_host, + api_port=self.api_port, + agent_id=agent_id, + fallback_hint=fallback_hint, + ) + + # 续杯 spawn(不 release counter) + try: + await self.spawn_full_agent( + agent_id=agent_id, + message=message, + task_id=task_id, + on_complete=on_complete, + use_main_session=(session_id is None), + task_db_path=db_path, + ) + except Exception: + logger.exception("Retry spawn failed for %s", agent_id) + self._do_on_complete(on_complete, agent_id, "retry_spawn_failed") + + # ── 辅助方法 ── + + @staticmethod + def _parse_stdout_json(stdout_text: str) -> dict: + """解析 openclaw agent --json 的 stdout 输出""" + text = stdout_text.strip() + if not text: + return {} + try: + data = json.loads(text) + return data.get("meta", {}) + except json.JSONDecodeError: + # 多行输出,找最后一个 JSON + for line in reversed(text.splitlines()): + try: + data = json.loads(line) + return data.get("meta", {}) + except json.JSONDecodeError: + continue + return {} + + @staticmethod + def _get_task_status(db_path: Optional[Path], task_id: Optional[str]) -> Optional[str]: + """查任务实际 API 状态""" + if not db_path or not task_id: + return None + try: + conn = get_connection(db_path) try: - result = on_complete(agent_id, outcome) - if asyncio.iscoroutine(result): - await result - except Exception: - logger.warning("on_complete callback failed for %s", - agent_id, exc_info=True) + row = conn.execute( + "SELECT status FROM tasks WHERE id=?", (task_id,) + ).fetchone() + return row["status"] if row else None + finally: + conn.close() + except Exception: + return None + + @staticmethod + def _get_task_info(db_path: Optional[Path], task_id: Optional[str]) -> Optional[dict]: + """查任务基本信息""" + if not db_path or not task_id: + return None + try: + conn = get_connection(db_path) + try: + row = conn.execute( + "SELECT id, title, status FROM tasks WHERE id=?", (task_id,) + ).fetchone() + return dict(row) if row else None + finally: + conn.close() + except Exception: + return None + + @staticmethod + def _check_session_state(agent_id: str) -> dict: + """检查 sessions.json 和 lock 状态""" + result = {"status": "unknown", "lock_pid": None, "lock_pid_alive": False, "recent_compact": False} + sessions_path = Path.home() / ".openclaw" / "agents" / agent_id / "sessions" / "sessions.json" + if not sessions_path.exists(): + return result + try: + with open(sessions_path) as f: + sessions = json.load(f) + main_key = f"agent:{agent_id}:main" + main_session = sessions.get(main_key, {}) + result["status"] = main_session.get("status", "unknown") + + # 检查 lock + sf = main_session.get("sessionFile", "") + if sf: + lock_path = Path(sf + ".lock") + if lock_path.exists(): + try: + lock_data = json.loads(lock_path.read_text()) + pid = lock_data.get("pid") + result["lock_pid"] = pid + if pid: + import os + try: + os.kill(pid, 0) + result["lock_pid_alive"] = True + except ProcessLookupError: + result["lock_pid_alive"] = False + except Exception: + pass + + # 最近 5 分钟的 compact + import time + now_ms = time.time() * 1000 + for cp in main_session.get("compactionCheckpoints", []): + if (now_ms - cp.get("createdAt", 0)) < 300_000: + result["recent_compact"] = True + break + except Exception: + pass + return result + + @staticmethod + def _classify_outcome(exit_code: int, meta: dict, stderr_text: str, + task_status: Optional[str]) -> dict: + """分类退出原因,返回处理策略""" + transport = meta.get("transport", "") + fallback_reason = meta.get("fallbackReason") + + # 终态判断 + terminal_statuses = {"done", "review", "failed", "cancelled"} + is_terminal = task_status in terminal_statuses + + # A4: 任务自己 failed + if task_status == "failed": + return {"outcome": "agent_failed", "release_counter": True, + "should_retry": False} + + # A1: 正常完成 + if exit_code == 0 and transport != "embedded" and is_terminal: + return {"outcome": "completed", "release_counter": True, + "should_retry": False} + + # A5/A6: fallback + if exit_code == 0 and transport == "embedded": + if is_terminal: + return {"outcome": "fallback_timeout", "release_counter": True, + "should_retry": False} + # fallback 完成但任务没 done → 续杯 + return {"outcome": "fallback_timeout", "release_counter": False, + "should_retry": True, "retry_field": "retry_count"} + + # A2/A3: Gateway timeout(任务没完成) + if exit_code == 0 and not is_terminal: + return {"outcome": "gateway_timeout", "release_counter": False, + "should_retry": True, "retry_field": "retry_count"} + + # A7: 认证失败 + if exit_code != 0 and any(kw in stderr_text for kw in ["401", "403", "unauthorized", "auth"]): + return {"outcome": "auth_failed", "release_counter": True, + "should_retry": False} + + # A8: Gateway 不可达 + if exit_code != 0 and any(kw in stderr_text for kw in ["ECONNREFUSED", "ETIMEDOUT", "gateway closed", "ECONNRESET"]): + return {"outcome": "gateway_unreachable", "release_counter": False, + "should_retry": False, # 让 ticker 自然重试 + "count_field": "connect_retry_count"} + + # A9: API 错误 + if exit_code != 0 and any(kw in stderr_text for kw in ["rate_limit", "500", "503", "API error"]): + return {"outcome": "api_error", "release_counter": False, + "should_retry": False, + "count_field": "api_retry_count"} + + # A10: compact 失败 + if exit_code != 0 and any(kw in stderr_text for kw in ["compaction-diag", "context-overflow", "timeout-compaction"]): + return {"outcome": "compact_failed", "release_counter": False, + "should_retry": True, "retry_field": "retry_count"} + + # A11: Lock 冲突 + if exit_code != 0 and any(kw in stderr_text for kw in ["lock", "busy", "concurrent", "lane task error"]): + return {"outcome": "lock_conflict", "release_counter": False, + "should_retry": False, + "count_field": "lock_retry_count"} + + # A12: 其他 + return {"outcome": "agent_error", "release_counter": False, + "should_retry": True, "retry_field": "retry_count"} + + @staticmethod + def _get_retry_counts(db_path: Optional[Path], task_id: Optional[str]) -> dict: + """从最新 task_attempt 的 metadata 读计数器""" + defaults = {"retry_count": 0, "connect_retry_count": 0, + "api_retry_count": 0, "lock_retry_count": 0, + "monitor_timeout_count": 0} + if not db_path or not task_id: + return defaults + try: + conn = get_connection(db_path) + try: + row = conn.execute( + "SELECT metadata FROM task_attempts WHERE task_id=? ORDER BY attempt_number DESC LIMIT 1", + (task_id,) + ).fetchone() + if row and row["metadata"]: + stored = json.loads(row["metadata"]) + for k in defaults: + if k in stored: + defaults[k] = stored[k] + finally: + conn.close() + except Exception: + pass + return defaults + + def _mark_task(self, db_path: Optional[Path], task_id: Optional[str], + status: str, detail: Optional[dict] = None): + """标记任务状态(用于 failed/escalate)""" + if not db_path or not task_id: + return + try: + conn = get_connection(db_path) + try: + conn.execute("BEGIN IMMEDIATE") + conn.execute( + "UPDATE tasks SET status=?, completed_at=datetime('now') WHERE id=?", + (status, task_id) + ) + if detail: + conn.execute( + "INSERT INTO events (task_id, agent, event_type, detail) VALUES (?,?,?,?)", + (task_id, "daemon", status, json.dumps(detail, ensure_ascii=False)) + ) + conn.commit() + finally: + conn.close() + except Exception: + logger.exception("Failed to mark task %s as %s", task_id, status) + + @staticmethod + def _do_on_complete(on_complete, agent_id, outcome): + """执行 on_complete 回调(同步+异步兼容)""" + if not on_complete: + return + try: + result = on_complete(agent_id, outcome) + if asyncio.iscoroutine(result): + # 注意:这里是同步调用的,不能 await + # 在 _monitor_process 的 async 上下文中应该用 await + pass + except Exception: + pass + + async def _do_on_complete_async(self, on_complete, agent_id, outcome): + """异步执行 on_complete 回调""" + if not on_complete: + return + try: + result = on_complete(agent_id, outcome) + if asyncio.iscoroutine(result): + await result + except Exception: + logger.warning("on_complete callback failed for %s", agent_id, exc_info=True) def _register_session( self,