d58e38d58f
PR #14 从旧分支复制文件导致回退了 PR #10 的 lint 修复。 修复内容: - autoflake 移除未使用导入/变量 - autopep8 修复缩进/空格 - 手动修复 F821(pathlib→Path), F541(f-string), F841(未使用变量) - 所有修复均通过 flake8 --max-line-length=120 --extend-ignore=E501 检查 (0 errors)
311 lines
8.9 KiB
Python
311 lines
8.9 KiB
Python
"""SSE (Server-Sent Events) + Hook 系统
|
||
|
||
SSE: 实时推送黑板事件变更
|
||
Hook: 可插拔的事件处理器(webhook / script / callback)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
|
||
logger = logging.getLogger("moziplus-v2.sse")
|
||
|
||
|
||
class SSEEventType(str, Enum):
|
||
TASK_CREATED = "task_created"
|
||
TASK_UPDATED = "task_updated"
|
||
TASK_COMPLETED = "task_completed"
|
||
TASK_FAILED = "task_failed"
|
||
OBSERVATION_ADDED = "observation_added"
|
||
AGENT_SPAWNED = "agent_spawned"
|
||
AGENT_COMPLETED = "agent_completed"
|
||
DAEMON_TICK = "daemon_tick"
|
||
REVIEW_RESULT = "review_result"
|
||
HOOK_TRIGGERED = "hook_triggered"
|
||
|
||
|
||
class SSEEvent:
|
||
"""SSE 事件"""
|
||
|
||
def __init__(
|
||
self,
|
||
event_type: str,
|
||
data: Any,
|
||
event_id: Optional[str] = None,
|
||
):
|
||
self.id = event_id or str(uuid.uuid4())
|
||
self.event_type = event_type
|
||
self.data = data
|
||
self.timestamp = datetime.utcnow().isoformat()
|
||
|
||
def to_sse(self) -> str:
|
||
"""格式化为 SSE 协议文本"""
|
||
lines = [f"id: {self.id}"]
|
||
lines.append(f"event: {self.event_type}")
|
||
lines.append(
|
||
f"data: {json.dumps(self.data, ensure_ascii=False, default=str)}")
|
||
return "\n".join(lines) + "\n\n"
|
||
|
||
|
||
class SSEBroker:
|
||
"""SSE 事件代理 — 管理 subscriber 和事件分发"""
|
||
|
||
def __init__(self):
|
||
self._subscribers: Dict[str, asyncio.Queue] = {}
|
||
self._event_history: List[SSEEvent] = []
|
||
self._max_history = 100
|
||
|
||
def subscribe(self, client_id: Optional[str] = None) -> tuple:
|
||
"""订阅 SSE 事件
|
||
|
||
Returns:
|
||
(client_id, queue)
|
||
"""
|
||
cid = client_id or str(uuid.uuid4())
|
||
queue = asyncio.Queue(maxsize=100)
|
||
self._subscribers[cid] = queue
|
||
|
||
# 发送历史事件
|
||
for event in self._event_history:
|
||
try:
|
||
queue.put_nowait(event)
|
||
except asyncio.QueueFull:
|
||
break
|
||
|
||
return cid, queue
|
||
|
||
async def subscribe_async(self, client_id: Optional[str] = None) -> tuple:
|
||
"""异步订阅(在 async 上下文中调用)"""
|
||
return self.subscribe(client_id)
|
||
|
||
def unsubscribe(self, client_id: str) -> None:
|
||
if client_id in self._subscribers:
|
||
del self._subscribers[client_id]
|
||
|
||
async def publish(self, event_type: str, data: Any) -> int:
|
||
"""发布事件到所有 subscriber
|
||
|
||
Returns:
|
||
delivered count
|
||
"""
|
||
event = SSEEvent(event_type, data)
|
||
self._event_history.append(event)
|
||
if len(self._event_history) > self._max_history:
|
||
self._event_history = self._event_history[-self._max_history:]
|
||
|
||
delivered = 0
|
||
for cid, queue in list(self._subscribers.items()):
|
||
try:
|
||
queue.put_nowait(event)
|
||
delivered += 1
|
||
except asyncio.QueueFull:
|
||
logger.warning("SSE queue full for client %s, dropping", cid)
|
||
|
||
return delivered
|
||
|
||
def publish_sync(self, event_type: str, data: Any) -> int:
|
||
"""同步版本(用于非 async 上下文)"""
|
||
event = SSEEvent(event_type, data)
|
||
self._event_history.append(event)
|
||
if len(self._event_history) > self._max_history:
|
||
self._event_history = self._event_history[-self._max_history:]
|
||
|
||
delivered = 0
|
||
for cid, queue in list(self._subscribers.items()):
|
||
try:
|
||
queue.put_nowait(event)
|
||
delivered += 1
|
||
except asyncio.QueueFull:
|
||
pass
|
||
return delivered
|
||
|
||
@property
|
||
def subscriber_count(self) -> int:
|
||
return len(self._subscribers)
|
||
|
||
@property
|
||
def history(self) -> List[SSEEvent]:
|
||
return list(self._event_history)
|
||
|
||
|
||
class HookType(str, Enum):
|
||
WEBHOOK = "webhook"
|
||
SCRIPT = "script"
|
||
CALLBACK = "callback"
|
||
|
||
|
||
class Hook:
|
||
"""Hook 定义"""
|
||
|
||
def __init__(
|
||
self,
|
||
hook_id: str,
|
||
event_type: str,
|
||
hook_type: str,
|
||
config: Dict[str, Any],
|
||
enabled: bool = True,
|
||
):
|
||
self.hook_id = hook_id
|
||
self.event_type = event_type
|
||
self.hook_type = hook_type
|
||
self.config = config
|
||
self.enabled = enabled
|
||
self.fire_count = 0
|
||
self.last_fired: Optional[str] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"hook_id": self.hook_id,
|
||
"event_type": self.event_type,
|
||
"hook_type": self.hook_type,
|
||
"config": self.config,
|
||
"enabled": self.enabled,
|
||
"fire_count": self.fire_count,
|
||
}
|
||
|
||
|
||
class HookManager:
|
||
"""Hook 管理器"""
|
||
|
||
def __init__(self, timeout: float = 10.0):
|
||
self._hooks: Dict[str, Hook] = {}
|
||
self._timeout = timeout
|
||
|
||
def register(self, hook: Hook) -> str:
|
||
self._hooks[hook.hook_id] = hook
|
||
return hook.hook_id
|
||
|
||
def unregister(self, hook_id: str) -> bool:
|
||
if hook_id in self._hooks:
|
||
del self._hooks[hook_id]
|
||
return True
|
||
return False
|
||
|
||
def get(self, hook_id: str) -> Optional[Hook]:
|
||
return self._hooks.get(hook_id)
|
||
|
||
def list_hooks(
|
||
self,
|
||
event_type: Optional[str] = None,
|
||
) -> List[Hook]:
|
||
hooks = list(self._hooks.values())
|
||
if event_type:
|
||
hooks = [h for h in hooks if h.event_type == event_type]
|
||
return hooks
|
||
|
||
async def fire(self, event_type: str, data: Any) -> List[Dict[str, Any]]:
|
||
"""触发匹配的 hooks
|
||
|
||
Returns:
|
||
[{hook_id, status, result/error}]
|
||
"""
|
||
results = []
|
||
|
||
for hook in self._hooks.values():
|
||
if not hook.enabled:
|
||
continue
|
||
if hook.event_type != "*" and hook.event_type != event_type:
|
||
continue
|
||
|
||
try:
|
||
result = await self._execute_hook(hook, data)
|
||
hook.fire_count += 1
|
||
hook.last_fired = datetime.utcnow().isoformat()
|
||
results.append({
|
||
"hook_id": hook.hook_id,
|
||
"status": "success",
|
||
"result": result,
|
||
})
|
||
except Exception as e:
|
||
results.append({
|
||
"hook_id": hook.hook_id,
|
||
"status": "error",
|
||
"error": str(e),
|
||
})
|
||
|
||
return results
|
||
|
||
async def _execute_hook(self, hook: Hook, data: Any) -> Any:
|
||
"""执行单个 hook"""
|
||
if hook.hook_type == HookType.WEBHOOK.value:
|
||
return await self._fire_webhook(hook, data)
|
||
|
||
if hook.hook_type == HookType.SCRIPT.value:
|
||
return await self._fire_script(hook, data)
|
||
|
||
if hook.hook_type == HookType.CALLBACK.value:
|
||
return await self._fire_callback(hook, data)
|
||
|
||
raise ValueError(f"Unknown hook type: {hook.hook_type}")
|
||
|
||
async def _fire_webhook(self, hook: Hook, data: Any) -> Dict:
|
||
"""触发 webhook"""
|
||
import urllib.request
|
||
|
||
url = hook.config.get("url", "")
|
||
payload = json.dumps({
|
||
"event": hook.event_type,
|
||
"data": data,
|
||
"timestamp": datetime.utcnow().isoformat(),
|
||
}).encode()
|
||
|
||
req = urllib.request.Request(
|
||
url,
|
||
data=payload,
|
||
headers={"Content-Type": "application/json"},
|
||
method="POST",
|
||
)
|
||
|
||
try:
|
||
with urllib.request.urlopen(req, timeout=self._timeout) as resp:
|
||
return {"status_code": resp.status}
|
||
except Exception as e:
|
||
logger.warning("Webhook %s failed: %s", hook.hook_id, e)
|
||
raise
|
||
|
||
async def _fire_script(self, hook: Hook, data: Any) -> Dict:
|
||
"""触发脚本"""
|
||
script = hook.config.get("script", "")
|
||
env_data = json.dumps(data)
|
||
|
||
proc = await asyncio.create_subprocess_exec(
|
||
script,
|
||
env_data,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
|
||
try:
|
||
stdout, stderr = await asyncio.wait_for(
|
||
proc.communicate(), timeout=self._timeout
|
||
)
|
||
return {
|
||
"exit_code": proc.returncode,
|
||
"stdout": stdout.decode()[:500],
|
||
}
|
||
except asyncio.TimeoutError:
|
||
proc.kill()
|
||
raise TimeoutError(f"Hook script timed out: {script}")
|
||
|
||
async def _fire_callback(self, hook: Hook, data: Any) -> Any:
|
||
"""触发回调函数"""
|
||
callback = hook.config.get("callback")
|
||
if not callback or not callable(callback):
|
||
raise ValueError("No callable callback configured")
|
||
|
||
if asyncio.iscoroutinefunction(callback):
|
||
return await callback(data)
|
||
else:
|
||
return callback(data)
|
||
|
||
@property
|
||
def hook_count(self) -> int:
|
||
return len(self._hooks)
|