128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
"""
|
|
prompt_composer.py — PromptSection Protocol + PromptContext + PromptComposer
|
|
|
|
拼装器:有序管理 prompt 段落,按优先级排序后合并为最终 prompt。
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Protocol, runtime_checkable
|
|
|
|
logger = logging.getLogger("moziplus-v2.prompt_composer")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Section 优先级范围约定
|
|
# ---------------------------------------------------------------------------
|
|
PRIORITY_CONTEXT = 10 # 任务上下文
|
|
PRIORITY_PRIOR = 20 # 前序信息
|
|
PRIORITY_ROLE = 30 # 角色规范
|
|
PRIORITY_API = 40 # API 操作指令
|
|
PRIORITY_CONSTRAINTS = 50 # 硬约束
|
|
PRIORITY_EXTENSION = 60 # 扩展段
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PromptSection Protocol
|
|
# ---------------------------------------------------------------------------
|
|
@runtime_checkable
|
|
class PromptSection(Protocol):
|
|
"""一个 prompt 段"""
|
|
|
|
name: str # 段名(去重用,同名覆盖)
|
|
priority: int # 排序优先级(小数字=靠前)
|
|
|
|
def render(self, context: "PromptContext") -> str:
|
|
"""渲染此段的文本内容。返回空字符串表示不注入。"""
|
|
...
|
|
|
|
def should_include(self, context: "PromptContext") -> bool:
|
|
"""是否注入此段(默认 True,条件段可覆盖)。"""
|
|
...
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PromptContext 数据对象
|
|
# ---------------------------------------------------------------------------
|
|
@dataclass
|
|
class PromptContext:
|
|
"""Prompt 渲染的统一上下文"""
|
|
|
|
task_id: str
|
|
title: str
|
|
description: str
|
|
must_haves: str
|
|
project_id: str
|
|
agent_id: str
|
|
|
|
task: Optional[Dict] = None
|
|
role: str = "executor"
|
|
spawn_type: str = "executor"
|
|
|
|
# mail 专用
|
|
from_agent: str = ""
|
|
mail_type: str = "" # inform / request
|
|
|
|
# toolchain 专用
|
|
event_type: str = "" # ci_failure / review_request / ...
|
|
event_data: Dict = field(default_factory=dict)
|
|
|
|
# 前序产出
|
|
depends_on_outputs: Optional[List] = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PromptComposer 拼装器
|
|
# ---------------------------------------------------------------------------
|
|
class PromptComposer:
|
|
"""有序拼装 prompt sections"""
|
|
|
|
SEPARATOR = "\n\n---\n\n"
|
|
TOKEN_BUDGET_WARN = 800 # token 预算警告阈值
|
|
CHARS_PER_TOKEN = 3.5 # 估算比率
|
|
|
|
def __init__(self) -> None:
|
|
self._sections: List[PromptSection] = []
|
|
|
|
def add(self, section: PromptSection) -> None:
|
|
"""添加一个 section(同名覆盖)"""
|
|
self._sections = [s for s in self._sections if s.name != section.name]
|
|
self._sections.append(section)
|
|
|
|
def add_many(self, sections: List[PromptSection]) -> None:
|
|
"""批量添加"""
|
|
for s in sections:
|
|
self.add(s)
|
|
|
|
def compose(self, context: PromptContext) -> str:
|
|
"""拼装最终 prompt
|
|
|
|
1. 过滤 should_include=False 的段
|
|
2. 按 priority 排序
|
|
3. 逐段 render
|
|
4. 过滤空段
|
|
5. 用分隔符连接
|
|
6. Token 预算警告(不截断)
|
|
"""
|
|
active = [s for s in self._sections if s.should_include(context)]
|
|
active.sort(key=lambda s: s.priority)
|
|
|
|
parts = [s.render(context) for s in active]
|
|
parts = [p for p in parts if p.strip()]
|
|
|
|
result = self.SEPARATOR.join(parts)
|
|
|
|
# Token 估算
|
|
tokens = max(1, int(len(result) / self.CHARS_PER_TOKEN))
|
|
logger.debug(
|
|
"Composed prompt from %d sections, %d tokens",
|
|
len(parts), tokens,
|
|
)
|
|
|
|
if tokens > self.TOKEN_BUDGET_WARN:
|
|
logger.warning(
|
|
"Prompt exceeds %d token budget: %d tokens (task_id=%s)",
|
|
self.TOKEN_BUDGET_WARN, tokens, context.task_id,
|
|
)
|
|
|
|
return result
|