Files
sanguo_moziplus_v2/src/daemon/sse.py
T
cfdaily d58e38d58f
CI / lint (pull_request) Successful in 6s
CI / test (pull_request) Successful in 9s
CI / notify-on-failure (pull_request) Successful in 0s
fix(lint): 修复 PR #14 引入的 lint 回退 (119→0)
PR #14 从旧分支复制文件导致回退了 PR #10 的 lint 修复。
修复内容:
- autoflake 移除未使用导入/变量
- autopep8 修复缩进/空格
- 手动修复 F821(pathlib→Path), F541(f-string), F841(未使用变量)
- 所有修复均通过 flake8 --max-line-length=120 --extend-ignore=E501 检查 (0 errors)
2026-06-09 23:53:29 +08:00

311 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""SSE (Server-Sent Events) + Hook 系统
SSE: 实时推送黑板事件变更
Hook: 可插拔的事件处理器(webhook / script / callback
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
logger = logging.getLogger("moziplus-v2.sse")
class SSEEventType(str, Enum):
TASK_CREATED = "task_created"
TASK_UPDATED = "task_updated"
TASK_COMPLETED = "task_completed"
TASK_FAILED = "task_failed"
OBSERVATION_ADDED = "observation_added"
AGENT_SPAWNED = "agent_spawned"
AGENT_COMPLETED = "agent_completed"
DAEMON_TICK = "daemon_tick"
REVIEW_RESULT = "review_result"
HOOK_TRIGGERED = "hook_triggered"
class SSEEvent:
"""SSE 事件"""
def __init__(
self,
event_type: str,
data: Any,
event_id: Optional[str] = None,
):
self.id = event_id or str(uuid.uuid4())
self.event_type = event_type
self.data = data
self.timestamp = datetime.utcnow().isoformat()
def to_sse(self) -> str:
"""格式化为 SSE 协议文本"""
lines = [f"id: {self.id}"]
lines.append(f"event: {self.event_type}")
lines.append(
f"data: {json.dumps(self.data, ensure_ascii=False, default=str)}")
return "\n".join(lines) + "\n\n"
class SSEBroker:
"""SSE 事件代理 — 管理 subscriber 和事件分发"""
def __init__(self):
self._subscribers: Dict[str, asyncio.Queue] = {}
self._event_history: List[SSEEvent] = []
self._max_history = 100
def subscribe(self, client_id: Optional[str] = None) -> tuple:
"""订阅 SSE 事件
Returns:
(client_id, queue)
"""
cid = client_id or str(uuid.uuid4())
queue = asyncio.Queue(maxsize=100)
self._subscribers[cid] = queue
# 发送历史事件
for event in self._event_history:
try:
queue.put_nowait(event)
except asyncio.QueueFull:
break
return cid, queue
async def subscribe_async(self, client_id: Optional[str] = None) -> tuple:
"""异步订阅(在 async 上下文中调用)"""
return self.subscribe(client_id)
def unsubscribe(self, client_id: str) -> None:
if client_id in self._subscribers:
del self._subscribers[client_id]
async def publish(self, event_type: str, data: Any) -> int:
"""发布事件到所有 subscriber
Returns:
delivered count
"""
event = SSEEvent(event_type, data)
self._event_history.append(event)
if len(self._event_history) > self._max_history:
self._event_history = self._event_history[-self._max_history:]
delivered = 0
for cid, queue in list(self._subscribers.items()):
try:
queue.put_nowait(event)
delivered += 1
except asyncio.QueueFull:
logger.warning("SSE queue full for client %s, dropping", cid)
return delivered
def publish_sync(self, event_type: str, data: Any) -> int:
"""同步版本(用于非 async 上下文)"""
event = SSEEvent(event_type, data)
self._event_history.append(event)
if len(self._event_history) > self._max_history:
self._event_history = self._event_history[-self._max_history:]
delivered = 0
for cid, queue in list(self._subscribers.items()):
try:
queue.put_nowait(event)
delivered += 1
except asyncio.QueueFull:
pass
return delivered
@property
def subscriber_count(self) -> int:
return len(self._subscribers)
@property
def history(self) -> List[SSEEvent]:
return list(self._event_history)
class HookType(str, Enum):
WEBHOOK = "webhook"
SCRIPT = "script"
CALLBACK = "callback"
class Hook:
"""Hook 定义"""
def __init__(
self,
hook_id: str,
event_type: str,
hook_type: str,
config: Dict[str, Any],
enabled: bool = True,
):
self.hook_id = hook_id
self.event_type = event_type
self.hook_type = hook_type
self.config = config
self.enabled = enabled
self.fire_count = 0
self.last_fired: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"hook_id": self.hook_id,
"event_type": self.event_type,
"hook_type": self.hook_type,
"config": self.config,
"enabled": self.enabled,
"fire_count": self.fire_count,
}
class HookManager:
"""Hook 管理器"""
def __init__(self, timeout: float = 10.0):
self._hooks: Dict[str, Hook] = {}
self._timeout = timeout
def register(self, hook: Hook) -> str:
self._hooks[hook.hook_id] = hook
return hook.hook_id
def unregister(self, hook_id: str) -> bool:
if hook_id in self._hooks:
del self._hooks[hook_id]
return True
return False
def get(self, hook_id: str) -> Optional[Hook]:
return self._hooks.get(hook_id)
def list_hooks(
self,
event_type: Optional[str] = None,
) -> List[Hook]:
hooks = list(self._hooks.values())
if event_type:
hooks = [h for h in hooks if h.event_type == event_type]
return hooks
async def fire(self, event_type: str, data: Any) -> List[Dict[str, Any]]:
"""触发匹配的 hooks
Returns:
[{hook_id, status, result/error}]
"""
results = []
for hook in self._hooks.values():
if not hook.enabled:
continue
if hook.event_type != "*" and hook.event_type != event_type:
continue
try:
result = await self._execute_hook(hook, data)
hook.fire_count += 1
hook.last_fired = datetime.utcnow().isoformat()
results.append({
"hook_id": hook.hook_id,
"status": "success",
"result": result,
})
except Exception as e:
results.append({
"hook_id": hook.hook_id,
"status": "error",
"error": str(e),
})
return results
async def _execute_hook(self, hook: Hook, data: Any) -> Any:
"""执行单个 hook"""
if hook.hook_type == HookType.WEBHOOK.value:
return await self._fire_webhook(hook, data)
if hook.hook_type == HookType.SCRIPT.value:
return await self._fire_script(hook, data)
if hook.hook_type == HookType.CALLBACK.value:
return await self._fire_callback(hook, data)
raise ValueError(f"Unknown hook type: {hook.hook_type}")
async def _fire_webhook(self, hook: Hook, data: Any) -> Dict:
"""触发 webhook"""
import urllib.request
url = hook.config.get("url", "")
payload = json.dumps({
"event": hook.event_type,
"data": data,
"timestamp": datetime.utcnow().isoformat(),
}).encode()
req = urllib.request.Request(
url,
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=self._timeout) as resp:
return {"status_code": resp.status}
except Exception as e:
logger.warning("Webhook %s failed: %s", hook.hook_id, e)
raise
async def _fire_script(self, hook: Hook, data: Any) -> Dict:
"""触发脚本"""
script = hook.config.get("script", "")
env_data = json.dumps(data)
proc = await asyncio.create_subprocess_exec(
script,
env_data,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(), timeout=self._timeout
)
return {
"exit_code": proc.returncode,
"stdout": stdout.decode()[:500],
}
except asyncio.TimeoutError:
proc.kill()
raise TimeoutError(f"Hook script timed out: {script}")
async def _fire_callback(self, hook: Hook, data: Any) -> Any:
"""触发回调函数"""
callback = hook.config.get("callback")
if not callback or not callable(callback):
raise ValueError("No callable callback configured")
if asyncio.iscoroutinefunction(callback):
return await callback(data)
else:
return callback(data)
@property
def hook_count(self) -> int:
return len(self._hooks)