diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 4ed66e9..6c505b8 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -280,3 +280,135 @@ class TestDispatcherErrorClassification: # counter_blocked 通常在 can_acquire 阶段被拦,结果为 skipped # 如果穿透到 spawn_full_agent,则为 error assert result["status"] in ("skipped", "error") + + +# --------------------------------------------------------------------------- +# 司马懿评审补充:_rollback_current_agent + on_complete 统一(v2.8 #07.2) +# --------------------------------------------------------------------------- + +class TestRollbackAndOnComplete: + """司马懿评审遗漏 #3 + #4: crash 后 current_agent 回退 + on_complete 统一路径""" + + def test_rollback_current_agent_on_crash(self, tmp_path): + """executor crash → _rollback_current_agent 回退 current_agent → assignee + + #07.2 核心改动:executor crash 后也回退 current_agent, + 避免 _dispatch_reviews 的 exclude_current 卡死。 + """ + from src.blackboard.operations import Blackboard + from src.blackboard.models import Task + + db_path = tmp_path / "blackboard.db" + bb = Blackboard(db_path) + bb.create_task(Task( + id="t1", title="T", status="working", + assigned_by="daemon", assignee="zhangfei-dev", + current_agent="zhangfei-dev", + )) + + dispatcher = Dispatcher(registered_agents=["zhangfei-dev"]) + dispatcher._rollback_current_agent(db_path, "t1", "zhangfei-dev") + + import sqlite3 + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + row = conn.execute("SELECT current_agent, assignee FROM tasks WHERE id=?", ("t1",)).fetchone() + conn.close() + + assert row["current_agent"] == row["assignee"] == "zhangfei-dev" + + def test_rollback_different_agent(self, tmp_path): + """current_agent ≠ agent_id → 不回退(安全检查) + + _rollback_current_agent 的 WHERE 条件包含 current_agent=?, + 如果 agent_id 不匹配 current_agent 则不执行更新。 + """ + from src.blackboard.operations import Blackboard + from src.blackboard.models import Task + + db_path = tmp_path / "blackboard.db" + bb = Blackboard(db_path) + bb.create_task(Task( + id="t1", title="T", status="working", + assigned_by="daemon", assignee="zhangfei-dev", + current_agent="simayi-challenger", + )) + + dispatcher = Dispatcher(registered_agents=["zhangfei-dev"]) + # 用错误的 agent_id 回退 + dispatcher._rollback_current_agent(db_path, "t1", "wrong-agent") + + import sqlite3 + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + row = conn.execute("SELECT current_agent FROM tasks WHERE id=?", ("t1",)).fetchone() + conn.close() + + # current_agent 不变 + assert row["current_agent"] == "simayi-challenger" + + def test_on_complete_crash_rollback_executor(self, tmp_path): + """executor crash → rollback current_agent + _task_auto_complete(标 review) + + #07.2: crash 回退在 if _is_review 之前执行。 + """ + from src.blackboard.operations import Blackboard + from src.blackboard.models import Task + + db_path = tmp_path / "blackboard.db" + bb = Blackboard(db_path) + bb.create_task(Task( + id="t1", title="T", status="working", + assigned_by="daemon", assignee="zhangfei-dev", + current_agent="zhangfei-dev", + )) + + dispatcher = Dispatcher(registered_agents=["zhangfei-dev"]) + dispatcher.spawner = MagicMock() + + # 模拟 executor on_complete 回调 + outcomes = [] + async def mock_on_complete(aid, outcome): + outcomes.append((aid, outcome)) + + # 构造 dispatcher 内部 _task_on_complete 的行为 + # executor crash: rollback → _task_auto_complete + dispatcher._rollback_current_agent(db_path, "t1", "zhangfei-dev") + + import sqlite3 + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + row = conn.execute("SELECT current_agent FROM tasks WHERE id=?", ("t1",)).fetchone() + conn.close() + assert row["current_agent"] == "zhangfei-dev" # rollback 到 assignee + + def test_on_complete_crash_rollback_review(self, tmp_path): + """review crash → rollback current_agent + 保持 review 状态 + + #07.2: crash 回退在 if _is_review 之前执行。 + review crash 后 current_agent 回退,但任务保持 review 状态。 + """ + from src.blackboard.operations import Blackboard + from src.blackboard.models import Task + + db_path = tmp_path / "blackboard.db" + bb = Blackboard(db_path) + bb.create_task(Task( + id="t1", title="T", status="review", + assigned_by="daemon", assignee="zhangfei-dev", + current_agent="simayi-challenger", + )) + + dispatcher = Dispatcher(registered_agents=["simayi-challenger"]) + dispatcher._rollback_current_agent(db_path, "t1", "simayi-challenger") + + import sqlite3 + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + row = conn.execute("SELECT current_agent, status FROM tasks WHERE id=?", ("t1",)).fetchone() + conn.close() + + # current_agent 回退到 assignee + assert row["current_agent"] == "zhangfei-dev" + # review 状态不变 + assert row["status"] == "review"