Files
sanguo_moziplus_v2/src/daemon/counter.py
T
cfdaily d58e38d58f
CI / lint (pull_request) Successful in 6s
CI / test (pull_request) Successful in 9s
CI / notify-on-failure (pull_request) Successful in 0s
fix(lint): 修复 PR #14 引入的 lint 回退 (119→0)
PR #14 从旧分支复制文件导致回退了 PR #10 的 lint 修复。
修复内容:
- autoflake 移除未使用导入/变量
- autopep8 修复缩进/空格
- 手动修复 F821(pathlib→Path), F541(f-string), F841(未使用变量)
- 所有修复均通过 flake8 --max-line-length=120 --extend-ignore=E501 检查 (0 errors)
2026-06-09 23:53:29 +08:00

149 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""ActiveAgentCounter — 并发控制 + 冷却机制
v2.1per (agent, session) 粒度。三层控制:
- per session key: max_per_session(同 session 不能并发 spawn
- per agent: max_concurrent_sessions(同 agent 最多 N 个不同 session
- global: max_global(全局总并发上限)
asyncio Semaphore 实现。
延迟创建 Semaphore(兼容 Python 3.9 无 event loop 时的构造)。
"""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Dict, Optional
logger = logging.getLogger("moziplus-v2.counter")
class ActiveAgentCounter:
"""异步并发计数器 + 冷却机制(v2.1: per session 粒度)"""
def __init__(self, max_global: int = 5, max_per_session: int = 1,
max_concurrent_sessions: int = 3,
default_cooldown_seconds: float = 120.0,
# 向后兼容旧配置
max_per_agent: Optional[int] = None):
self._max_global = max_global
self._max_per_session = max_per_session
self._max_concurrent_sessions = max_concurrent_sessions
self._default_cooldown_seconds = default_cooldown_seconds
# 如果调用方传了旧的 max_per_agent,映射到 max_per_session
if max_per_agent is not None and max_per_session == 1:
self._max_per_session = max_per_agent
self._global_sem: Optional[asyncio.Semaphore] = None
self._per_key: Dict[str, asyncio.Semaphore] = {}
# key → 引用计数(同一个 key 理论上只有 0 或 1)
self._active_keys: Dict[str, int] = {}
# agent_id → 活跃 session 数
self._agent_active: Dict[str, int] = {}
self._global_active: int = 0
# 冷却机制(per agent
self._cooldown_until: Dict[str, float] = {}
@staticmethod
def _make_key(agent_id: str, session_id: str) -> str:
return f"{agent_id}:{session_id}"
def _get_global_sem(self) -> asyncio.Semaphore:
if self._global_sem is None:
self._global_sem = asyncio.Semaphore(self._max_global)
return self._global_sem
def _get_key_sem(self, key: str) -> asyncio.Semaphore:
if key not in self._per_key:
self._per_key[key] = asyncio.Semaphore(self._max_per_session)
return self._per_key[key]
def is_cooling_down(self, agent_id: str) -> bool:
"""检查 agent 是否在冷却期"""
until = self._cooldown_until.get(agent_id)
if until and time.time() < until:
return True
self._cooldown_until.pop(agent_id, None)
return False
def set_cooldown(self, agent_id: str,
seconds: Optional[float] = None) -> None:
"""设置冷却期(默认 120 秒)"""
cd = seconds if seconds is not None else self._default_cooldown_seconds
self._cooldown_until[agent_id] = time.time() + cd
logger.info("Cooldown set for %s: %.0fs (until %.0f)",
agent_id, cd, self._cooldown_until[agent_id])
async def can_acquire(self, agent_id: str,
session_id: str = "main") -> bool:
"""三层检查:cooldown → global → per agent → per session key"""
if self.is_cooling_down(agent_id):
return False
if self._global_active >= self._max_global:
return False
if self._agent_active.get(
agent_id, 0) >= self._max_concurrent_sessions:
return False
key = self._make_key(agent_id, session_id)
if self._active_keys.get(key, 0) >= self._max_per_session:
return False
return True
async def acquire(self, agent_id: str, session_id: str = "main") -> bool:
"""占用 per-session key + per-agent 计数 + global semaphore"""
if not await self.can_acquire(agent_id, session_id):
return False
key = self._make_key(agent_id, session_id)
await self._get_global_sem().acquire()
await self._get_key_sem(key).acquire()
self._global_active += 1
self._active_keys[key] = self._active_keys.get(key, 0) + 1
self._agent_active[agent_id] = self._agent_active.get(agent_id, 0) + 1
return True
def release(self, agent_id: str, session_id: str = "main") -> None:
"""释放 per-session key + per-agent 计数 + global semaphore"""
key = self._make_key(agent_id, session_id)
if key in self._per_key:
self._per_key[key].release()
# 如果 key 不再活跃,清理 semaphore
if self._active_keys.get(key, 0) <= 1:
del self._per_key[key]
if self._global_sem:
self._global_sem.release()
self._global_active = max(0, self._global_active - 1)
if key in self._active_keys:
self._active_keys[key] = max(0, self._active_keys[key] - 1)
if self._active_keys[key] == 0:
del self._active_keys[key]
if agent_id in self._agent_active:
self._agent_active[agent_id] = max(
0, self._agent_active[agent_id] - 1)
if self._agent_active[agent_id] == 0:
del self._agent_active[agent_id]
@property
def global_active(self) -> int:
return self._global_active
@property
def max_global(self) -> int:
return self._max_global
@property
def active_agents(self) -> Dict[str, int]:
"""返回 per agent 的活跃 session 数"""
return dict(self._agent_active)
def is_near_limit(self, margin: int = 1) -> bool:
"""全局活跃数是否接近上限"""
return self._global_active >= self._max_global - margin