diff --git a/src/daemon/prompt_composer.py b/src/daemon/prompt_composer.py new file mode 100644 index 0000000..1940f10 --- /dev/null +++ b/src/daemon/prompt_composer.py @@ -0,0 +1,127 @@ +""" +prompt_composer.py — PromptSection Protocol + PromptContext + PromptComposer + +拼装器:有序管理 prompt 段落,按优先级排序后合并为最终 prompt。 +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable + +logger = logging.getLogger("moziplus-v2.prompt_composer") + +# --------------------------------------------------------------------------- +# Section 优先级范围约定 +# --------------------------------------------------------------------------- +PRIORITY_CONTEXT = 10 # 任务上下文 +PRIORITY_PRIOR = 20 # 前序信息 +PRIORITY_ROLE = 30 # 角色规范 +PRIORITY_API = 40 # API 操作指令 +PRIORITY_CONSTRAINTS = 50 # 硬约束 +PRIORITY_EXTENSION = 60 # 扩展段 + + +# --------------------------------------------------------------------------- +# PromptSection Protocol +# --------------------------------------------------------------------------- +@runtime_checkable +class PromptSection(Protocol): + """一个 prompt 段""" + + name: str # 段名(去重用,同名覆盖) + priority: int # 排序优先级(小数字=靠前) + + def render(self, context: "PromptContext") -> str: + """渲染此段的文本内容。返回空字符串表示不注入。""" + ... + + def should_include(self, context: "PromptContext") -> bool: + """是否注入此段(默认 True,条件段可覆盖)。""" + ... + + +# --------------------------------------------------------------------------- +# PromptContext 数据对象 +# --------------------------------------------------------------------------- +@dataclass +class PromptContext: + """Prompt 渲染的统一上下文""" + + task_id: str + title: str + description: str + must_haves: str + project_id: str + agent_id: str + + task: Optional[Dict] = None + role: str = "executor" + spawn_type: str = "executor" + + # mail 专用 + from_agent: str = "" + mail_type: str = "" # inform / request + + # toolchain 专用 + event_type: str = "" # ci_failure / review_request / ... + event_data: Dict = field(default_factory=dict) + + # 前序产出 + depends_on_outputs: Optional[List] = None + + +# --------------------------------------------------------------------------- +# PromptComposer 拼装器 +# --------------------------------------------------------------------------- +class PromptComposer: + """有序拼装 prompt sections""" + + SEPARATOR = "\n\n---\n\n" + TOKEN_BUDGET_WARN = 800 # token 预算警告阈值 + CHARS_PER_TOKEN = 3.5 # 估算比率 + + def __init__(self) -> None: + self._sections: List[Any] = [] # List[PromptSection] + + def add(self, section: Any) -> None: + """添加一个 section(同名覆盖)""" + self._sections = [s for s in self._sections if s.name != section.name] + self._sections.append(section) + + def add_many(self, sections: List[Any]) -> None: + """批量添加""" + for s in sections: + self.add(s) + + def compose(self, context: PromptContext) -> str: + """拼装最终 prompt + + 1. 过滤 should_include=False 的段 + 2. 按 priority 排序 + 3. 逐段 render + 4. 过滤空段 + 5. 用分隔符连接 + 6. Token 预算警告(不截断) + """ + active = [s for s in self._sections if s.should_include(context)] + active.sort(key=lambda s: s.priority) + + parts = [s.render(context) for s in active] + parts = [p for p in parts if p.strip()] + + result = self.SEPARATOR.join(parts) + + # Token 估算 + tokens = max(1, int(len(result) / self.CHARS_PER_TOKEN)) + logger.debug( + "Composed prompt from %d sections, %d tokens", + len(parts), tokens, + ) + + if tokens > self.TOKEN_BUDGET_WARN: + logger.warning( + "Prompt exceeds %d token budget: %d tokens (task_id=%s)", + self.TOKEN_BUDGET_WARN, tokens, context.task_id, + ) + + return result diff --git a/src/daemon/task_type_registry.py b/src/daemon/task_type_registry.py new file mode 100644 index 0000000..ba3da5a --- /dev/null +++ b/src/daemon/task_type_registry.py @@ -0,0 +1,137 @@ +""" +task_type_registry.py — Task type handler Protocol + Registry. + +启动时一次性加载 handler,运行时只读。 +零依赖:不导入项目内其他模块。 +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable + +logger = logging.getLogger("moziplus-v2.registry") + + +# --------------------------------------------------------------------------- +# Protocol +# --------------------------------------------------------------------------- + +@runtime_checkable +class TaskTypeHandler(Protocol): + """所有 task type handler 的统一接口。""" + + # 属性(通过 __init__ 设置) + task_type: str # 类型标识:'task' | 'mail' | 'toolchain' + virtual_project: Optional[str] # 虚拟项目 ID,如 '_mail'、'_toolchain'。普通任务为 None + + def build_prompt( + self, + task_id: str, + title: str, + description: str, + must_haves: str, + project_id: str, + agent_id: str, + task: Optional[Dict] = None, + spawn_type: str = "executor", + spawner: Any = None, + ) -> str: + """构建 Agent prompt。""" + ... + + def build_api_section( + self, project_id: str, task_id: str, agent_id: str + ) -> str: + """构建 API 操作指令(success_status 等)。""" + ... + + def skip_guardrail(self, project_id: str) -> bool: + """是否跳过 guardrail 检查。""" + ... + + def pre_spawn( + self, task_id: str, db_path: Path, dispatcher: Any + ) -> Optional[Any]: + """spawn 前回调,返回 on_checks_passed 回调或 None。""" + ... + + def post_complete( + self, + task_id: str, + agent_id: str, + outcome: str, + db_path: Path, + must_haves: str, + dispatcher: Any, + ) -> None: + """spawn 完成后回调。""" + ... + + def build_retry_prompt( + self, + task_id: str, + agent_id: str, + retry_count: int, + max_retries: int, + retry_field: str, + task_info: Dict, + spawner: Any, + ) -> str: + """构建重试 prompt。""" + ... + + def check_completion(self, task_id: str, db_path: Path) -> bool: + """检查任务是否已完成(如 mail 的回复检查)。""" + ... + + def get_sections(self) -> list: + """返回此 handler 的 prompt section 列表。""" + ... + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +class TaskTypeRegistry: + """Task type handler 注册表。启动时一次性加载,运行时只读。""" + + _handlers: Dict[str, TaskTypeHandler] = {} + + @classmethod + def register(cls, handler: TaskTypeHandler) -> None: + """注册一个 handler。启动时调用一次。""" + if handler.task_type in cls._handlers: + raise ValueError(f"Task type '{handler.task_type}' already registered") + cls._handlers[handler.task_type] = handler + vp = getattr(handler, "virtual_project", None) + logger.info("Registered task type handler: %s (virtual_project=%s)", handler.task_type, vp) + + @classmethod + def get_by_project(cls, project_id: str) -> Optional[TaskTypeHandler]: + """通过 project_id 查找 handler(匹配 virtual_project)。""" + for h in cls._handlers.values(): + if h.virtual_project == project_id: + return h + return None + + @classmethod + def get(cls, task_type: str) -> Optional[TaskTypeHandler]: + """通过 task_type 标识查找 handler。""" + return cls._handlers.get(task_type) + + @classmethod + def virtual_projects(cls) -> list[str]: + """返回所有已注册的虚拟项目 ID(ticker 自动发现用)。""" + return [ + h.virtual_project + for h in cls._handlers.values() + if h.virtual_project is not None + ] + + @classmethod + def clear(cls) -> None: + """清空注册表(仅测试用)。""" + cls._handlers = {}