478 lines
17 KiB
Python
478 lines
17 KiB
Python
"""F2 测试:黑板核心(DB + CRUD + 状态机 + 并发)"""
|
||
|
||
import json
|
||
import threading
|
||
from pathlib import Path
|
||
from typing import List
|
||
|
||
import pytest
|
||
|
||
from src.blackboard.db import (
|
||
VALID_TRANSITIONS,
|
||
TERMINAL_STATUSES,
|
||
COMMENT_TYPES,
|
||
OUTPUT_TYPES,
|
||
REVIEW_TYPES,
|
||
VERDICT_TYPES,
|
||
init_db,
|
||
get_connection,
|
||
)
|
||
from src.blackboard.models import (
|
||
Task, Comment, Output, Decision, Observation,
|
||
Review, Experience,
|
||
)
|
||
from src.blackboard.operations import Blackboard
|
||
from src.blackboard.queries import Queries
|
||
|
||
|
||
@pytest.fixture
|
||
def tmp_db(tmp_path):
|
||
"""创建临时黑板"""
|
||
db_path = tmp_path / "test.db"
|
||
bb = Blackboard(db_path)
|
||
return bb, Queries(db_path), db_path
|
||
|
||
|
||
# ===================================================================
|
||
# Schema 初始化
|
||
# ===================================================================
|
||
|
||
class TestSchema:
|
||
def test_init_creates_all_tables(self, tmp_path):
|
||
db_path = tmp_path / "new.db"
|
||
init_db(db_path)
|
||
conn = get_connection(db_path)
|
||
tables = {r[0] for r in conn.execute(
|
||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||
).fetchall()}
|
||
conn.close()
|
||
expected = {
|
||
"tasks", "comments", "outputs", "decisions",
|
||
"observations", "events", "agents", "task_attempts",
|
||
"reviews", "experiences", "experience_tags",
|
||
}
|
||
assert expected.issubset(tables)
|
||
|
||
def test_wal_mode(self, tmp_path):
|
||
db_path = tmp_path / "wal.db"
|
||
init_db(db_path)
|
||
conn = get_connection(db_path)
|
||
row = conn.execute("PRAGMA journal_mode").fetchone()
|
||
conn.close()
|
||
assert row[0] == "wal"
|
||
|
||
def test_busy_timeout(self, tmp_path):
|
||
db_path = tmp_path / "busy.db"
|
||
init_db(db_path)
|
||
conn = get_connection(db_path)
|
||
row = conn.execute("PRAGMA busy_timeout").fetchone()
|
||
conn.close()
|
||
assert row[0] == 5000
|
||
|
||
def test_foreign_keys_on(self, tmp_path):
|
||
db_path = tmp_path / "fk.db"
|
||
init_db(db_path)
|
||
conn = get_connection(db_path)
|
||
row = conn.execute("PRAGMA foreign_keys").fetchone()
|
||
conn.close()
|
||
assert row[0] == 1
|
||
|
||
|
||
# ===================================================================
|
||
# Task CRUD
|
||
# ===================================================================
|
||
|
||
class TestTaskCRUD:
|
||
def test_create_and_get(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
task = Task(id="t1", title="Test Task", task_type="coding")
|
||
bb.create_task(task)
|
||
got = bb.get_task("t1")
|
||
assert got is not None
|
||
assert got.title == "Test Task"
|
||
assert got.status == "pending"
|
||
|
||
def test_create_with_all_fields(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
task = Task(
|
||
id="t2", title="Full Task", description="desc",
|
||
assignee="zhangfei-dev", assigned_by="pangtong-fujunshi",
|
||
depends_on='["t1"]', parent_task="t0",
|
||
priority=3, task_type="review", deadline="2026-06-01",
|
||
risk_level="high", estimated_duration_minutes=60,
|
||
must_haves='{"truths": ["a"], "artifacts": ["b"], "constraints": ["c"]}',
|
||
)
|
||
bb.create_task(task)
|
||
got = bb.get_task("t2")
|
||
assert got.priority == 3
|
||
assert got.risk_level == "high"
|
||
assert got.assignee == "zhangfei-dev"
|
||
|
||
def test_get_nonexistent(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
assert bb.get_task("nope") is None
|
||
|
||
def test_list_tasks(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
for i in range(5):
|
||
bb.create_task(Task(id=f"t{i}", title=f"Task {i}"))
|
||
tasks = bb.list_tasks()
|
||
assert len(tasks) == 5
|
||
|
||
def test_list_tasks_by_status(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="A"))
|
||
bb.create_task(Task(id="t2", title="B"))
|
||
bb.update_task_status("t1", "claimed", agent="agent1")
|
||
pending = bb.list_tasks(status="pending")
|
||
assert len(pending) == 1
|
||
assert pending[0].id == "t2"
|
||
|
||
def test_list_tasks_by_assignee(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="A", assignee="agent1"))
|
||
bb.create_task(Task(id="t2", title="B", assignee="agent2"))
|
||
tasks = bb.list_tasks(assignee="agent1")
|
||
assert len(tasks) == 1
|
||
|
||
|
||
# ===================================================================
|
||
# 状态机
|
||
# ===================================================================
|
||
|
||
class TestStateMachine:
|
||
def test_valid_transitions(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="SM Test"))
|
||
|
||
# pending → claimed → working → review → done
|
||
assert bb.update_task_status("t1", "claimed", agent="a1")
|
||
assert bb.update_task_status("t1", "working", agent="a1")
|
||
assert bb.update_task_status("t1", "review", agent="a1")
|
||
assert bb.update_task_status("t1", "done", agent="system")
|
||
|
||
def test_invalid_transition(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Bad"))
|
||
# pending → done is invalid
|
||
assert not bb.update_task_status("t1", "done")
|
||
|
||
def test_terminal_state_blocked(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Done"))
|
||
bb.update_task_status("t1", "claimed", agent="a1")
|
||
bb.update_task_status("t1", "working", agent="a1")
|
||
bb.update_task_status("t1", "review", agent="a1")
|
||
bb.update_task_status("t1", "done", agent="system")
|
||
# done is terminal
|
||
assert not bb.update_task_status("t1", "pending")
|
||
|
||
def test_failed_to_pending_retry(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Retry"))
|
||
bb.update_task_status("t1", "claimed", agent="a1")
|
||
bb.update_task_status("t1", "working", agent="a1")
|
||
bb.update_task_status("t1", "failed", agent="a1")
|
||
assert bb.update_task_status("t1", "pending")
|
||
task = bb.get_task("t1")
|
||
assert task.retry_count == 1
|
||
|
||
def test_blocked_to_pending(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Blocked"))
|
||
bb.update_task_status("t1", "claimed", agent="a1")
|
||
bb.update_task_status("t1", "working", agent="a1")
|
||
assert bb.update_task_status("t1", "blocked", agent="a1")
|
||
assert bb.update_task_status("t1", "pending")
|
||
|
||
def test_cancel_from_pending(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Cancel"))
|
||
assert bb.update_task_status("t1", "cancelled")
|
||
# cancelled is terminal
|
||
assert not bb.update_task_status("t1", "pending")
|
||
|
||
|
||
# ===================================================================
|
||
# Claim(原子 CAS)
|
||
# ===================================================================
|
||
|
||
class TestClaim:
|
||
def test_claim_pending(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Claim"))
|
||
assert bb.claim_task("t1", "agent1")
|
||
task = bb.get_task("t1")
|
||
assert task.status == "claimed"
|
||
assert task.assignee == "agent1"
|
||
|
||
def test_claim_assigned_task(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Assigned", assignee="agent1"))
|
||
# Same agent can claim
|
||
assert bb.claim_task("t1", "agent1")
|
||
|
||
def test_cannot_claim_others_task(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Other", assignee="agent1"))
|
||
assert not bb.claim_task("t1", "agent2")
|
||
|
||
def test_cannot_claim_working(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Working"))
|
||
bb.update_task_status("t1", "claimed", agent="a1")
|
||
bb.update_task_status("t1", "working", agent="a1")
|
||
assert not bb.claim_task("t1", "agent2")
|
||
|
||
def test_concurrent_claim(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Race"))
|
||
results = []
|
||
|
||
def claim(agent):
|
||
bb2 = Blackboard(bb.db_path)
|
||
results.append(bb2.claim_task("t1", agent))
|
||
|
||
t1 = threading.Thread(target=claim, args=("agent1",))
|
||
t2 = threading.Thread(target=claim, args=("agent2",))
|
||
t1.start()
|
||
t2.start()
|
||
t1.join()
|
||
t2.join()
|
||
# Only one should succeed
|
||
assert sum(1 for r in results if r) == 1
|
||
|
||
|
||
# ===================================================================
|
||
# Comment
|
||
# ===================================================================
|
||
|
||
class TestComment:
|
||
def test_add_and_get(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Comment"))
|
||
cid = bb.add_comment("t1", "agent1", "Hello", mentions=["agent2"])
|
||
comments = bb.get_comments("t1")
|
||
assert len(comments) == 1
|
||
assert comments[0].body == "Hello"
|
||
assert comments[0].comment_type == "general"
|
||
assert json.loads(comments[0].mentions) == ["agent2"]
|
||
|
||
def test_comment_types(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Types"))
|
||
for ct in COMMENT_TYPES:
|
||
bb.add_comment("t1", "agent1", f"Type: {ct}", comment_type=ct)
|
||
comments = bb.get_comments("t1")
|
||
assert len(comments) == len(COMMENT_TYPES)
|
||
|
||
def test_invalid_comment_type(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Bad"))
|
||
with pytest.raises(ValueError):
|
||
bb.add_comment("t1", "a", "x", comment_type="invalid")
|
||
|
||
def test_filter_by_type(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Filter"))
|
||
bb.add_comment("t1", "a", "general", comment_type="general")
|
||
bb.add_comment("t1", "a", "handoff", comment_type="handoff")
|
||
handoffs = bb.get_comments("t1", comment_type="handoff")
|
||
assert len(handoffs) == 1
|
||
|
||
|
||
# ===================================================================
|
||
# Output
|
||
# ===================================================================
|
||
|
||
class TestOutput:
|
||
def test_write_and_get(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Output"))
|
||
oid = bb.write_output("t1", "agent1", "code", "main.py",
|
||
summary="Main script")
|
||
outputs = bb.get_outputs("t1")
|
||
assert len(outputs) == 1
|
||
assert outputs[0].title == "main.py"
|
||
assert outputs[0].output_type == "code"
|
||
|
||
def test_all_output_types(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Types"))
|
||
for ot in OUTPUT_TYPES:
|
||
bb.write_output("t1", "a", ot, f"file.{ot}")
|
||
assert len(bb.get_outputs("t1")) == len(OUTPUT_TYPES)
|
||
|
||
def test_invalid_output_type(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Bad"))
|
||
with pytest.raises(ValueError):
|
||
bb.write_output("t1", "a", "invalid", "x")
|
||
|
||
|
||
# ===================================================================
|
||
# Decision + Observation
|
||
# ===================================================================
|
||
|
||
class TestDecision:
|
||
def test_add_and_get(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Dec"))
|
||
bb.add_decision("t1", "pangtong", "Use FastAPI",
|
||
"Better async support", ["Flask", "Litestar"])
|
||
decs = bb.get_decisions("t1")
|
||
assert len(decs) == 1
|
||
assert decs[0].rationale == "Better async support"
|
||
assert json.loads(decs[0].alternatives) == ["Flask", "Litestar"]
|
||
|
||
|
||
class TestObservation:
|
||
def test_add_and_get(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Obs"))
|
||
bb.add_observation("t1", "agent1", "Potential issue", severity="warning")
|
||
obs = bb.get_observations("t1")
|
||
assert len(obs) == 1
|
||
assert obs[0].severity == "warning"
|
||
|
||
def test_unresolved_filter(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Obs"))
|
||
bb.add_observation("t1", "a", "unresolved", severity="blocking")
|
||
unresolved = bb.get_observations("t1", unresolved_only=True)
|
||
assert len(unresolved) == 1
|
||
assert unresolved[0].severity == "blocking"
|
||
|
||
|
||
# ===================================================================
|
||
# Review
|
||
# ===================================================================
|
||
|
||
class TestReview:
|
||
def test_add_and_get(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Rev"))
|
||
review = Review(
|
||
id="rev-1", task_id="t1", reviewer="simayi",
|
||
review_type="output_review", verdict="approved",
|
||
confidence=0.9, summary="LGTM",
|
||
)
|
||
bb.add_review(review)
|
||
reviews = bb.get_reviews("t1")
|
||
assert len(reviews) == 1
|
||
assert reviews[0].verdict == "approved"
|
||
assert reviews[0].confidence == 0.9
|
||
|
||
|
||
# ===================================================================
|
||
# Experience
|
||
# ===================================================================
|
||
|
||
class TestExperience:
|
||
def test_add_and_query(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
exp = Experience(
|
||
experience_id="exp-1", source="task_completion",
|
||
summary="SQLite WAL works well", category="best_practice",
|
||
created_by="pangtong", tags=["sqlite", "performance"],
|
||
)
|
||
bb.add_experience(exp)
|
||
results = bb.query_experiences(tags=["sqlite"])
|
||
assert len(results) == 1
|
||
assert set(results[0].tags) == {"sqlite", "performance"}
|
||
|
||
def test_touch_increments(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
exp = Experience(
|
||
experience_id="exp-1", source="manual",
|
||
summary="test", category="pattern", created_by="pangtong",
|
||
)
|
||
bb.add_experience(exp)
|
||
bb.touch_experience("exp-1")
|
||
bb.touch_experience("exp-1")
|
||
results = bb.query_experiences()
|
||
assert results[0].usage_count == 2
|
||
|
||
|
||
# ===================================================================
|
||
# Event
|
||
# ===================================================================
|
||
|
||
class TestEvent:
|
||
def test_events_written_on_transitions(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Events"))
|
||
bb.update_task_status("t1", "claimed", agent="a1")
|
||
bb.update_task_status("t1", "working", agent="a1")
|
||
events = bb.get_events(task_id="t1")
|
||
# create + claimed + working = 3 events
|
||
assert len(events) >= 3
|
||
|
||
|
||
# ===================================================================
|
||
# Queries
|
||
# ===================================================================
|
||
|
||
class TestQueries:
|
||
def test_task_summary(self, tmp_db):
|
||
bb, q, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="A"))
|
||
bb.create_task(Task(id="t2", title="B"))
|
||
bb.update_task_status("t1", "claimed", agent="a1")
|
||
summary = q.task_summary()
|
||
assert summary.get("pending") == 1
|
||
assert summary.get("claimed") == 1
|
||
|
||
def test_pending_dispatchable(self, tmp_db):
|
||
bb, q, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="A"))
|
||
bb.create_task(Task(id="t2", title="B", depends_on='["t1"]'))
|
||
dispatchable = q.pending_dispatchable()
|
||
# t1 has no deps → dispatchable; t2 depends on t1 → not
|
||
assert len(dispatchable) == 1
|
||
assert dispatchable[0].id == "t1"
|
||
|
||
def test_blocked_tasks_with_deps(self, tmp_db):
|
||
bb, q, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="A"))
|
||
bb.create_task(Task(id="t2", title="B", depends_on='["t1"]'))
|
||
bb.update_task_status("t2", "claimed", agent="a1")
|
||
bb.update_task_status("t2", "working", agent="a1")
|
||
bb.update_task_status("t2", "blocked", agent="a1")
|
||
blocked = q.blocked_tasks_with_deps()
|
||
assert len(blocked) == 1
|
||
assert blocked[0]["all_deps_done"] is False
|
||
|
||
|
||
# ===================================================================
|
||
# 并发写入
|
||
# ===================================================================
|
||
|
||
class TestConcurrency:
|
||
def test_concurrent_writes(self, tmp_db):
|
||
bb, _, _ = tmp_db
|
||
bb.create_task(Task(id="t1", title="Concurrent"))
|
||
|
||
errors = []
|
||
|
||
def write_comments(agent, count):
|
||
try:
|
||
bb2 = Blackboard(bb.db_path)
|
||
for i in range(count):
|
||
bb2.add_comment("t1", agent, f"{agent}-{i}")
|
||
except Exception as e:
|
||
errors.append(e)
|
||
|
||
threads = [
|
||
threading.Thread(target=write_comments, args=(f"a{i}", 10))
|
||
for i in range(5)
|
||
]
|
||
for t in threads:
|
||
t.start()
|
||
for t in threads:
|
||
t.join()
|
||
|
||
assert len(errors) == 0
|
||
comments = bb.get_comments("t1")
|
||
assert len(comments) == 50
|