feat: Step 1 — TaskTypeRegistry + PromptComposer 基础设施
- task_type_registry.py: TaskTypeHandler Protocol (10方法+2属性) + TaskTypeRegistry 注册表 - prompt_composer.py: PromptSection Protocol + PromptContext dataclass + PromptComposer 拼装器 - 零依赖,纯新增文件,不影响现有功能
This commit is contained in:
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
prompt_composer.py — PromptSection Protocol + PromptContext + PromptComposer
|
||||
|
||||
拼装器:有序管理 prompt 段落,按优先级排序后合并为最终 prompt。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
logger = logging.getLogger("moziplus-v2.prompt_composer")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section 优先级范围约定
|
||||
# ---------------------------------------------------------------------------
|
||||
PRIORITY_CONTEXT = 10 # 任务上下文
|
||||
PRIORITY_PRIOR = 20 # 前序信息
|
||||
PRIORITY_ROLE = 30 # 角色规范
|
||||
PRIORITY_API = 40 # API 操作指令
|
||||
PRIORITY_CONSTRAINTS = 50 # 硬约束
|
||||
PRIORITY_EXTENSION = 60 # 扩展段
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PromptSection Protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
@runtime_checkable
|
||||
class PromptSection(Protocol):
|
||||
"""一个 prompt 段"""
|
||||
|
||||
name: str # 段名(去重用,同名覆盖)
|
||||
priority: int # 排序优先级(小数字=靠前)
|
||||
|
||||
def render(self, context: "PromptContext") -> str:
|
||||
"""渲染此段的文本内容。返回空字符串表示不注入。"""
|
||||
...
|
||||
|
||||
def should_include(self, context: "PromptContext") -> bool:
|
||||
"""是否注入此段(默认 True,条件段可覆盖)。"""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PromptContext 数据对象
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class PromptContext:
|
||||
"""Prompt 渲染的统一上下文"""
|
||||
|
||||
task_id: str
|
||||
title: str
|
||||
description: str
|
||||
must_haves: str
|
||||
project_id: str
|
||||
agent_id: str
|
||||
|
||||
task: Optional[Dict] = None
|
||||
role: str = "executor"
|
||||
spawn_type: str = "executor"
|
||||
|
||||
# mail 专用
|
||||
from_agent: str = ""
|
||||
mail_type: str = "" # inform / request
|
||||
|
||||
# toolchain 专用
|
||||
event_type: str = "" # ci_failure / review_request / ...
|
||||
event_data: Dict = field(default_factory=dict)
|
||||
|
||||
# 前序产出
|
||||
depends_on_outputs: Optional[List] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PromptComposer 拼装器
|
||||
# ---------------------------------------------------------------------------
|
||||
class PromptComposer:
|
||||
"""有序拼装 prompt sections"""
|
||||
|
||||
SEPARATOR = "\n\n---\n\n"
|
||||
TOKEN_BUDGET_WARN = 800 # token 预算警告阈值
|
||||
CHARS_PER_TOKEN = 3.5 # 估算比率
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sections: List[Any] = [] # List[PromptSection]
|
||||
|
||||
def add(self, section: Any) -> None:
|
||||
"""添加一个 section(同名覆盖)"""
|
||||
self._sections = [s for s in self._sections if s.name != section.name]
|
||||
self._sections.append(section)
|
||||
|
||||
def add_many(self, sections: List[Any]) -> None:
|
||||
"""批量添加"""
|
||||
for s in sections:
|
||||
self.add(s)
|
||||
|
||||
def compose(self, context: PromptContext) -> str:
|
||||
"""拼装最终 prompt
|
||||
|
||||
1. 过滤 should_include=False 的段
|
||||
2. 按 priority 排序
|
||||
3. 逐段 render
|
||||
4. 过滤空段
|
||||
5. 用分隔符连接
|
||||
6. Token 预算警告(不截断)
|
||||
"""
|
||||
active = [s for s in self._sections if s.should_include(context)]
|
||||
active.sort(key=lambda s: s.priority)
|
||||
|
||||
parts = [s.render(context) for s in active]
|
||||
parts = [p for p in parts if p.strip()]
|
||||
|
||||
result = self.SEPARATOR.join(parts)
|
||||
|
||||
# Token 估算
|
||||
tokens = max(1, int(len(result) / self.CHARS_PER_TOKEN))
|
||||
logger.debug(
|
||||
"Composed prompt from %d sections, %d tokens",
|
||||
len(parts), tokens,
|
||||
)
|
||||
|
||||
if tokens > self.TOKEN_BUDGET_WARN:
|
||||
logger.warning(
|
||||
"Prompt exceeds %d token budget: %d tokens (task_id=%s)",
|
||||
self.TOKEN_BUDGET_WARN, tokens, context.task_id,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
task_type_registry.py — Task type handler Protocol + Registry.
|
||||
|
||||
启动时一次性加载 handler,运行时只读。
|
||||
零依赖:不导入项目内其他模块。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
logger = logging.getLogger("moziplus-v2.registry")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@runtime_checkable
|
||||
class TaskTypeHandler(Protocol):
|
||||
"""所有 task type handler 的统一接口。"""
|
||||
|
||||
# 属性(通过 __init__ 设置)
|
||||
task_type: str # 类型标识:'task' | 'mail' | 'toolchain'
|
||||
virtual_project: Optional[str] # 虚拟项目 ID,如 '_mail'、'_toolchain'。普通任务为 None
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
task_id: str,
|
||||
title: str,
|
||||
description: str,
|
||||
must_haves: str,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
task: Optional[Dict] = None,
|
||||
spawn_type: str = "executor",
|
||||
spawner: Any = None,
|
||||
) -> str:
|
||||
"""构建 Agent prompt。"""
|
||||
...
|
||||
|
||||
def build_api_section(
|
||||
self, project_id: str, task_id: str, agent_id: str
|
||||
) -> str:
|
||||
"""构建 API 操作指令(success_status 等)。"""
|
||||
...
|
||||
|
||||
def skip_guardrail(self, project_id: str) -> bool:
|
||||
"""是否跳过 guardrail 检查。"""
|
||||
...
|
||||
|
||||
def pre_spawn(
|
||||
self, task_id: str, db_path: Path, dispatcher: Any
|
||||
) -> Optional[Any]:
|
||||
"""spawn 前回调,返回 on_checks_passed 回调或 None。"""
|
||||
...
|
||||
|
||||
def post_complete(
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
outcome: str,
|
||||
db_path: Path,
|
||||
must_haves: str,
|
||||
dispatcher: Any,
|
||||
) -> None:
|
||||
"""spawn 完成后回调。"""
|
||||
...
|
||||
|
||||
def build_retry_prompt(
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
retry_count: int,
|
||||
max_retries: int,
|
||||
retry_field: str,
|
||||
task_info: Dict,
|
||||
spawner: Any,
|
||||
) -> str:
|
||||
"""构建重试 prompt。"""
|
||||
...
|
||||
|
||||
def check_completion(self, task_id: str, db_path: Path) -> bool:
|
||||
"""检查任务是否已完成(如 mail 的回复检查)。"""
|
||||
...
|
||||
|
||||
def get_sections(self) -> list:
|
||||
"""返回此 handler 的 prompt section 列表。"""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TaskTypeRegistry:
|
||||
"""Task type handler 注册表。启动时一次性加载,运行时只读。"""
|
||||
|
||||
_handlers: Dict[str, TaskTypeHandler] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, handler: TaskTypeHandler) -> None:
|
||||
"""注册一个 handler。启动时调用一次。"""
|
||||
if handler.task_type in cls._handlers:
|
||||
raise ValueError(f"Task type '{handler.task_type}' already registered")
|
||||
cls._handlers[handler.task_type] = handler
|
||||
vp = getattr(handler, "virtual_project", None)
|
||||
logger.info("Registered task type handler: %s (virtual_project=%s)", handler.task_type, vp)
|
||||
|
||||
@classmethod
|
||||
def get_by_project(cls, project_id: str) -> Optional[TaskTypeHandler]:
|
||||
"""通过 project_id 查找 handler(匹配 virtual_project)。"""
|
||||
for h in cls._handlers.values():
|
||||
if h.virtual_project == project_id:
|
||||
return h
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get(cls, task_type: str) -> Optional[TaskTypeHandler]:
|
||||
"""通过 task_type 标识查找 handler。"""
|
||||
return cls._handlers.get(task_type)
|
||||
|
||||
@classmethod
|
||||
def virtual_projects(cls) -> list[str]:
|
||||
"""返回所有已注册的虚拟项目 ID(ticker 自动发现用)。"""
|
||||
return [
|
||||
h.virtual_project
|
||||
for h in cls._handlers.values()
|
||||
if h.virtual_project is not None
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表(仅测试用)。"""
|
||||
cls._handlers = {}
|
||||
Reference in New Issue
Block a user