Files
sanguo_moziplus_v2/tests/test_blackboard.py
T
2026-05-17 00:42:49 +08:00

478 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""F2 测试:黑板核心(DB + CRUD + 状态机 + 并发)"""
import json
import threading
from pathlib import Path
from typing import List
import pytest
from src.blackboard.db import (
VALID_TRANSITIONS,
TERMINAL_STATUSES,
COMMENT_TYPES,
OUTPUT_TYPES,
REVIEW_TYPES,
VERDICT_TYPES,
init_db,
get_connection,
)
from src.blackboard.models import (
Task, Comment, Output, Decision, Observation,
Review, Experience,
)
from src.blackboard.operations import Blackboard
from src.blackboard.queries import Queries
@pytest.fixture
def tmp_db(tmp_path):
"""创建临时黑板"""
db_path = tmp_path / "test.db"
bb = Blackboard(db_path)
return bb, Queries(db_path), db_path
# ===================================================================
# Schema 初始化
# ===================================================================
class TestSchema:
def test_init_creates_all_tables(self, tmp_path):
db_path = tmp_path / "new.db"
init_db(db_path)
conn = get_connection(db_path)
tables = {r[0] for r in conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()}
conn.close()
expected = {
"tasks", "comments", "outputs", "decisions",
"observations", "events", "agents", "task_attempts",
"reviews", "experiences", "experience_tags",
}
assert expected.issubset(tables)
def test_wal_mode(self, tmp_path):
db_path = tmp_path / "wal.db"
init_db(db_path)
conn = get_connection(db_path)
row = conn.execute("PRAGMA journal_mode").fetchone()
conn.close()
assert row[0] == "wal"
def test_busy_timeout(self, tmp_path):
db_path = tmp_path / "busy.db"
init_db(db_path)
conn = get_connection(db_path)
row = conn.execute("PRAGMA busy_timeout").fetchone()
conn.close()
assert row[0] == 5000
def test_foreign_keys_on(self, tmp_path):
db_path = tmp_path / "fk.db"
init_db(db_path)
conn = get_connection(db_path)
row = conn.execute("PRAGMA foreign_keys").fetchone()
conn.close()
assert row[0] == 1
# ===================================================================
# Task CRUD
# ===================================================================
class TestTaskCRUD:
def test_create_and_get(self, tmp_db):
bb, _, _ = tmp_db
task = Task(id="t1", title="Test Task", task_type="coding")
bb.create_task(task)
got = bb.get_task("t1")
assert got is not None
assert got.title == "Test Task"
assert got.status == "pending"
def test_create_with_all_fields(self, tmp_db):
bb, _, _ = tmp_db
task = Task(
id="t2", title="Full Task", description="desc",
assignee="zhangfei-dev", assigned_by="pangtong-fujunshi",
depends_on='["t1"]', parent_task="t0",
priority=3, task_type="review", deadline="2026-06-01",
risk_level="high", estimated_duration_minutes=60,
must_haves='{"truths": ["a"], "artifacts": ["b"], "constraints": ["c"]}',
)
bb.create_task(task)
got = bb.get_task("t2")
assert got.priority == 3
assert got.risk_level == "high"
assert got.assignee == "zhangfei-dev"
def test_get_nonexistent(self, tmp_db):
bb, _, _ = tmp_db
assert bb.get_task("nope") is None
def test_list_tasks(self, tmp_db):
bb, _, _ = tmp_db
for i in range(5):
bb.create_task(Task(id=f"t{i}", title=f"Task {i}"))
tasks = bb.list_tasks()
assert len(tasks) == 5
def test_list_tasks_by_status(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="A"))
bb.create_task(Task(id="t2", title="B"))
bb.update_task_status("t1", "claimed", agent="agent1")
pending = bb.list_tasks(status="pending")
assert len(pending) == 1
assert pending[0].id == "t2"
def test_list_tasks_by_assignee(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="A", assignee="agent1"))
bb.create_task(Task(id="t2", title="B", assignee="agent2"))
tasks = bb.list_tasks(assignee="agent1")
assert len(tasks) == 1
# ===================================================================
# 状态机
# ===================================================================
class TestStateMachine:
def test_valid_transitions(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="SM Test"))
# pending → claimed → working → review → done
assert bb.update_task_status("t1", "claimed", agent="a1")
assert bb.update_task_status("t1", "working", agent="a1")
assert bb.update_task_status("t1", "review", agent="a1")
assert bb.update_task_status("t1", "done", agent="system")
def test_invalid_transition(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Bad"))
# pending → done is invalid
assert not bb.update_task_status("t1", "done")
def test_terminal_state_blocked(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Done"))
bb.update_task_status("t1", "claimed", agent="a1")
bb.update_task_status("t1", "working", agent="a1")
bb.update_task_status("t1", "review", agent="a1")
bb.update_task_status("t1", "done", agent="system")
# done is terminal
assert not bb.update_task_status("t1", "pending")
def test_failed_to_pending_retry(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Retry"))
bb.update_task_status("t1", "claimed", agent="a1")
bb.update_task_status("t1", "working", agent="a1")
bb.update_task_status("t1", "failed", agent="a1")
assert bb.update_task_status("t1", "pending")
task = bb.get_task("t1")
assert task.retry_count == 1
def test_blocked_to_pending(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Blocked"))
bb.update_task_status("t1", "claimed", agent="a1")
bb.update_task_status("t1", "working", agent="a1")
assert bb.update_task_status("t1", "blocked", agent="a1")
assert bb.update_task_status("t1", "pending")
def test_cancel_from_pending(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Cancel"))
assert bb.update_task_status("t1", "cancelled")
# cancelled is terminal
assert not bb.update_task_status("t1", "pending")
# ===================================================================
# Claim(原子 CAS
# ===================================================================
class TestClaim:
def test_claim_pending(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Claim"))
assert bb.claim_task("t1", "agent1")
task = bb.get_task("t1")
assert task.status == "claimed"
assert task.assignee == "agent1"
def test_claim_assigned_task(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Assigned", assignee="agent1"))
# Same agent can claim
assert bb.claim_task("t1", "agent1")
def test_cannot_claim_others_task(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Other", assignee="agent1"))
assert not bb.claim_task("t1", "agent2")
def test_cannot_claim_working(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Working"))
bb.update_task_status("t1", "claimed", agent="a1")
bb.update_task_status("t1", "working", agent="a1")
assert not bb.claim_task("t1", "agent2")
def test_concurrent_claim(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Race"))
results = []
def claim(agent):
bb2 = Blackboard(bb.db_path)
results.append(bb2.claim_task("t1", agent))
t1 = threading.Thread(target=claim, args=("agent1",))
t2 = threading.Thread(target=claim, args=("agent2",))
t1.start()
t2.start()
t1.join()
t2.join()
# Only one should succeed
assert sum(1 for r in results if r) == 1
# ===================================================================
# Comment
# ===================================================================
class TestComment:
def test_add_and_get(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Comment"))
cid = bb.add_comment("t1", "agent1", "Hello", mentions=["agent2"])
comments = bb.get_comments("t1")
assert len(comments) == 1
assert comments[0].body == "Hello"
assert comments[0].comment_type == "general"
assert json.loads(comments[0].mentions) == ["agent2"]
def test_comment_types(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Types"))
for ct in COMMENT_TYPES:
bb.add_comment("t1", "agent1", f"Type: {ct}", comment_type=ct)
comments = bb.get_comments("t1")
assert len(comments) == len(COMMENT_TYPES)
def test_invalid_comment_type(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Bad"))
with pytest.raises(ValueError):
bb.add_comment("t1", "a", "x", comment_type="invalid")
def test_filter_by_type(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Filter"))
bb.add_comment("t1", "a", "general", comment_type="general")
bb.add_comment("t1", "a", "handoff", comment_type="handoff")
handoffs = bb.get_comments("t1", comment_type="handoff")
assert len(handoffs) == 1
# ===================================================================
# Output
# ===================================================================
class TestOutput:
def test_write_and_get(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Output"))
oid = bb.write_output("t1", "agent1", "code", "main.py",
summary="Main script")
outputs = bb.get_outputs("t1")
assert len(outputs) == 1
assert outputs[0].title == "main.py"
assert outputs[0].output_type == "code"
def test_all_output_types(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Types"))
for ot in OUTPUT_TYPES:
bb.write_output("t1", "a", ot, f"file.{ot}")
assert len(bb.get_outputs("t1")) == len(OUTPUT_TYPES)
def test_invalid_output_type(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Bad"))
with pytest.raises(ValueError):
bb.write_output("t1", "a", "invalid", "x")
# ===================================================================
# Decision + Observation
# ===================================================================
class TestDecision:
def test_add_and_get(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Dec"))
bb.add_decision("t1", "pangtong", "Use FastAPI",
"Better async support", ["Flask", "Litestar"])
decs = bb.get_decisions("t1")
assert len(decs) == 1
assert decs[0].rationale == "Better async support"
assert json.loads(decs[0].alternatives) == ["Flask", "Litestar"]
class TestObservation:
def test_add_and_get(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Obs"))
bb.add_observation("t1", "agent1", "Potential issue", severity="warning")
obs = bb.get_observations("t1")
assert len(obs) == 1
assert obs[0].severity == "warning"
def test_unresolved_filter(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Obs"))
bb.add_observation("t1", "a", "unresolved", severity="blocking")
unresolved = bb.get_observations("t1", unresolved_only=True)
assert len(unresolved) == 1
assert unresolved[0].severity == "blocking"
# ===================================================================
# Review
# ===================================================================
class TestReview:
def test_add_and_get(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Rev"))
review = Review(
id="rev-1", task_id="t1", reviewer="simayi",
review_type="output_review", verdict="approved",
confidence=0.9, summary="LGTM",
)
bb.add_review(review)
reviews = bb.get_reviews("t1")
assert len(reviews) == 1
assert reviews[0].verdict == "approved"
assert reviews[0].confidence == 0.9
# ===================================================================
# Experience
# ===================================================================
class TestExperience:
def test_add_and_query(self, tmp_db):
bb, _, _ = tmp_db
exp = Experience(
experience_id="exp-1", source="task_completion",
summary="SQLite WAL works well", category="best_practice",
created_by="pangtong", tags=["sqlite", "performance"],
)
bb.add_experience(exp)
results = bb.query_experiences(tags=["sqlite"])
assert len(results) == 1
assert set(results[0].tags) == {"sqlite", "performance"}
def test_touch_increments(self, tmp_db):
bb, _, _ = tmp_db
exp = Experience(
experience_id="exp-1", source="manual",
summary="test", category="pattern", created_by="pangtong",
)
bb.add_experience(exp)
bb.touch_experience("exp-1")
bb.touch_experience("exp-1")
results = bb.query_experiences()
assert results[0].usage_count == 2
# ===================================================================
# Event
# ===================================================================
class TestEvent:
def test_events_written_on_transitions(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Events"))
bb.update_task_status("t1", "claimed", agent="a1")
bb.update_task_status("t1", "working", agent="a1")
events = bb.get_events(task_id="t1")
# create + claimed + working = 3 events
assert len(events) >= 3
# ===================================================================
# Queries
# ===================================================================
class TestQueries:
def test_task_summary(self, tmp_db):
bb, q, _ = tmp_db
bb.create_task(Task(id="t1", title="A"))
bb.create_task(Task(id="t2", title="B"))
bb.update_task_status("t1", "claimed", agent="a1")
summary = q.task_summary()
assert summary.get("pending") == 1
assert summary.get("claimed") == 1
def test_pending_dispatchable(self, tmp_db):
bb, q, _ = tmp_db
bb.create_task(Task(id="t1", title="A"))
bb.create_task(Task(id="t2", title="B", depends_on='["t1"]'))
dispatchable = q.pending_dispatchable()
# t1 has no deps → dispatchable; t2 depends on t1 → not
assert len(dispatchable) == 1
assert dispatchable[0].id == "t1"
def test_blocked_tasks_with_deps(self, tmp_db):
bb, q, _ = tmp_db
bb.create_task(Task(id="t1", title="A"))
bb.create_task(Task(id="t2", title="B", depends_on='["t1"]'))
bb.update_task_status("t2", "claimed", agent="a1")
bb.update_task_status("t2", "working", agent="a1")
bb.update_task_status("t2", "blocked", agent="a1")
blocked = q.blocked_tasks_with_deps()
assert len(blocked) == 1
assert blocked[0]["all_deps_done"] is False
# ===================================================================
# 并发写入
# ===================================================================
class TestConcurrency:
def test_concurrent_writes(self, tmp_db):
bb, _, _ = tmp_db
bb.create_task(Task(id="t1", title="Concurrent"))
errors = []
def write_comments(agent, count):
try:
bb2 = Blackboard(bb.db_path)
for i in range(count):
bb2.add_comment("t1", agent, f"{agent}-{i}")
except Exception as e:
errors.append(e)
threads = [
threading.Thread(target=write_comments, args=(f"a{i}", 10))
for i in range(5)
]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
comments = bb.get_comments("t1")
assert len(comments) == 50