diff --git a/src/daemon/guardrails.py b/src/daemon/guardrails.py new file mode 100644 index 0000000..8412b58 --- /dev/null +++ b/src/daemon/guardrails.py @@ -0,0 +1,138 @@ +"""安全红线引擎 — PRD §10.1 六条红线的检查与拦截""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +import yaml + +logger = logging.getLogger("moziplus-v2.guardrails") + + +@dataclass +class GuardrailViolation: + """红线违规记录""" + rule_id: str + rule_name: str + severity: str # critical / warning + message: str + action: str # block_and_notify / pause_and_notify / terminate_and_escalate / pause_and_escalate + + +class GuardrailEngine: + """安全红线检查引擎""" + + def __init__(self, config_path: Optional[Path] = None): + self.rules: List[Dict[str, Any]] = [] + self.settings: Dict[str, Any] = {"enabled": True} + if config_path and config_path.exists(): + self.load(config_path) + + def load(self, config_path: Path) -> None: + """加载 guardrails.yaml""" + with open(config_path) as f: + data = yaml.safe_load(f) + self.rules = data.get("rules", []) + self.settings = data.get("settings", {"enabled": True}) + logger.info("Loaded %d guardrail rules from %s", len(self.rules), config_path) + + def check_task(self, task: Any) -> List[GuardrailViolation]: + """检查 Task 是否触犯安全红线(调度前调用)""" + if not self.settings.get("enabled", True): + return [] + + violations = [] + task_title = getattr(task, "title", "") or "" + task_desc = getattr(task, "description", "") or "" + task_type = getattr(task, "task_type", "") or "" + must_haves = getattr(task, "must_haves", {}) or {} + must_haves_str = str(must_haves) + + for rule in self.rules: + rule_id = rule["id"] + triggers = rule.get("triggers", []) + + for trigger in triggers: + matched = False + + # 模式匹配 + pattern = trigger.get("pattern") + if pattern: + check_fields = { + "task_title": task_title, + "task_description": task_desc, + "must_haves": must_haves_str, + } + fields = trigger.get("in", []) + regex = re.compile(pattern, re.IGNORECASE) + for f in fields: + if f in check_fields and regex.search(check_fields[f]): + matched = True + break + + # task_type 匹配 + trigger_type = trigger.get("task_type") + if trigger_type and task_type == trigger_type: + matched = True + + if matched: + violations.append(GuardrailViolation( + rule_id=rule_id, + rule_name=rule["name"], + severity=rule.get("severity", "warning"), + message=rule.get("message", f"安全红线触发: {rule['name']}"), + action=rule.get("action", "block_and_notify"), + )) + break # 每条规则只触发一次 + + if violations: + logger.warning("Guardrail violations for task '%s': %s", + task_title, [v.rule_id for v in violations]) + + return violations + + def check_token_usage(self, token_count: int) -> Optional[GuardrailViolation]: + """检查 Token 消耗是否超标""" + if not self.settings.get("enabled", True): + return None + + for rule in self.rules: + if rule["id"] != "high_token_usage": + continue + threshold = rule.get("triggers", [{}])[0].get("token_threshold", 100000) + if token_count > threshold: + return GuardrailViolation( + rule_id=rule["id"], + rule_name=rule["name"], + severity=rule.get("severity", "warning"), + message=rule.get("message", f"Token消耗超过{threshold}"), + action=rule.get("action", "pause_and_notify"), + ) + return None + + def check_consecutive_failure(self, failure_count: int) -> Optional[GuardrailViolation]: + """检查连续失败次数""" + if not self.settings.get("enabled", True): + return None + + for rule in self.rules: + if rule["id"] != "consecutive_failure": + continue + threshold = rule.get("triggers", [{}])[0].get("consecutive_failures", 3) + if failure_count >= threshold: + return GuardrailViolation( + rule_id=rule["id"], + rule_name=rule["name"], + severity=rule.get("severity", "warning"), + message=rule.get("message", f"连续失败{failure_count}次"), + action=rule.get("action", "pause_and_escalate"), + ) + return None + + @property + def is_enabled(self) -> bool: + return self.settings.get("enabled", True)