106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
"""ActiveAgentCounter — 并发控制 + 冷却机制
|
|
|
|
全局上限 + per-agent 串行,asyncio Semaphore 实现。
|
|
延迟创建 Semaphore(兼容 Python 3.9 无 event loop 时的构造)。
|
|
v2.7.2:新增 cooldown 机制(429/API 错误后冷却期)。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
logger = logging.getLogger("moziplus-v2.counter")
|
|
|
|
|
|
class ActiveAgentCounter:
|
|
"""异步并发计数器 + 冷却机制"""
|
|
|
|
def __init__(self, max_global: int = 5, max_per_agent: int = 1,
|
|
default_cooldown_seconds: float = 120.0):
|
|
self._max_global = max_global
|
|
self._max_per_agent = max_per_agent
|
|
self._default_cooldown_seconds = default_cooldown_seconds
|
|
self._global_sem: Optional[asyncio.Semaphore] = None
|
|
self._per_agent: Dict[str, asyncio.Semaphore] = {}
|
|
self._active: Dict[str, int] = {}
|
|
self._global_active: int = 0
|
|
# v2.7.2:冷却机制
|
|
self._cooldown_until: Dict[str, float] = {} # agent_id → 冷却到期时间戳
|
|
|
|
def _get_global_sem(self) -> asyncio.Semaphore:
|
|
if self._global_sem is None:
|
|
self._global_sem = asyncio.Semaphore(self._max_global)
|
|
return self._global_sem
|
|
|
|
def _get_agent_sem(self, agent_id: str) -> asyncio.Semaphore:
|
|
if agent_id not in self._per_agent:
|
|
self._per_agent[agent_id] = asyncio.Semaphore(self._max_per_agent)
|
|
return self._per_agent[agent_id]
|
|
|
|
def is_cooling_down(self, agent_id: str) -> bool:
|
|
"""检查 agent 是否在冷却期"""
|
|
until = self._cooldown_until.get(agent_id)
|
|
if until and time.time() < until:
|
|
return True
|
|
# 冷却期已过,清理
|
|
self._cooldown_until.pop(agent_id, None)
|
|
return False
|
|
|
|
def set_cooldown(self, agent_id: str, seconds: Optional[float] = None) -> None:
|
|
"""设置冷却期(默认 120 秒)"""
|
|
cd = seconds if seconds is not None else self._default_cooldown_seconds
|
|
self._cooldown_until[agent_id] = time.time() + cd
|
|
logger.info("Cooldown set for %s: %.0fs (until %.0f)",
|
|
agent_id, cd, self._cooldown_until[agent_id])
|
|
|
|
async def can_acquire(self, agent_id: str) -> bool:
|
|
if self.is_cooling_down(agent_id):
|
|
return False
|
|
if self._global_active >= self._max_global:
|
|
return False
|
|
active = self._active.get(agent_id, 0)
|
|
if active >= self._max_per_agent:
|
|
return False
|
|
return True
|
|
|
|
async def acquire(self, agent_id: str) -> bool:
|
|
if not await self.can_acquire(agent_id):
|
|
return False
|
|
|
|
await self._get_global_sem().acquire()
|
|
await self._get_agent_sem(agent_id).acquire()
|
|
|
|
self._global_active += 1
|
|
self._active[agent_id] = self._active.get(agent_id, 0) + 1
|
|
return True
|
|
|
|
def release(self, agent_id: str) -> None:
|
|
if agent_id in self._per_agent:
|
|
self._per_agent[agent_id].release()
|
|
if self._global_sem:
|
|
self._global_sem.release()
|
|
self._global_active = max(0, self._global_active - 1)
|
|
if agent_id in self._active:
|
|
self._active[agent_id] = max(0, self._active[agent_id] - 1)
|
|
if self._active[agent_id] == 0:
|
|
del self._active[agent_id]
|
|
|
|
@property
|
|
def global_active(self) -> int:
|
|
return self._global_active
|
|
|
|
@property
|
|
def max_global(self) -> int:
|
|
return self._max_global
|
|
|
|
@property
|
|
def active_agents(self) -> Dict[str, int]:
|
|
return dict(self._active)
|
|
|
|
def is_near_limit(self, margin: int = 1) -> bool:
|
|
"""全局活跃数是否接近上限"""
|
|
return self._global_active >= self._max_global - margin
|