diff --git a/src/daemon/sse.py b/src/daemon/sse.py new file mode 100644 index 0000000..d7cf626 --- /dev/null +++ b/src/daemon/sse.py @@ -0,0 +1,308 @@ +"""SSE (Server-Sent Events) + Hook 系统 + +SSE: 实时推送黑板事件变更 +Hook: 可插拔的事件处理器(webhook / script / callback) +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import subprocess +import uuid +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set + +from src.blackboard.models import Event + +logger = logging.getLogger("moziplus-v2.sse") + + +class SSEEventType(str, Enum): + TASK_CREATED = "task_created" + TASK_UPDATED = "task_updated" + TASK_COMPLETED = "task_completed" + TASK_FAILED = "task_failed" + OBSERVATION_ADDED = "observation_added" + AGENT_SPAWNED = "agent_spawned" + AGENT_COMPLETED = "agent_completed" + DAEMON_TICK = "daemon_tick" + REVIEW_RESULT = "review_result" + HOOK_TRIGGERED = "hook_triggered" + + +class SSEEvent: + """SSE 事件""" + + def __init__( + self, + event_type: str, + data: Any, + event_id: Optional[str] = None, + ): + self.id = event_id or str(uuid.uuid4()) + self.event_type = event_type + self.data = data + self.timestamp = datetime.utcnow().isoformat() + + def to_sse(self) -> str: + """格式化为 SSE 协议文本""" + lines = [f"id: {self.id}"] + lines.append(f"event: {self.event_type}") + lines.append(f"data: {json.dumps(self.data, ensure_ascii=False, default=str)}") + return "\n".join(lines) + "\n\n" + + +class SSEBroker: + """SSE 事件代理 — 管理 subscriber 和事件分发""" + + def __init__(self): + self._subscribers: Dict[str, asyncio.Queue] = {} + self._event_history: List[SSEEvent] = [] + self._max_history = 100 + + def subscribe(self, client_id: Optional[str] = None) -> tuple: + """订阅 SSE 事件 + + Returns: + (client_id, queue) + """ + cid = client_id or str(uuid.uuid4()) + queue: asyncio.Queue = asyncio.Queue(maxsize=100) + self._subscribers[cid] = queue + + # 发送历史事件 + for event in self._event_history: + try: + queue.put_nowait(event) + except asyncio.QueueFull: + break + + return cid, queue + + def unsubscribe(self, client_id: str) -> None: + if client_id in self._subscribers: + del self._subscribers[client_id] + + async def publish(self, event_type: str, data: Any) -> int: + """发布事件到所有 subscriber + + Returns: + delivered count + """ + event = SSEEvent(event_type, data) + self._event_history.append(event) + if len(self._event_history) > self._max_history: + self._event_history = self._event_history[-self._max_history:] + + delivered = 0 + for cid, queue in list(self._subscribers.items()): + try: + queue.put_nowait(event) + delivered += 1 + except asyncio.QueueFull: + logger.warning("SSE queue full for client %s, dropping", cid) + + return delivered + + def publish_sync(self, event_type: str, data: Any) -> int: + """同步版本(用于非 async 上下文)""" + event = SSEEvent(event_type, data) + self._event_history.append(event) + if len(self._event_history) > self._max_history: + self._event_history = self._event_history[-self._max_history:] + + delivered = 0 + for cid, queue in list(self._subscribers.items()): + try: + queue.put_nowait(event) + delivered += 1 + except asyncio.QueueFull: + pass + return delivered + + @property + def subscriber_count(self) -> int: + return len(self._subscribers) + + @property + def history(self) -> List[SSEEvent]: + return list(self._event_history) + + +class HookType(str, Enum): + WEBHOOK = "webhook" + SCRIPT = "script" + CALLBACK = "callback" + + +class Hook: + """Hook 定义""" + + def __init__( + self, + hook_id: str, + event_type: str, + hook_type: str, + config: Dict[str, Any], + enabled: bool = True, + ): + self.hook_id = hook_id + self.event_type = event_type + self.hook_type = hook_type + self.config = config + self.enabled = enabled + self.fire_count = 0 + self.last_fired: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "hook_id": self.hook_id, + "event_type": self.event_type, + "hook_type": self.hook_type, + "config": self.config, + "enabled": self.enabled, + "fire_count": self.fire_count, + } + + +class HookManager: + """Hook 管理器""" + + def __init__(self, timeout: float = 10.0): + self._hooks: Dict[str, Hook] = {} + self._timeout = timeout + + def register(self, hook: Hook) -> str: + self._hooks[hook.hook_id] = hook + return hook.hook_id + + def unregister(self, hook_id: str) -> bool: + if hook_id in self._hooks: + del self._hooks[hook_id] + return True + return False + + def get(self, hook_id: str) -> Optional[Hook]: + return self._hooks.get(hook_id) + + def list_hooks( + self, + event_type: Optional[str] = None, + ) -> List[Hook]: + hooks = list(self._hooks.values()) + if event_type: + hooks = [h for h in hooks if h.event_type == event_type] + return hooks + + async def fire(self, event_type: str, data: Any) -> List[Dict[str, Any]]: + """触发匹配的 hooks + + Returns: + [{hook_id, status, result/error}] + """ + results = [] + + for hook in self._hooks.values(): + if not hook.enabled: + continue + if hook.event_type != "*" and hook.event_type != event_type: + continue + + try: + result = await self._execute_hook(hook, data) + hook.fire_count += 1 + hook.last_fired = datetime.utcnow().isoformat() + results.append({ + "hook_id": hook.hook_id, + "status": "success", + "result": result, + }) + except Exception as e: + results.append({ + "hook_id": hook.hook_id, + "status": "error", + "error": str(e), + }) + + return results + + async def _execute_hook(self, hook: Hook, data: Any) -> Any: + """执行单个 hook""" + if hook.hook_type == HookType.WEBHOOK.value: + return await self._fire_webhook(hook, data) + + if hook.hook_type == HookType.SCRIPT.value: + return await self._fire_script(hook, data) + + if hook.hook_type == HookType.CALLBACK.value: + return await self._fire_callback(hook, data) + + raise ValueError(f"Unknown hook type: {hook.hook_type}") + + async def _fire_webhook(self, hook: Hook, data: Any) -> Dict: + """触发 webhook""" + import urllib.request + + url = hook.config.get("url", "") + payload = json.dumps({ + "event": hook.event_type, + "data": data, + "timestamp": datetime.utcnow().isoformat(), + }).encode() + + req = urllib.request.Request( + url, + data=payload, + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urllib.request.urlopen(req, timeout=self._timeout) as resp: + return {"status_code": resp.status} + except Exception as e: + logger.warning("Webhook %s failed: %s", hook.hook_id, e) + raise + + async def _fire_script(self, hook: Hook, data: Any) -> Dict: + """触发脚本""" + script = hook.config.get("script", "") + env_data = json.dumps(data) + + proc = await asyncio.create_subprocess_exec( + script, + env_data, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), timeout=self._timeout + ) + return { + "exit_code": proc.returncode, + "stdout": stdout.decode()[:500], + } + except asyncio.TimeoutError: + proc.kill() + raise TimeoutError(f"Hook script timed out: {script}") + + async def _fire_callback(self, hook: Hook, data: Any) -> Any: + """触发回调函数""" + callback = hook.config.get("callback") + if not callback or not callable(callback): + raise ValueError("No callable callback configured") + + if asyncio.iscoroutinefunction(callback): + return await callback(data) + else: + return callback(data) + + @property + def hook_count(self) -> int: + return len(self._hooks)