d58e38d58f
PR #14 从旧分支复制文件导致回退了 PR #10 的 lint 修复。 修复内容: - autoflake 移除未使用导入/变量 - autopep8 修复缩进/空格 - 手动修复 F821(pathlib→Path), F541(f-string), F841(未使用变量) - 所有修复均通过 flake8 --max-line-length=120 --extend-ignore=E501 检查 (0 errors)
149 lines
5.6 KiB
Python
149 lines
5.6 KiB
Python
"""ActiveAgentCounter — 并发控制 + 冷却机制
|
||
|
||
v2.1:per (agent, session) 粒度。三层控制:
|
||
- per session key: max_per_session(同 session 不能并发 spawn)
|
||
- per agent: max_concurrent_sessions(同 agent 最多 N 个不同 session)
|
||
- global: max_global(全局总并发上限)
|
||
|
||
asyncio Semaphore 实现。
|
||
延迟创建 Semaphore(兼容 Python 3.9 无 event loop 时的构造)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
from typing import Dict, Optional
|
||
|
||
logger = logging.getLogger("moziplus-v2.counter")
|
||
|
||
|
||
class ActiveAgentCounter:
|
||
"""异步并发计数器 + 冷却机制(v2.1: per session 粒度)"""
|
||
|
||
def __init__(self, max_global: int = 5, max_per_session: int = 1,
|
||
max_concurrent_sessions: int = 3,
|
||
default_cooldown_seconds: float = 120.0,
|
||
# 向后兼容旧配置
|
||
max_per_agent: Optional[int] = None):
|
||
self._max_global = max_global
|
||
self._max_per_session = max_per_session
|
||
self._max_concurrent_sessions = max_concurrent_sessions
|
||
self._default_cooldown_seconds = default_cooldown_seconds
|
||
|
||
# 如果调用方传了旧的 max_per_agent,映射到 max_per_session
|
||
if max_per_agent is not None and max_per_session == 1:
|
||
self._max_per_session = max_per_agent
|
||
|
||
self._global_sem: Optional[asyncio.Semaphore] = None
|
||
self._per_key: Dict[str, asyncio.Semaphore] = {}
|
||
# key → 引用计数(同一个 key 理论上只有 0 或 1)
|
||
self._active_keys: Dict[str, int] = {}
|
||
# agent_id → 活跃 session 数
|
||
self._agent_active: Dict[str, int] = {}
|
||
self._global_active: int = 0
|
||
# 冷却机制(per agent)
|
||
self._cooldown_until: Dict[str, float] = {}
|
||
|
||
@staticmethod
|
||
def _make_key(agent_id: str, session_id: str) -> str:
|
||
return f"{agent_id}:{session_id}"
|
||
|
||
def _get_global_sem(self) -> asyncio.Semaphore:
|
||
if self._global_sem is None:
|
||
self._global_sem = asyncio.Semaphore(self._max_global)
|
||
return self._global_sem
|
||
|
||
def _get_key_sem(self, key: str) -> asyncio.Semaphore:
|
||
if key not in self._per_key:
|
||
self._per_key[key] = asyncio.Semaphore(self._max_per_session)
|
||
return self._per_key[key]
|
||
|
||
def is_cooling_down(self, agent_id: str) -> bool:
|
||
"""检查 agent 是否在冷却期"""
|
||
until = self._cooldown_until.get(agent_id)
|
||
if until and time.time() < until:
|
||
return True
|
||
self._cooldown_until.pop(agent_id, None)
|
||
return False
|
||
|
||
def set_cooldown(self, agent_id: str,
|
||
seconds: Optional[float] = None) -> None:
|
||
"""设置冷却期(默认 120 秒)"""
|
||
cd = seconds if seconds is not None else self._default_cooldown_seconds
|
||
self._cooldown_until[agent_id] = time.time() + cd
|
||
logger.info("Cooldown set for %s: %.0fs (until %.0f)",
|
||
agent_id, cd, self._cooldown_until[agent_id])
|
||
|
||
async def can_acquire(self, agent_id: str,
|
||
session_id: str = "main") -> bool:
|
||
"""三层检查:cooldown → global → per agent → per session key"""
|
||
if self.is_cooling_down(agent_id):
|
||
return False
|
||
if self._global_active >= self._max_global:
|
||
return False
|
||
if self._agent_active.get(
|
||
agent_id, 0) >= self._max_concurrent_sessions:
|
||
return False
|
||
key = self._make_key(agent_id, session_id)
|
||
if self._active_keys.get(key, 0) >= self._max_per_session:
|
||
return False
|
||
return True
|
||
|
||
async def acquire(self, agent_id: str, session_id: str = "main") -> bool:
|
||
"""占用 per-session key + per-agent 计数 + global semaphore"""
|
||
if not await self.can_acquire(agent_id, session_id):
|
||
return False
|
||
|
||
key = self._make_key(agent_id, session_id)
|
||
await self._get_global_sem().acquire()
|
||
await self._get_key_sem(key).acquire()
|
||
|
||
self._global_active += 1
|
||
self._active_keys[key] = self._active_keys.get(key, 0) + 1
|
||
self._agent_active[agent_id] = self._agent_active.get(agent_id, 0) + 1
|
||
return True
|
||
|
||
def release(self, agent_id: str, session_id: str = "main") -> None:
|
||
"""释放 per-session key + per-agent 计数 + global semaphore"""
|
||
key = self._make_key(agent_id, session_id)
|
||
|
||
if key in self._per_key:
|
||
self._per_key[key].release()
|
||
# 如果 key 不再活跃,清理 semaphore
|
||
if self._active_keys.get(key, 0) <= 1:
|
||
del self._per_key[key]
|
||
|
||
if self._global_sem:
|
||
self._global_sem.release()
|
||
self._global_active = max(0, self._global_active - 1)
|
||
|
||
if key in self._active_keys:
|
||
self._active_keys[key] = max(0, self._active_keys[key] - 1)
|
||
if self._active_keys[key] == 0:
|
||
del self._active_keys[key]
|
||
|
||
if agent_id in self._agent_active:
|
||
self._agent_active[agent_id] = max(
|
||
0, self._agent_active[agent_id] - 1)
|
||
if self._agent_active[agent_id] == 0:
|
||
del self._agent_active[agent_id]
|
||
|
||
@property
|
||
def global_active(self) -> int:
|
||
return self._global_active
|
||
|
||
@property
|
||
def max_global(self) -> int:
|
||
return self._max_global
|
||
|
||
@property
|
||
def active_agents(self) -> Dict[str, int]:
|
||
"""返回 per agent 的活跃 session 数"""
|
||
return dict(self._agent_active)
|
||
|
||
def is_near_limit(self, margin: int = 1) -> bool:
|
||
"""全局活跃数是否接近上限"""
|
||
return self._global_active >= self._max_global - margin
|