diff --git a/tests/test_sse.py b/tests/test_sse.py new file mode 100644 index 0000000..1530397 --- /dev/null +++ b/tests/test_sse.py @@ -0,0 +1,259 @@ +"""F17 SSE + Hook 单元测试 + +按 test-plan-v2.6.md §F17: +- T1: SSE 事件推送(P0) +- T2: Hook 注册/触发(P0) +- T3: 回调 Hook(P0) +- T4: 错误处理(P1) +""" + +import asyncio +import json +import pytest +from unittest.mock import AsyncMock, MagicMock + +from src.daemon.sse import ( + Hook, + HookManager, + HookType, + SSEBroker, + SSEEvent, + SSEEventType, +) + + +# --------------------------------------------------------------------------- +# SSE +# --------------------------------------------------------------------------- + +class TestSSEEvent: + def test_to_sse_format(self): + event = SSEEvent("task_created", {"task_id": "t1"}) + sse = event.to_sse() + assert sse.startswith("id: ") + assert "event: task_created" in sse + assert '"task_id": "t1"' in sse + assert sse.endswith("\n\n") + + def test_custom_event_id(self): + event = SSEEvent("test", {}, event_id="my-id") + assert event.id == "my-id" + assert "id: my-id" in event.to_sse() + + +class TestSSEBroker: + def test_subscribe_returns_queue(self): + broker = SSEBroker() + cid, queue = broker.subscribe() + assert cid + assert isinstance(queue, asyncio.Queue) + + def test_publish_to_subscriber(self): + broker = SSEBroker() + cid, queue = broker.subscribe() + + delivered = asyncio.run(broker.publish("task_created", {"id": "t1"})) + assert delivered == 1 + + event = queue.get_nowait() + assert event.event_type == "task_created" + assert event.data["id"] == "t1" + + def test_unsubscribe(self): + broker = SSEBroker() + cid, _ = broker.subscribe() + assert broker.subscriber_count == 1 + + broker.unsubscribe(cid) + assert broker.subscriber_count == 0 + + def test_publish_no_subscribers(self): + broker = SSEBroker() + delivered = asyncio.run(broker.publish("test", {})) + assert delivered == 0 + + def test_history_kept(self): + broker = SSEBroker() + asyncio.run(broker.publish("e1", {"a": 1})) + asyncio.run(broker.publish("e2", {"b": 2})) + + assert len(broker.history) == 2 + assert broker.history[0].event_type == "e1" + + def test_history_replays_to_new_subscriber(self): + broker = SSEBroker() + asyncio.run(broker.publish("e1", {"x": 1})) + + cid, queue = broker.subscribe() + event = queue.get_nowait() + assert event.event_type == "e1" + + def test_history_max(self): + broker = SSEBroker() + broker._max_history = 3 + for i in range(5): + asyncio.run(broker.publish(f"e{i}", {})) + assert len(broker.history) == 3 + + def test_publish_sync(self): + broker = SSEBroker() + cid, queue = broker.subscribe() + delivered = broker.publish_sync("tick", {"n": 1}) + assert delivered == 1 + event = queue.get_nowait() + assert event.data["n"] == 1 + + def test_multiple_subscribers(self): + broker = SSEBroker() + c1, q1 = broker.subscribe() + c2, q2 = broker.subscribe() + + asyncio.run(broker.publish("test", {"v": 42})) + assert q1.get_nowait().data["v"] == 42 + assert q2.get_nowait().data["v"] == 42 + + +# --------------------------------------------------------------------------- +# Hook +# --------------------------------------------------------------------------- + +class TestHookManager: + def test_register_and_get(self): + hm = HookManager() + hook = Hook("h1", "task_created", HookType.WEBHOOK.value, + {"url": "http://example.com"}) + hm.register(hook) + assert hm.get("h1") is not None + assert hm.hook_count == 1 + + def test_unregister(self): + hm = HookManager() + hm.register(Hook("h1", "*", HookType.CALLBACK.value, {})) + assert hm.unregister("h1") is True + assert hm.hook_count == 0 + + def test_list_hooks_by_event(self): + hm = HookManager() + hm.register(Hook("h1", "task_created", HookType.WEBHOOK.value, {})) + hm.register(Hook("h2", "task_updated", HookType.WEBHOOK.value, {})) + + created = hm.list_hooks(event_type="task_created") + assert len(created) == 1 + assert created[0].hook_id == "h1" + + def test_fire_matching_hook(self): + results = [] + + async def callback(data): + results.append(data) + return "ok" + + hm = HookManager() + hm.register(Hook("h1", "task_created", HookType.CALLBACK.value, + {"callback": callback})) + + fire_results = asyncio.run(hm.fire("task_created", {"task_id": "t1"})) + assert len(fire_results) == 1 + assert fire_results[0]["status"] == "success" + assert len(results) == 1 + + def test_fire_wildcard_hook(self): + results = [] + + async def callback(data): + results.append(data) + + hm = HookManager() + hm.register(Hook("h1", "*", HookType.CALLBACK.value, + {"callback": callback})) + + asyncio.run(hm.fire("any_event", {"x": 1})) + assert len(results) == 1 + + def test_fire_no_match(self): + hm = HookManager() + hm.register(Hook("h1", "task_created", HookType.CALLBACK.value, + {"callback": lambda d: None})) + + results = asyncio.run(hm.fire("task_updated", {})) + assert len(results) == 0 + + def test_fire_disabled_hook(self): + hm = HookManager() + hm.register(Hook("h1", "*", HookType.CALLBACK.value, + {"callback": lambda d: d}, enabled=False)) + + results = asyncio.run(hm.fire("test", {})) + assert len(results) == 0 + + def test_sync_callback(self): + results = [] + + def sync_callback(data): + results.append(data) + return "sync_ok" + + hm = HookManager() + hm.register(Hook("h1", "test", HookType.CALLBACK.value, + {"callback": sync_callback})) + + fire_results = asyncio.run(hm.fire("test", {"v": 1})) + assert fire_results[0]["status"] == "success" + assert results[0]["v"] == 1 + + def test_hook_fire_count(self): + async def cb(data): + pass + + hm = HookManager() + hm.register(Hook("h1", "test", HookType.CALLBACK.value, + {"callback": cb})) + + asyncio.run(hm.fire("test", {})) + asyncio.run(hm.fire("test", {})) + assert hm.get("h1").fire_count == 2 + assert hm.get("h1").last_fired is not None + + +# --------------------------------------------------------------------------- +# T4: 错误处理 +# --------------------------------------------------------------------------- + +class TestHookErrors: + def test_webhook_error_handled(self): + hm = HookManager() + hm.register(Hook("h1", "test", HookType.WEBHOOK.value, + {"url": "http://nonexistent.invalid/hook"})) + + results = asyncio.run(hm.fire("test", {})) + assert len(results) == 1 + assert results[0]["status"] == "error" + + def test_script_error_handled(self): + hm = HookManager() + hm.register(Hook("h1", "test", HookType.SCRIPT.value, + {"script": "/nonexistent/script.sh"})) + + results = asyncio.run(hm.fire("test", {})) + assert len(results) == 1 + assert results[0]["status"] == "error" + + def test_callback_error_handled(self): + def bad_callback(data): + raise RuntimeError("callback error") + + hm = HookManager() + hm.register(Hook("h1", "test", HookType.CALLBACK.value, + {"callback": bad_callback})) + + results = asyncio.run(hm.fire("test", {})) + assert results[0]["status"] == "error" + assert "callback error" in results[0]["error"] + + def test_no_callable_callback(self): + hm = HookManager() + hm.register(Hook("h1", "test", HookType.CALLBACK.value, + {"callback": None})) + + results = asyncio.run(hm.fire("test", {})) + assert results[0]["status"] == "error"