"""base_task_handler.py — Task type handler 基类。 收敛合理的共性能力(crash rollback + verify + mark + notify), 子类只实现差异点。 """ from __future__ import annotations import logging from dataclasses import dataclass from pathlib import Path from typing import Optional from src.daemon.prompt_composer import PromptContext, PromptComposer, PromptSection from src.blackboard.db import get_connection logger = logging.getLogger("moziplus-v2.handler") @dataclass class VerifyResult: """验证结果""" passed: bool reason: str # "has_output" / "no_reply" / "no_signal" / ... evidence: str # "output_count=1, comment_count=0" can_retry: bool = True retry_count: int = 0 class BaseTaskHandler: """所有 task type handler 的基类。 职责:L2 引擎注入层的业务逻辑——prompt 构建、完成验证、状态标记。 不管:进程生命周期、exit 分类、重试决策(这些归 spawner)。 """ # crash 类 outcome(进程级异常,需要 rollback) CRASH_OUTCOMES = frozenset({ "crashed", "compact_failed", "process_crash", "session_stuck", "compact_hanging", }) task_type: str = "" virtual_project: Optional[str] = None display_name: str = "" # 中文展示名(ticker 扫描日志用) # === 子类必须实现 === def build_prompt(self, context: PromptContext) -> str: """构建 L2 prompt(通过 PromptComposer 拼 section)。子类实现。""" raise NotImplementedError def verify_completion(self, task_id: str, db_path: Path) -> VerifyResult: """验证任务完成质量。每个 handler 自己的验证逻辑。子类实现。""" raise NotImplementedError def target_success_status(self) -> str: """验证通过后的目标状态。task='review', mail/toolchain='done'""" return "review" def get_sections(self) -> list[PromptSection]: """返回此 handler 的 prompt section 列表。子类实现。""" return [] # === 基类提供统一流程 === def pre_spawn(self, task_id: str, db_path: Path) -> bool: """spawn 前业务准备。默认 True。 mail/toolchain override 为 auto_working。""" return True def post_complete(self, task_id: str, agent_id: str, outcome: str, db_path: Path) -> None: """spawn 完成后的业务处理。统一 4 步流程: 1. crash 处理 → rollback current_agent 2. verify → 验证产出 3. mark → 标目标状态 4. notify → 失败时 on_failure """ # 1. crash 处理(基类提供,所有 handler 继承) if outcome in self.CRASH_OUTCOMES: self._rollback_current_agent(db_path, task_id, agent_id) return # 2. verify result = self.verify_completion(task_id, db_path) # 3. mark if result.passed: self._mark_task_status(db_path, task_id, self.target_success_status()) logger.info("Task %s: verify passed (%s), marked %s", task_id, result.reason, self.target_success_status()) else: # 4. notify self.on_failure(task_id, agent_id, db_path, result) def on_failure(self, task_id: str, agent_id: str, db_path: Path, verify: VerifyResult) -> None: """验证失败处理。默认:标 failed。子类可 override。""" self._mark_task_status(db_path, task_id, "failed") logger.info("Task %s: verify failed (%s), marked failed", task_id, verify.reason) def check_completion(self, task_id: str, db_path: Path) -> bool: """ticker 级别的完成检查。默认:False。""" return False # === 内部工具方法 === def _rollback_current_agent(self, db_path: Path, task_id: str, agent_id: str) -> None: """crash 后回退 current_agent → assignee,避免 exclude_current 卡死。 从 dispatcher._rollback_current_agent 迁移。""" try: conn = get_connection(db_path) try: conn.execute( "UPDATE tasks SET current_agent = " "(SELECT assignee FROM tasks WHERE id=?) " "WHERE id=? AND current_agent=?", (task_id, task_id, agent_id) ) conn.commit() finally: conn.close() logger.info("Task %s: rolled back current_agent from %s to assignee", task_id, agent_id) except Exception as e: logger.warning("Task %s: failed to rollback current_agent: %s", task_id, e) def _mark_task_status(self, db_path: Path, task_id: str, status: str) -> None: """更新任务状态 + 写审计事件(带 3 次重试,防 SQLite DB 锁)。""" for attempt in range(3): try: conn = get_connection(db_path) try: conn.execute("BEGIN IMMEDIATE") old_row = conn.execute( "SELECT status FROM tasks WHERE id=?", (task_id,) ).fetchone() old_status = old_row["status"] if old_row else "unknown" conn.execute( "UPDATE tasks SET status=?, updated_at=datetime('now') WHERE id=?", (status, task_id), ) conn.execute( "INSERT INTO events (task_id, agent, event_type, payload) " "VALUES (?, 'handler', 'status_change', ?)", (task_id, f'{{"from": "{old_status}", "to": "{status}", ' f'"source": "{self.task_type}_handler"}}'), ) conn.commit() return finally: conn.close() except Exception as e: logger.warning("Handler: mark %s → %s attempt %d failed: %s", task_id, status, attempt + 1, e) logger.error("Handler: mark %s → %s all 3 attempts failed", task_id, status) def _auto_mark_working(self, task_id: str, db_path: Path) -> bool: """pending → working(mail/toolchain 通用)。""" try: conn = get_connection(db_path) try: conn.execute("BEGIN IMMEDIATE") row = conn.execute( "SELECT status FROM tasks WHERE id=?", (task_id,)).fetchone() if not row or row["status"] not in ("pending", "claimed"): logger.warning("Task %s: cannot mark working (status=%s)", task_id, row["status"] if row else "not found") return False conn.execute( "UPDATE tasks SET status='working', updated_at=datetime('now') " "WHERE id=?", (task_id,)) conn.commit() logger.info("Task %s: auto-marked working", task_id) return True finally: conn.close() except Exception as e: logger.error("Task %s: failed to mark working: %s", task_id, e) return False