"""SSE (Server-Sent Events) + Hook 系统 SSE: 实时推送黑板事件变更 Hook: 可插拔的事件处理器(webhook / script / callback) """ from __future__ import annotations import asyncio import json import logging import uuid from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional logger = logging.getLogger("moziplus-v2.sse") class SSEEventType(str, Enum): TASK_CREATED = "task_created" TASK_UPDATED = "task_updated" TASK_COMPLETED = "task_completed" TASK_FAILED = "task_failed" OBSERVATION_ADDED = "observation_added" AGENT_SPAWNED = "agent_spawned" AGENT_COMPLETED = "agent_completed" DAEMON_TICK = "daemon_tick" REVIEW_RESULT = "review_result" HOOK_TRIGGERED = "hook_triggered" class SSEEvent: """SSE 事件""" def __init__( self, event_type: str, data: Any, event_id: Optional[str] = None, ): self.id = event_id or str(uuid.uuid4()) self.event_type = event_type self.data = data self.timestamp = datetime.utcnow().isoformat() def to_sse(self) -> str: """格式化为 SSE 协议文本""" lines = [f"id: {self.id}"] lines.append(f"event: {self.event_type}") lines.append( f"data: {json.dumps(self.data, ensure_ascii=False, default=str)}") return "\n".join(lines) + "\n\n" class SSEBroker: """SSE 事件代理 — 管理 subscriber 和事件分发""" def __init__(self): self._subscribers: Dict[str, asyncio.Queue] = {} self._event_history: List[SSEEvent] = [] self._max_history = 100 def subscribe(self, client_id: Optional[str] = None) -> tuple: """订阅 SSE 事件 Returns: (client_id, queue) """ cid = client_id or str(uuid.uuid4()) queue = asyncio.Queue(maxsize=100) self._subscribers[cid] = queue # 发送历史事件 for event in self._event_history: try: queue.put_nowait(event) except asyncio.QueueFull: break return cid, queue async def subscribe_async(self, client_id: Optional[str] = None) -> tuple: """异步订阅(在 async 上下文中调用)""" return self.subscribe(client_id) def unsubscribe(self, client_id: str) -> None: if client_id in self._subscribers: del self._subscribers[client_id] async def publish(self, event_type: str, data: Any) -> int: """发布事件到所有 subscriber Returns: delivered count """ event = SSEEvent(event_type, data) self._event_history.append(event) if len(self._event_history) > self._max_history: self._event_history = self._event_history[-self._max_history:] delivered = 0 for cid, queue in list(self._subscribers.items()): try: queue.put_nowait(event) delivered += 1 except asyncio.QueueFull: logger.warning("SSE queue full for client %s, dropping", cid) return delivered def publish_sync(self, event_type: str, data: Any) -> int: """同步版本(用于非 async 上下文)""" event = SSEEvent(event_type, data) self._event_history.append(event) if len(self._event_history) > self._max_history: self._event_history = self._event_history[-self._max_history:] delivered = 0 for cid, queue in list(self._subscribers.items()): try: queue.put_nowait(event) delivered += 1 except asyncio.QueueFull: pass return delivered @property def subscriber_count(self) -> int: return len(self._subscribers) @property def history(self) -> List[SSEEvent]: return list(self._event_history) class HookType(str, Enum): WEBHOOK = "webhook" SCRIPT = "script" CALLBACK = "callback" class Hook: """Hook 定义""" def __init__( self, hook_id: str, event_type: str, hook_type: str, config: Dict[str, Any], enabled: bool = True, ): self.hook_id = hook_id self.event_type = event_type self.hook_type = hook_type self.config = config self.enabled = enabled self.fire_count = 0 self.last_fired: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { "hook_id": self.hook_id, "event_type": self.event_type, "hook_type": self.hook_type, "config": self.config, "enabled": self.enabled, "fire_count": self.fire_count, } class HookManager: """Hook 管理器""" def __init__(self, timeout: float = 10.0): self._hooks: Dict[str, Hook] = {} self._timeout = timeout def register(self, hook: Hook) -> str: self._hooks[hook.hook_id] = hook return hook.hook_id def unregister(self, hook_id: str) -> bool: if hook_id in self._hooks: del self._hooks[hook_id] return True return False def get(self, hook_id: str) -> Optional[Hook]: return self._hooks.get(hook_id) def list_hooks( self, event_type: Optional[str] = None, ) -> List[Hook]: hooks = list(self._hooks.values()) if event_type: hooks = [h for h in hooks if h.event_type == event_type] return hooks async def fire(self, event_type: str, data: Any) -> List[Dict[str, Any]]: """触发匹配的 hooks Returns: [{hook_id, status, result/error}] """ results = [] for hook in self._hooks.values(): if not hook.enabled: continue if hook.event_type != "*" and hook.event_type != event_type: continue try: result = await self._execute_hook(hook, data) hook.fire_count += 1 hook.last_fired = datetime.utcnow().isoformat() results.append({ "hook_id": hook.hook_id, "status": "success", "result": result, }) except Exception as e: results.append({ "hook_id": hook.hook_id, "status": "error", "error": str(e), }) return results async def _execute_hook(self, hook: Hook, data: Any) -> Any: """执行单个 hook""" if hook.hook_type == HookType.WEBHOOK.value: return await self._fire_webhook(hook, data) if hook.hook_type == HookType.SCRIPT.value: return await self._fire_script(hook, data) if hook.hook_type == HookType.CALLBACK.value: return await self._fire_callback(hook, data) raise ValueError(f"Unknown hook type: {hook.hook_type}") async def _fire_webhook(self, hook: Hook, data: Any) -> Dict: """触发 webhook""" import urllib.request url = hook.config.get("url", "") payload = json.dumps({ "event": hook.event_type, "data": data, "timestamp": datetime.utcnow().isoformat(), }).encode() req = urllib.request.Request( url, data=payload, headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(req, timeout=self._timeout) as resp: return {"status_code": resp.status} except Exception as e: logger.warning("Webhook %s failed: %s", hook.hook_id, e) raise async def _fire_script(self, hook: Hook, data: Any) -> Dict: """触发脚本""" script = hook.config.get("script", "") env_data = json.dumps(data) proc = await asyncio.create_subprocess_exec( script, env_data, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout, stderr = await asyncio.wait_for( proc.communicate(), timeout=self._timeout ) return { "exit_code": proc.returncode, "stdout": stdout.decode()[:500], } except asyncio.TimeoutError: proc.kill() raise TimeoutError(f"Hook script timed out: {script}") async def _fire_callback(self, hook: Hook, data: Any) -> Any: """触发回调函数""" callback = hook.config.get("callback") if not callback or not callable(callback): raise ValueError("No callable callback configured") if asyncio.iscoroutinefunction(callback): return await callback(data) else: return callback(data) @property def hook_count(self) -> int: return len(self._hooks)