117 lines
3.6 KiB
Python
117 lines
3.6 KiB
Python
"""
|
|
自动化回测服务 - 任务队列
|
|
"""
|
|
import os
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import List, Optional, Dict
|
|
from multiprocessing import Pool
|
|
from .config import settings
|
|
from .models import TaskStatus, BacktestTask, BacktestTaskWithId
|
|
from .result_storage import storage
|
|
|
|
|
|
class TaskQueue:
|
|
"""任务队列管理器"""
|
|
|
|
def __init__(self):
|
|
self.max_workers = settings.max_workers
|
|
self.pending_tasks: List[str] = []
|
|
self.running_tasks: List[str] = []
|
|
self.completed_tasks: List[str] = []
|
|
self.failed_tasks: List[str] = []
|
|
self._pool: Optional[Pool] = None
|
|
|
|
def _generate_task_id(self) -> str:
|
|
"""生成唯一任务ID"""
|
|
return str(uuid.uuid4()).replace("-", "")
|
|
|
|
def submit_task(self, task: BacktestTask) -> BacktestTaskWithId:
|
|
"""提交新任务到队列"""
|
|
task_id = self._generate_task_id()
|
|
now = datetime.now().isoformat()
|
|
|
|
task_with_id = BacktestTaskWithId(
|
|
task_id=task_id,
|
|
status=TaskStatus.PENDING,
|
|
created_at=now,
|
|
**task.model_dump()
|
|
)
|
|
|
|
storage.save_task(task_with_id)
|
|
self.pending_tasks.append(task_id)
|
|
return task_with_id
|
|
|
|
def list_tasks(self, page: int = 1, page_size: int = 10, status: Optional[str] = None) -> Dict:
|
|
"""列出任务,支持分页和状态过滤"""
|
|
if status == "pending":
|
|
task_ids = self.pending_tasks
|
|
elif status == "running":
|
|
task_ids = self.running_tasks
|
|
elif status == "completed":
|
|
task_ids = self.completed_tasks
|
|
elif status == "failed":
|
|
task_ids = self.failed_tasks
|
|
else:
|
|
task_ids = (
|
|
self.pending_tasks +
|
|
self.running_tasks +
|
|
self.completed_tasks +
|
|
self.failed_tasks
|
|
)
|
|
|
|
total = len(task_ids)
|
|
start = (page - 1) * page_size
|
|
end = start + page_size
|
|
|
|
result = []
|
|
for task_id in task_ids[start:end]:
|
|
# 根据状态找任务
|
|
if task_id in self.pending_tasks:
|
|
task = storage.load_task(task_id, "pending")
|
|
elif task_id in self.running_tasks:
|
|
task = storage.load_task(task_id, "running")
|
|
elif task_id in self.completed_tasks:
|
|
task = storage.load_task(task_id, "completed")
|
|
else:
|
|
task = storage.load_task(task_id, "failed")
|
|
|
|
if task:
|
|
result.append(task)
|
|
|
|
return {
|
|
"total": total,
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"tasks": result
|
|
}
|
|
|
|
def get_task(self, task_id: str) -> Optional[BacktestTaskWithId]:
|
|
"""根据ID获取任务"""
|
|
# 在各个状态查找
|
|
for status_dir, task_list in [
|
|
("pending", self.pending_tasks),
|
|
("running", self.running_tasks),
|
|
("completed", self.completed_tasks),
|
|
("failed", self.failed_tasks),
|
|
]:
|
|
if task_id in task_list:
|
|
return storage.load_task(task_id, status_dir)
|
|
|
|
return None
|
|
|
|
def start_worker_pool(self):
|
|
"""启动工作进程池"""
|
|
if self._pool is None:
|
|
self._pool = Pool(processes=self.max_workers)
|
|
|
|
def close_worker_pool(self):
|
|
"""关闭工作进程池"""
|
|
if self._pool is not None:
|
|
self._pool.close()
|
|
self._pool.join()
|
|
self._pool = None
|
|
|
|
|
|
task_queue = TaskQueue()
|