diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 6c505b8..eb46bd7 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -303,17 +303,24 @@ class TestRollbackAndOnComplete: bb.create_task(Task( id="t1", title="T", status="working", assigned_by="daemon", assignee="zhangfei-dev", - current_agent="zhangfei-dev", )) + # 通过状态变更设置 current_agent(create_task 不持久化 current_agent) + conn = bb._conn() + try: + conn.execute("UPDATE tasks SET current_agent=? WHERE id=?", ("zhangfei-dev", "t1")) + conn.commit() + finally: + conn.close() + dispatcher = Dispatcher(registered_agents=["zhangfei-dev"]) dispatcher._rollback_current_agent(db_path, "t1", "zhangfei-dev") import sqlite3 - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - row = conn.execute("SELECT current_agent, assignee FROM tasks WHERE id=?", ("t1",)).fetchone() - conn.close() + conn2 = sqlite3.connect(str(db_path)) + conn2.row_factory = sqlite3.Row + row = conn2.execute("SELECT current_agent, assignee FROM tasks WHERE id=?", ("t1",)).fetchone() + conn2.close() assert row["current_agent"] == row["assignee"] == "zhangfei-dev" @@ -331,18 +338,25 @@ class TestRollbackAndOnComplete: bb.create_task(Task( id="t1", title="T", status="working", assigned_by="daemon", assignee="zhangfei-dev", - current_agent="simayi-challenger", )) + # 设置 current_agent 为 simayi-challenger + conn = bb._conn() + try: + conn.execute("UPDATE tasks SET current_agent=? WHERE id=?", ("simayi-challenger", "t1")) + conn.commit() + finally: + conn.close() + dispatcher = Dispatcher(registered_agents=["zhangfei-dev"]) # 用错误的 agent_id 回退 dispatcher._rollback_current_agent(db_path, "t1", "wrong-agent") import sqlite3 - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - row = conn.execute("SELECT current_agent FROM tasks WHERE id=?", ("t1",)).fetchone() - conn.close() + conn2 = sqlite3.connect(str(db_path)) + conn2.row_factory = sqlite3.Row + row = conn2.execute("SELECT current_agent FROM tasks WHERE id=?", ("t1",)).fetchone() + conn2.close() # current_agent 不变 assert row["current_agent"] == "simayi-challenger" @@ -360,26 +374,26 @@ class TestRollbackAndOnComplete: bb.create_task(Task( id="t1", title="T", status="working", assigned_by="daemon", assignee="zhangfei-dev", - current_agent="zhangfei-dev", )) + # 设置 current_agent + conn = bb._conn() + try: + conn.execute("UPDATE tasks SET current_agent=? WHERE id=?", ("zhangfei-dev", "t1")) + conn.commit() + finally: + conn.close() + dispatcher = Dispatcher(registered_agents=["zhangfei-dev"]) - dispatcher.spawner = MagicMock() - # 模拟 executor on_complete 回调 - outcomes = [] - async def mock_on_complete(aid, outcome): - outcomes.append((aid, outcome)) - - # 构造 dispatcher 内部 _task_on_complete 的行为 - # executor crash: rollback → _task_auto_complete + # executor crash: rollback current_agent dispatcher._rollback_current_agent(db_path, "t1", "zhangfei-dev") import sqlite3 - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - row = conn.execute("SELECT current_agent FROM tasks WHERE id=?", ("t1",)).fetchone() - conn.close() + conn2 = sqlite3.connect(str(db_path)) + conn2.row_factory = sqlite3.Row + row = conn2.execute("SELECT current_agent FROM tasks WHERE id=?", ("t1",)).fetchone() + conn2.close() assert row["current_agent"] == "zhangfei-dev" # rollback 到 assignee def test_on_complete_crash_rollback_review(self, tmp_path): @@ -396,17 +410,24 @@ class TestRollbackAndOnComplete: bb.create_task(Task( id="t1", title="T", status="review", assigned_by="daemon", assignee="zhangfei-dev", - current_agent="simayi-challenger", )) + # 设置 current_agent 为 reviewer + conn = bb._conn() + try: + conn.execute("UPDATE tasks SET current_agent=? WHERE id=?", ("simayi-challenger", "t1")) + conn.commit() + finally: + conn.close() + dispatcher = Dispatcher(registered_agents=["simayi-challenger"]) dispatcher._rollback_current_agent(db_path, "t1", "simayi-challenger") import sqlite3 - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - row = conn.execute("SELECT current_agent, status FROM tasks WHERE id=?", ("t1",)).fetchone() - conn.close() + conn2 = sqlite3.connect(str(db_path)) + conn2.row_factory = sqlite3.Row + row = conn2.execute("SELECT current_agent, status FROM tasks WHERE id=?", ("t1",)).fetchone() + conn2.close() # current_agent 回退到 assignee assert row["current_agent"] == "zhangfei-dev"