d58e38d58f
PR #14 从旧分支复制文件导致回退了 PR #10 的 lint 修复。 修复内容: - autoflake 移除未使用导入/变量 - autopep8 修复缩进/空格 - 手动修复 F821(pathlib→Path), F541(f-string), F841(未使用变量) - 所有修复均通过 flake8 --max-line-length=120 --extend-ignore=E501 检查 (0 errors)
149 lines
5.2 KiB
Python
149 lines
5.2 KiB
Python
"""安全红线引擎 — PRD §10.1 六条红线的检查与拦截"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import yaml
|
|
|
|
logger = logging.getLogger("moziplus-v2.guardrails")
|
|
|
|
|
|
@dataclass
|
|
class GuardrailViolation:
|
|
"""红线违规记录"""
|
|
rule_id: str
|
|
rule_name: str
|
|
severity: str # critical / warning
|
|
message: str
|
|
action: str # block_and_notify / pause_and_notify / terminate_and_escalate / pause_and_escalate
|
|
|
|
|
|
class GuardrailEngine:
|
|
"""安全红线检查引擎"""
|
|
|
|
def __init__(self, config_path: Optional[Path] = None):
|
|
self.rules: List[Dict[str, Any]] = []
|
|
self.settings: Dict[str, Any] = {"enabled": True}
|
|
if config_path and config_path.exists():
|
|
self.load(config_path)
|
|
|
|
def load(self, config_path: Path) -> None:
|
|
"""加载 guardrails.yaml"""
|
|
with open(config_path) as f:
|
|
data = yaml.safe_load(f)
|
|
self.rules = data.get("rules", [])
|
|
self.settings = data.get("settings", {"enabled": True})
|
|
logger.info(
|
|
"Loaded %d guardrail rules from %s", len(
|
|
self.rules), config_path)
|
|
|
|
def check_task(self, task: Any) -> List[GuardrailViolation]:
|
|
"""检查 Task 是否触犯安全红线(调度前调用)"""
|
|
if not self.settings.get("enabled", True):
|
|
return []
|
|
|
|
violations = []
|
|
task_title = getattr(task, "title", "") or ""
|
|
task_desc = getattr(task, "description", "") or ""
|
|
task_type = getattr(task, "task_type", "") or ""
|
|
must_haves = getattr(task, "must_haves", {}) or {}
|
|
must_haves_str = str(must_haves)
|
|
|
|
for rule in self.rules:
|
|
rule_id = rule["id"]
|
|
triggers = rule.get("triggers", [])
|
|
|
|
for trigger in triggers:
|
|
matched = False
|
|
|
|
# 模式匹配
|
|
pattern = trigger.get("pattern")
|
|
if pattern:
|
|
check_fields = {
|
|
"task_title": task_title,
|
|
"task_description": task_desc,
|
|
"must_haves": must_haves_str,
|
|
}
|
|
fields = trigger.get("in", [])
|
|
regex = re.compile(pattern, re.IGNORECASE)
|
|
for f in fields:
|
|
if f in check_fields and regex.search(check_fields[f]):
|
|
matched = True
|
|
break
|
|
|
|
# task_type 匹配
|
|
trigger_type = trigger.get("task_type")
|
|
if trigger_type and task_type == trigger_type:
|
|
matched = True
|
|
|
|
if matched:
|
|
violations.append(GuardrailViolation(
|
|
rule_id=rule_id,
|
|
rule_name=rule["name"],
|
|
severity=rule.get("severity", "warning"),
|
|
message=rule.get("message", f"安全红线触发: {rule['name']}"),
|
|
action=rule.get("action", "block_and_notify"),
|
|
))
|
|
break # 每条规则只触发一次
|
|
|
|
if violations:
|
|
logger.warning("Guardrail violations for task '%s': %s",
|
|
task_title, [v.rule_id for v in violations])
|
|
|
|
return violations
|
|
|
|
def check_token_usage(
|
|
self, token_count: int) -> Optional[GuardrailViolation]:
|
|
"""检查 Token 消耗是否超标"""
|
|
if not self.settings.get("enabled", True):
|
|
return None
|
|
|
|
for rule in self.rules:
|
|
if rule["id"] != "high_token_usage":
|
|
continue
|
|
threshold = rule.get(
|
|
"triggers", [
|
|
{}])[0].get(
|
|
"token_threshold", 100000)
|
|
if token_count > threshold:
|
|
return GuardrailViolation(
|
|
rule_id=rule["id"],
|
|
rule_name=rule["name"],
|
|
severity=rule.get("severity", "warning"),
|
|
message=rule.get("message", f"Token消耗超过{threshold}"),
|
|
action=rule.get("action", "pause_and_notify"),
|
|
)
|
|
return None
|
|
|
|
def check_consecutive_failure(
|
|
self, failure_count: int) -> Optional[GuardrailViolation]:
|
|
"""检查连续失败次数"""
|
|
if not self.settings.get("enabled", True):
|
|
return None
|
|
|
|
for rule in self.rules:
|
|
if rule["id"] != "consecutive_failure":
|
|
continue
|
|
threshold = rule.get(
|
|
"triggers", [
|
|
{}])[0].get(
|
|
"consecutive_failures", 3)
|
|
if failure_count >= threshold:
|
|
return GuardrailViolation(
|
|
rule_id=rule["id"],
|
|
rule_name=rule["name"],
|
|
severity=rule.get("severity", "warning"),
|
|
message=rule.get("message", f"连续失败{failure_count}次"),
|
|
action=rule.get("action", "pause_and_escalate"),
|
|
)
|
|
return None
|
|
|
|
@property
|
|
def is_enabled(self) -> bool:
|
|
return self.settings.get("enabled", True)
|