Files
sanguo_vnpy/scripts/backtest-service/task_queue.py
T
2026-04-12 10:18:57 +08:00

117 lines
3.6 KiB
Python

"""
自动化回测服务 - 任务队列
"""
import os
import uuid
from datetime import datetime
from typing import List, Optional, Dict
from multiprocessing import Pool
from .config import settings
from .models import TaskStatus, BacktestTask, BacktestTaskWithId
from .result_storage import storage
class TaskQueue:
"""任务队列管理器"""
def __init__(self):
self.max_workers = settings.max_workers
self.pending_tasks: List[str] = []
self.running_tasks: List[str] = []
self.completed_tasks: List[str] = []
self.failed_tasks: List[str] = []
self._pool: Optional[Pool] = None
def _generate_task_id(self) -> str:
"""生成唯一任务ID"""
return str(uuid.uuid4()).replace("-", "")
def submit_task(self, task: BacktestTask) -> BacktestTaskWithId:
"""提交新任务到队列"""
task_id = self._generate_task_id()
now = datetime.now().isoformat()
task_with_id = BacktestTaskWithId(
task_id=task_id,
status=TaskStatus.PENDING,
created_at=now,
**task.model_dump()
)
storage.save_task(task_with_id)
self.pending_tasks.append(task_id)
return task_with_id
def list_tasks(self, page: int = 1, page_size: int = 10, status: Optional[str] = None) -> Dict:
"""列出任务,支持分页和状态过滤"""
if status == "pending":
task_ids = self.pending_tasks
elif status == "running":
task_ids = self.running_tasks
elif status == "completed":
task_ids = self.completed_tasks
elif status == "failed":
task_ids = self.failed_tasks
else:
task_ids = (
self.pending_tasks +
self.running_tasks +
self.completed_tasks +
self.failed_tasks
)
total = len(task_ids)
start = (page - 1) * page_size
end = start + page_size
result = []
for task_id in task_ids[start:end]:
# 根据状态找任务
if task_id in self.pending_tasks:
task = storage.load_task(task_id, "pending")
elif task_id in self.running_tasks:
task = storage.load_task(task_id, "running")
elif task_id in self.completed_tasks:
task = storage.load_task(task_id, "completed")
else:
task = storage.load_task(task_id, "failed")
if task:
result.append(task)
return {
"total": total,
"page": page,
"page_size": page_size,
"tasks": result
}
def get_task(self, task_id: str) -> Optional[BacktestTaskWithId]:
"""根据ID获取任务"""
# 在各个状态查找
for status_dir, task_list in [
("pending", self.pending_tasks),
("running", self.running_tasks),
("completed", self.completed_tasks),
("failed", self.failed_tasks),
]:
if task_id in task_list:
return storage.load_task(task_id, status_dir)
return None
def start_worker_pool(self):
"""启动工作进程池"""
if self._pool is None:
self._pool = Pool(processes=self.max_workers)
def close_worker_pool(self):
"""关闭工作进程池"""
if self._pool is not None:
self._pool.close()
self._pool.join()
self._pool = None
task_queue = TaskQueue()