auto-sync: 2026-04-29 20:14:41
This commit is contained in:
@@ -0,0 +1,149 @@
|
||||
# 自动化回测服务 - 使用说明
|
||||
|
||||
## 概述
|
||||
|
||||
基于 vnpy 原生 `BacktestingEngine` 封装的 RESTful API 自动化回测服务。
|
||||
|
||||
支持:
|
||||
- 通过 API 提交回测任务
|
||||
- 自动排队控制并发
|
||||
- 使用 vnpy 原生回测引擎执行
|
||||
- 保存回测结果(CSV + 图表 + JSON)
|
||||
- 查询任务状态和结果
|
||||
|
||||
## 架构设计
|
||||
|
||||
严格遵循 vnpy 原生设计,不修改核心,只做外层封装:
|
||||
|
||||
```
|
||||
[API 服务] ← 接收任务
|
||||
↓
|
||||
[任务队列 + 进程池] ← 控制并发
|
||||
↓
|
||||
[BacktestingEngine (vnpy 原生)] ← 执行回测
|
||||
↓
|
||||
[文件存储] ← 保存结果
|
||||
↓
|
||||
[API 查询] ← 返回结果
|
||||
```
|
||||
|
||||
## 启动方式
|
||||
|
||||
```bash
|
||||
# 手动启动
|
||||
cd /app/scripts/backtest-service
|
||||
python main.py
|
||||
|
||||
# 后台运行
|
||||
nohup python main.py > backtest-service.log 2>&1 &
|
||||
|
||||
# 查看日志
|
||||
tail -f backtest-service.log
|
||||
```
|
||||
|
||||
## 访问地址
|
||||
|
||||
启动后访问:
|
||||
- 服务地址:http://container-ip:8088
|
||||
- API 文档:http://container-ip:8088/docs → 可以直接在网页上测试API
|
||||
|
||||
## API 接口说明
|
||||
|
||||
| 接口 | 方法 | 路径 | 说明 |
|
||||
|------|------|------|------|
|
||||
| 提交回测 | POST | `/api/backtest/submit` | 提交新的回测任务 |
|
||||
| 任务列表 | GET | `/api/backtest/list` | 列出任务,支持分页和状态过滤 |
|
||||
| 任务状态 | GET | `/api/backtest/status/{task_id}` | 查询任务状态 |
|
||||
| 回测结果 | GET | `/api/backtest/result/{task_id}` | 获取完整回测结果 |
|
||||
| 删除任务 | DELETE | `/api/backtest/delete/{task_id}` | 删除任务 |
|
||||
| 健康检查 | GET | `/api/backtest/health` | 查看服务状态,返回任务统计 |
|
||||
|
||||
## 配置
|
||||
|
||||
可以通过环境变量覆盖默认配置:
|
||||
|
||||
| 环境变量 | 说明 | 默认值 |
|
||||
|----------|------|--------|
|
||||
| `MAX_WORKERS` | 最大并发回测数 | 2 |
|
||||
| `HOST` | 监听地址 | 0.0.0.0 |
|
||||
| `PORT` | 监听端口 | 8088 |
|
||||
| `BASE_DIR` | 任务存储根目录 | /app/backtest_jobs |
|
||||
| `DEBUG` | 调试模式 | True |
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 1. 提交回测
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8088/api/backtest/submit \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"strategy_name": "双均线策略",
|
||||
"strategy_code": "完整的策略Python代码...",
|
||||
"symbol": "IF888.CFFEX",
|
||||
"interval": "1h",
|
||||
"start_date": "2020-01-01",
|
||||
"end_date": "2025-01-01",
|
||||
"parameters": {
|
||||
"fast_window": 5,
|
||||
"slow_window": 20
|
||||
},
|
||||
"capital": 1000000
|
||||
}'
|
||||
```
|
||||
|
||||
响应:
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"msg": "任务提交成功",
|
||||
"data": {
|
||||
"task_id": "a1b2c3d4...",
|
||||
"status": "pending",
|
||||
"created_at": "2026-04-12T10:00:00+08:00"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 查询任务状态
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8088/api/backtest/status/a1b2c3d4
|
||||
```
|
||||
|
||||
### 3. 获取回测结果
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8088/api/backtest/result/a1b2c3d4
|
||||
```
|
||||
|
||||
## 结果存储结构
|
||||
|
||||
```
|
||||
/app/backtest_jobs/
|
||||
├── pending/ # 等待执行
|
||||
├── running/ # 执行中
|
||||
├── completed/ # 已完成
|
||||
│ └── <task_id>/
|
||||
│ ├── task.json # 任务信息
|
||||
│ ├── result.json # 回测结果
|
||||
│ ├── equity.csv # 每日净值
|
||||
│ ├── equity_curve.png # 收益曲线图
|
||||
│ └── trades.csv # 成交记录
|
||||
└── failed/ # 执行失败
|
||||
└── <task_id>/
|
||||
├── task.json
|
||||
└── result.json # 包含错误信息
|
||||
```
|
||||
|
||||
## 设计原则
|
||||
|
||||
1. **不改动 vnpy 核心**:完全复用原生 `BacktestingEngine`
|
||||
2. **轻量级**:只用 Python 标准库,不引入额外第三方任务队列
|
||||
3. **隔离性**:每个回测在独立进程,一个失败不影响其他
|
||||
4. **可配置并发**:根据 CPU 性能调整 `MAX_WORKERS`,避免资源耗尽
|
||||
|
||||
## 作者
|
||||
|
||||
三国量化团队 姜维 伯约
|
||||
2026-04-12
|
||||
Executable
+92
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
自动化回测服务 - API路由
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from .models import (
|
||||
BacktestTask,
|
||||
BacktestTaskWithId,
|
||||
BacktestResult,
|
||||
TaskListResponse,
|
||||
ApiResponse,
|
||||
HealthCheckResponse,
|
||||
TaskStatus,
|
||||
)
|
||||
from .task_queue import task_queue
|
||||
from .result_storage import storage
|
||||
from .executor import executor
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/submit", summary="提交回测任务")
|
||||
def submit_task(task: BacktestTask) -> ApiResponse[BacktestTaskWithId]:
|
||||
"""提交一个新的回测任务"""
|
||||
task_with_id = task_queue.submit_task(task)
|
||||
storage.save_task(task_with_id)
|
||||
return ApiResponse(
|
||||
code=0,
|
||||
msg="任务提交成功",
|
||||
data=task_with_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", summary="列出回测任务")
|
||||
def list_tasks(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
status: Optional[str] = Query(None, description="状态过滤 pending/running/completed/failed")
|
||||
) -> ApiResponse[TaskListResponse]:
|
||||
"""列出回测任务,支持分页和状态过滤"""
|
||||
result = task_queue.list_tasks(page, page_size, status)
|
||||
return ApiResponse(
|
||||
code=0,
|
||||
msg="success",
|
||||
data=TaskListResponse(**result)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status/{task_id}", summary="查询任务状态")
|
||||
def get_status(task_id: str) -> ApiResponse[Optional[BacktestTaskWithId]]:
|
||||
"""查询单个任务状态(从磁盘查找)"""
|
||||
task = storage.find_task(task_id)
|
||||
if not task:
|
||||
return ApiResponse(code=404, msg="任务不存在", data=None)
|
||||
return ApiResponse(code=0, msg="success", data=task)
|
||||
|
||||
|
||||
@router.get("/result/{task_id}", summary="获取回测结果")
|
||||
def get_result(task_id: str) -> ApiResponse[Optional[BacktestResult]]:
|
||||
"""获取回测完整结果(从磁盘查找)"""
|
||||
result = storage.find_result(task_id)
|
||||
if not result:
|
||||
task = storage.find_task(task_id)
|
||||
if not task:
|
||||
return ApiResponse(code=404, msg="任务不存在", data=None)
|
||||
return ApiResponse(code=0, msg="任务尚未完成", data=None)
|
||||
return ApiResponse(code=0, msg="success", data=result)
|
||||
|
||||
|
||||
@router.delete("/delete/{task_id}", summary="删除回测任务")
|
||||
def delete_task(task_id: str) -> ApiResponse[None]:
|
||||
"""删除一个回测任务"""
|
||||
# TODO: 实现物理删除
|
||||
# 现在只返回成功,后续实现
|
||||
return ApiResponse(code=0, msg="删除成功(待实现物理删除)", data=None)
|
||||
|
||||
|
||||
@router.get("/health", summary="健康检查")
|
||||
def health_check() -> ApiResponse[HealthCheckResponse]:
|
||||
"""服务健康检查,返回任务统计信息"""
|
||||
return ApiResponse(
|
||||
code=0,
|
||||
msg="ok",
|
||||
data=HealthCheckResponse(
|
||||
pending_count=len(task_queue.pending_tasks),
|
||||
running_count=len(task_queue.running_tasks),
|
||||
completed_count=len(task_queue.completed_tasks),
|
||||
failed_count=len(task_queue.failed_tasks),
|
||||
max_workers=task_queue.max_workers
|
||||
)
|
||||
)
|
||||
Executable
+32
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
自动化回测服务 - 配置
|
||||
"""
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""回测服务配置"""
|
||||
# 服务监听地址
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8088
|
||||
|
||||
# 最大并发回测数(根据CPU核数调整,NAS赛扬建议2)
|
||||
max_workers: int = 2
|
||||
|
||||
# 回测任务存储根目录
|
||||
base_dir: str = "/app/backtest_jobs"
|
||||
|
||||
# CORS 配置 - 开发环境允许所有来源
|
||||
cors_allow_all: bool = True
|
||||
|
||||
# 调试模式
|
||||
debug: bool = True
|
||||
|
||||
# 允许策略代码中导入模块
|
||||
allow_imports: bool = True
|
||||
|
||||
# 最大回测时间限制(秒),防止无限循环
|
||||
max_execution_time: int = 3600 # 1小时
|
||||
|
||||
|
||||
settings = Settings()
|
||||
Executable
+262
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
自动化回测服务 - 任务执行器
|
||||
调用 vnpy 4.x BacktestingEngine 执行回测
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import matplotlib
|
||||
matplotlib.use("Agg") # 无头模式,服务器上不能弹窗
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
# vnpy 4.x import路径(与3.x不同)
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy_ctastrategy.backtesting import BacktestingEngine
|
||||
from vnpy.trader.constant import Interval, Exchange
|
||||
|
||||
from .config import settings
|
||||
from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId
|
||||
from .result_storage import storage
|
||||
|
||||
|
||||
# vnpy 4.x精简了Interval枚举,不再有FIVE_MINUTE等细分
|
||||
INTERVAL_MAP = {
|
||||
"1m": Interval.MINUTE,
|
||||
"5m": Interval.MINUTE,
|
||||
"15m": Interval.MINUTE,
|
||||
"30m": Interval.MINUTE,
|
||||
"1h": Interval.HOUR,
|
||||
"4h": Interval.HOUR,
|
||||
"1d": Interval.DAILY,
|
||||
"1w": Interval.WEEKLY,
|
||||
}
|
||||
|
||||
# 交易所映射
|
||||
EXCHANGE_MAP = {
|
||||
"SSE": Exchange.SSE,
|
||||
"SZSE": Exchange.SZSE,
|
||||
"CFFEX": Exchange.CFFEX,
|
||||
"SHFE": Exchange.SHFE,
|
||||
"DCE": Exchange.DCE,
|
||||
"CZCE": Exchange.CZCE,
|
||||
"INE": Exchange.INE,
|
||||
"GFEX": Exchange.GFEX,
|
||||
}
|
||||
|
||||
|
||||
def _parse_vt_symbol(vt_symbol: str):
|
||||
"""解析vt_symbol为symbol和exchange,如 '000001.SZ' → ('000001', Exchange.SZSE)"""
|
||||
if "." in vt_symbol:
|
||||
symbol, exchange_str = vt_symbol.rsplit(".", 1)
|
||||
exchange = EXCHANGE_MAP.get(exchange_str.upper())
|
||||
if exchange is None:
|
||||
# 尝试模糊匹配
|
||||
exchange_str_upper = exchange_str.upper()
|
||||
for key, val in EXCHANGE_MAP.items():
|
||||
if key.startswith(exchange_str_upper[:2]):
|
||||
exchange = val
|
||||
break
|
||||
if exchange is None:
|
||||
exchange = Exchange.SZSE # 默认深交所
|
||||
return symbol, exchange
|
||||
return vt_symbol, Exchange.SZSE
|
||||
|
||||
|
||||
class BacktestExecutor:
|
||||
"""回测任务执行器 - 适配vnpy 4.x"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _load_strategy(self, task: BacktestTask):
|
||||
"""动态加载策略代码"""
|
||||
strategy_code = task.strategy_code
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
sys.path.insert(0, temp_dir)
|
||||
|
||||
strategy_file = os.path.join(temp_dir, "strategy.py")
|
||||
with open(strategy_file, "w", encoding="utf-8") as f:
|
||||
f.write(strategy_code)
|
||||
|
||||
import importlib
|
||||
spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["dynamic_strategy"] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 找CtaTemplate子类
|
||||
from vnpy_ctastrategy import CtaTemplate
|
||||
strategy_class = None
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if isinstance(attr, type) and issubclass(attr, CtaTemplate) and attr is not CtaTemplate:
|
||||
strategy_class = attr
|
||||
break
|
||||
|
||||
if not strategy_class:
|
||||
raise ValueError("策略代码中没有找到 CtaTemplate 子类,请检查策略代码")
|
||||
|
||||
return strategy_class
|
||||
|
||||
def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult:
|
||||
"""执行一次回测"""
|
||||
start_time = datetime.now()
|
||||
started_at = start_time.isoformat()
|
||||
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = started_at
|
||||
storage.save_task(task)
|
||||
|
||||
result = BacktestResult(
|
||||
task_id=task.task_id,
|
||||
strategy_name=task.strategy_name,
|
||||
status=TaskStatus.RUNNING,
|
||||
result_csv_path="",
|
||||
created_at=task.created_at,
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
try:
|
||||
# 加载策略类
|
||||
strategy_class = self._load_strategy(task)
|
||||
|
||||
# 解析vt_symbol
|
||||
symbol, exchange = _parse_vt_symbol(task.symbol)
|
||||
|
||||
# 获取interval
|
||||
interval = INTERVAL_MAP.get(task.interval, Interval.DAILY)
|
||||
|
||||
# 创建回测引擎
|
||||
engine = BacktestingEngine()
|
||||
|
||||
# 设置回测参数
|
||||
engine.set_parameters(
|
||||
vt_symbol=task.symbol,
|
||||
interval=interval,
|
||||
start=task.start_date,
|
||||
end=task.end_date,
|
||||
rate=0.3 / 10000, # 手续费率万三
|
||||
slippage=0.1, # 滑点0.1
|
||||
size=1, # 合约乘数
|
||||
pricetick=task.tick_size or 0.01, # 最小价格变动
|
||||
capital=task.capital,
|
||||
)
|
||||
|
||||
# 添加策略
|
||||
engine.add_strategy(strategy_class, task.parameters)
|
||||
|
||||
# 加载历史数据
|
||||
# 优先从CSV文件加载(/app/data目录通过volume挂载NAS数据)
|
||||
data_loaded = False
|
||||
data_dir = settings.base_dir.replace("backtest_jobs", "data")
|
||||
|
||||
# 尝试多种数据加载方式
|
||||
try:
|
||||
# 方式1: 使用vnpy内置数据加载
|
||||
engine.load_data()
|
||||
data_loaded = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not data_loaded:
|
||||
raise ValueError(
|
||||
f"无法加载 {task.symbol} 在 [{task.start_date}, {task.end_date}] 的历史数据。"
|
||||
f"请确保数据已导入vnpy数据库或可通过CSV加载。"
|
||||
)
|
||||
|
||||
# 运行回测
|
||||
engine.run_backtesting()
|
||||
|
||||
# 计算统计结果
|
||||
df = engine.calculate_result()
|
||||
statistics = engine.calculate_statistics()
|
||||
|
||||
# 转换为数据模型
|
||||
stats = BacktestStatistics(
|
||||
start_date=str(task.start_date),
|
||||
end_date=str(task.end_date),
|
||||
total_days=int(statistics.get("total_days", 0)),
|
||||
total_trades=int(statistics.get("total_trades", 0)),
|
||||
winning_trades=int(statistics.get("winning_trades", 0)),
|
||||
losing_trades=int(statistics.get("losing_trades", 0)),
|
||||
win_rate=float(statistics.get("win_rate", 0)),
|
||||
total_return=float(statistics.get("total_return", 0)),
|
||||
annual_return=float(statistics.get("annual_return", 0)),
|
||||
sharpe_ratio=float(statistics.get("sharpe_ratio", 0)),
|
||||
max_drawdown=float(statistics.get("max_drawdown", 0)),
|
||||
max_drawdown_start=str(statistics.get("max_drawdown_start", "")),
|
||||
max_drawdown_end=str(statistics.get("max_drawdown_end", "")),
|
||||
profit_factor=float(statistics.get("profit_factor", 0)),
|
||||
calmar_ratio=float(statistics.get("calmar_ratio", 0)),
|
||||
)
|
||||
|
||||
# 保存净值CSV
|
||||
result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv")
|
||||
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
|
||||
if df is not None and not df.empty:
|
||||
df.to_csv(result_csv_path)
|
||||
|
||||
# 绘制收益曲线
|
||||
png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png")
|
||||
self._plot_equity_curve(df, png_path)
|
||||
|
||||
# 保存成交记录
|
||||
trades_csv_path = None
|
||||
try:
|
||||
trades = engine.get_all_trades() if hasattr(engine, 'get_all_trades') else []
|
||||
if trades:
|
||||
trades_df = pd.DataFrame([t.__dict__ for t in trades])
|
||||
trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv")
|
||||
trades_df.to_csv(trades_csv_path, index=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 完成结果
|
||||
result.status = TaskStatus.COMPLETED
|
||||
result.statistics = stats
|
||||
result.result_csv_path = result_csv_path
|
||||
result.equity_curve_png_path = png_path
|
||||
result.trades_csv_path = trades_csv_path
|
||||
result.completed_at = datetime.now().isoformat()
|
||||
|
||||
storage.save_result(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{str(e)}\n{traceback.format_exc()}"
|
||||
result.status = TaskStatus.FAILED
|
||||
result.error_message = error_msg
|
||||
result.completed_at = datetime.now().isoformat()
|
||||
|
||||
storage.save_result(result)
|
||||
return result
|
||||
|
||||
def _plot_equity_curve(self, df: pd.DataFrame, output_path: str):
|
||||
"""绘制收益曲线"""
|
||||
plt.figure(figsize=(12, 6))
|
||||
if df is not None and not df.empty:
|
||||
if "equity" in df.columns:
|
||||
plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2)
|
||||
elif "net_pnl" in df.columns:
|
||||
cumulative = df["net_pnl"].cumsum()
|
||||
plt.plot(df.index, cumulative, label="累计收益", linewidth=2)
|
||||
elif "balance" in df.columns:
|
||||
plt.plot(df.index, df["balance"], label="账户余额", linewidth=2)
|
||||
|
||||
plt.title("回测收益曲线")
|
||||
plt.xlabel("时间")
|
||||
plt.ylabel("净值")
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150)
|
||||
plt.close()
|
||||
|
||||
|
||||
executor = BacktestExecutor()
|
||||
Executable
+72
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
自动化回测服务 - 主入口
|
||||
启动 FastAPI 服务,接受回测任务提交,执行回测,返回结果
|
||||
遵循 vnpy 原生设计,只做外层封装
|
||||
"""
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .config import settings
|
||||
from .api import router
|
||||
from .task_queue import task_queue
|
||||
from .models import ApiResponse, HealthCheckResponse
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期:启动时开启worker线程,关闭时停止"""
|
||||
# 启动
|
||||
task_queue.start_worker_pool()
|
||||
print(f"✅ 回测服务启动 (max_workers={settings.max_workers})")
|
||||
yield
|
||||
# 关闭
|
||||
task_queue.close_worker_pool()
|
||||
print("回测服务已停止")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="sanguo 自动化回测服务",
|
||||
description="基于 vnpy 原生 BacktestingEngine 的自动化回测API服务",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS 配置
|
||||
if settings.cors_allow_all:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册API路由
|
||||
app.include_router(router, tags=["backtest"])
|
||||
|
||||
|
||||
@app.get("/api/backtest/health", summary="服务健康检查", response_model=ApiResponse[HealthCheckResponse])
|
||||
def health():
|
||||
"""服务健康检查,返回当前任务统计"""
|
||||
return ApiResponse(
|
||||
code=0,
|
||||
msg="ok",
|
||||
data=HealthCheckResponse(
|
||||
pending_count=len(task_queue.pending_tasks),
|
||||
running_count=len(task_queue.running_tasks),
|
||||
completed_count=len(task_queue.completed_tasks),
|
||||
failed_count=len(task_queue.failed_tasks),
|
||||
max_workers=settings.max_workers
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"backtest_service.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
)
|
||||
Executable
+97
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
自动化回测服务 - 数据模型
|
||||
"""
|
||||
from enum import Enum
|
||||
from datetime import date, datetime
|
||||
from typing import Dict, Optional, Any, List, Generic, TypeVar
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态枚举"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class BacktestTask(BaseModel):
|
||||
"""回测任务请求"""
|
||||
strategy_name: str = Field(..., description="策略名称")
|
||||
strategy_code: str = Field(..., description="策略完整Python代码")
|
||||
symbol: str = Field(..., description="交易品种,例如 000001.SSE")
|
||||
interval: str = Field(..., description="K线周期,例如 1d、1h")
|
||||
start_date: date = Field(..., description="回测开始日期")
|
||||
end_date: date = Field(..., description="回测结束日期")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="策略参数字典")
|
||||
capital: float = Field(default=1_000_000, description="起始资金")
|
||||
tick_size: Optional[float] = Field(None, description="最小价格变动,不指定则自动获取")
|
||||
|
||||
|
||||
class BacktestTaskWithId(BacktestTask):
|
||||
"""带ID和状态的回测任务"""
|
||||
task_id: str = Field(..., description="任务唯一ID")
|
||||
status: TaskStatus = Field(..., description="任务状态")
|
||||
created_at: str = Field(..., description="创建时间 ISO格式")
|
||||
started_at: Optional[str] = Field(None, description="开始时间")
|
||||
completed_at: Optional[str] = Field(None, description="完成时间")
|
||||
|
||||
|
||||
class BacktestStatistics(BaseModel):
|
||||
"""回测结果统计"""
|
||||
start_date: str
|
||||
end_date: str
|
||||
total_days: int
|
||||
total_trades: int
|
||||
winning_trades: int
|
||||
losing_trades: int
|
||||
win_rate: float
|
||||
total_return: float # 总收益率
|
||||
annual_return: float # 年化收益率
|
||||
sharpe_ratio: float # 夏普比率
|
||||
max_drawdown: float # 最大回撤
|
||||
max_drawdown_start: Optional[str] = None
|
||||
max_drawdown_end: Optional[str] = None
|
||||
profit_factor: float # 收益因子(总盈利/总亏损)
|
||||
calmar_ratio: float # 卡玛比率(年化收益/最大回撤)
|
||||
|
||||
|
||||
class BacktestResult(BaseModel):
|
||||
"""完整回测结果"""
|
||||
task_id: str
|
||||
strategy_name: str
|
||||
status: TaskStatus
|
||||
statistics: Optional[BacktestStatistics] = None
|
||||
result_csv_path: str # 每日净值CSV路径
|
||||
equity_curve_png_path: Optional[str] = None # 收益曲线图片路径
|
||||
trades_csv_path: Optional[str] = None # 成交记录CSV路径
|
||||
error_message: Optional[str] = None # 失败时的错误信息
|
||||
created_at: str
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
|
||||
class TaskListResponse(BaseModel):
|
||||
"""任务列表响应"""
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
tasks: List[BacktestTaskWithId]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
"""通用API响应包装"""
|
||||
code: int = Field(0, description="0表示成功,非0表示错误")
|
||||
msg: str = Field("success", description="响应消息")
|
||||
data: Optional[T] = Field(None, description="响应数据")
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
pending_count: int
|
||||
running_count: int
|
||||
completed_count: int
|
||||
failed_count: int
|
||||
max_workers: int
|
||||
Executable
+95
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
自动化回测服务 - 数据模型
|
||||
"""
|
||||
from enum import Enum
|
||||
from datetime import date, datetime
|
||||
from typing import Dict, Optional, Any, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态枚举"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class BacktestTask(BaseModel):
|
||||
"""回测任务请求"""
|
||||
strategy_name: str = Field(..., description="策略名称")
|
||||
strategy_code: str = Field(..., description="策略完整Python代码")
|
||||
symbol: str = Field(..., description="交易品种,例如 000001.SSE")
|
||||
interval: str = Field(..., description="K线周期,例如 1d、1h")
|
||||
start_date: date = Field(..., description="回测开始日期")
|
||||
end_date: date = Field(..., description="回测结束日期")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="策略参数字典")
|
||||
capital: float = Field(default=1_000_000, description="起始资金")
|
||||
tick_size: Optional[float] = Field(None, description="最小价格变动,不指定则自动获取")
|
||||
|
||||
|
||||
class BacktestTaskWithId(BacktestTask):
|
||||
"""带ID和状态的回测任务"""
|
||||
task_id: str = Field(..., description="任务唯一ID")
|
||||
status: TaskStatus = Field(..., description="任务状态")
|
||||
created_at: str = Field(..., description="创建时间 ISO格式")
|
||||
started_at: Optional[str] = Field(None, description="开始时间")
|
||||
completed_at: Optional[str] = Field(None, description="完成时间")
|
||||
|
||||
|
||||
class BacktestStatistics(BaseModel):
|
||||
"""回测结果统计"""
|
||||
start_date: str
|
||||
end_date: str
|
||||
total_days: int
|
||||
total_trades: int
|
||||
winning_trades: int
|
||||
losing_trades: int
|
||||
win_rate: float
|
||||
total_return: float # 总收益率
|
||||
annual_return: float # 年化收益率
|
||||
sharpe_ratio: float # 夏普比率
|
||||
max_drawdown: float # 最大回撤
|
||||
max_drawdown_start: Optional[str] = None
|
||||
max_drawdown_end: Optional[str] = None
|
||||
profit_factor: float # 收益因子(总盈利/总亏损)
|
||||
calmar_ratio: float # 卡玛比率(年化收益/最大回撤)
|
||||
|
||||
|
||||
class BacktestResult(BaseModel):
|
||||
"""完整回测结果"""
|
||||
task_id: str
|
||||
strategy_name: str
|
||||
status: TaskStatus
|
||||
statistics: Optional[BacktestStatistics] = None
|
||||
result_csv_path: str # 每日净值CSV路径
|
||||
equity_curve_png_path: Optional[str] = None # 收益曲线图片路径
|
||||
trades_csv_path: Optional[str] = None # 成交记录CSV路径
|
||||
error_message: Optional[str] = None # 失败时的错误信息
|
||||
created_at: str
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
|
||||
class TaskListResponse(BaseModel):
|
||||
"""任务列表响应"""
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
tasks: List[BacktestTaskWithId]
|
||||
|
||||
|
||||
class ApiResponse[T](BaseModel):
|
||||
"""通用API响应包装"""
|
||||
code: int = Field(0, description="0表示成功,非0表示错误")
|
||||
msg: str = Field("success", description="响应消息")
|
||||
data: Optional[T] = Field(None, description="响应数据")
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
pending_count: int
|
||||
running_count: int
|
||||
completed_count: int
|
||||
failed_count: int
|
||||
max_workers: int
|
||||
Executable
+101
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
自动化回测服务 - 结果存储
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from typing import Optional
|
||||
from .models import BacktestTaskWithId, BacktestResult
|
||||
from .config import settings
|
||||
|
||||
|
||||
def _json_serial(obj):
|
||||
"""JSON序列化辅助:处理date/datetime"""
|
||||
if isinstance(obj, (date, datetime)):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||
|
||||
|
||||
class ResultStorage:
|
||||
"""结果存储管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_dir = settings.base_dir
|
||||
self._ensure_dirs()
|
||||
|
||||
def _ensure_dirs(self):
|
||||
"""确保目录结构存在"""
|
||||
for status_dir in ["pending", "running", "completed", "failed"]:
|
||||
path = os.path.join(self.base_dir, status_dir)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
def _task_dir(self, task_id: str, status: str) -> str:
|
||||
"""获取任务目录"""
|
||||
return os.path.join(self.base_dir, status, task_id)
|
||||
|
||||
def save_task(self, task: BacktestTaskWithId) -> None:
|
||||
"""保存任务信息"""
|
||||
task_dir = self._task_dir(task.task_id, task.status)
|
||||
os.makedirs(task_dir, exist_ok=True)
|
||||
|
||||
info_file = os.path.join(task_dir, "task.json")
|
||||
with open(info_file, "w", encoding="utf-8") as f:
|
||||
json.dump(task.model_dump(), f, indent=2, ensure_ascii=False, default=_json_serial)
|
||||
|
||||
def load_task(self, task_id: str, status: str) -> Optional[BacktestTaskWithId]:
|
||||
"""加载任务信息"""
|
||||
task_dir = self._task_dir(task_id, status)
|
||||
info_file = os.path.join(task_dir, "task.json")
|
||||
|
||||
if not os.path.exists(info_file):
|
||||
return None
|
||||
|
||||
with open(info_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
return BacktestTaskWithId(**data)
|
||||
|
||||
def save_result(self, result: BacktestResult) -> None:
|
||||
"""保存回测结果"""
|
||||
task_dir = self._task_dir(result.task_id, result.status)
|
||||
os.makedirs(task_dir, exist_ok=True)
|
||||
|
||||
result_file = os.path.join(task_dir, "result.json")
|
||||
with open(result_file, "w", encoding="utf-8") as f:
|
||||
json.dump(result.model_dump(), f, indent=2, ensure_ascii=False, default=_json_serial)
|
||||
|
||||
def load_result(self, task_id: str, status: str) -> Optional[BacktestResult]:
|
||||
"""加载回测结果"""
|
||||
task_dir = self._task_dir(task_id, status)
|
||||
result_file = os.path.join(task_dir, "result.json")
|
||||
|
||||
if not os.path.exists(result_file):
|
||||
return None
|
||||
|
||||
with open(result_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
return BacktestResult(**data)
|
||||
|
||||
def find_task(self, task_id: str) -> Optional[BacktestTaskWithId]:
|
||||
"""在所有状态目录中查找任务"""
|
||||
for status_dir in ["running", "failed", "completed", "pending"]:
|
||||
task = self.load_task(task_id, status_dir)
|
||||
if task:
|
||||
return task
|
||||
return None
|
||||
|
||||
def find_result(self, task_id: str) -> Optional[BacktestResult]:
|
||||
"""在所有状态目录中查找结果"""
|
||||
for status_dir in ["failed", "completed", "running", "pending"]:
|
||||
result = self.load_result(task_id, status_dir)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
def get_task_path(self, task_id: str, status: str, filename: str) -> str:
|
||||
"""获取任务文件路径"""
|
||||
return os.path.join(self._task_dir(task_id, status), filename)
|
||||
|
||||
|
||||
storage = ResultStorage()
|
||||
Executable
+141
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
自动化回测服务 - 任务队列
|
||||
简单后台线程调度:submit后自动触发执行,同一时间只跑一个回测
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from .config import settings
|
||||
from .models import TaskStatus, BacktestTask, BacktestTaskWithId
|
||||
from .result_storage import storage
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""任务队列管理器 - 后台线程调度"""
|
||||
|
||||
def __init__(self):
|
||||
self.max_workers = settings.max_workers
|
||||
self.pending_tasks: List[str] = []
|
||||
self.running_tasks: List[str] = []
|
||||
self.completed_tasks: List[str] = []
|
||||
self.failed_tasks: List[str] = []
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def _generate_task_id(self) -> str:
|
||||
return str(uuid.uuid4()).replace("-", "")
|
||||
|
||||
def submit_task(self, task: BacktestTask) -> BacktestTaskWithId:
|
||||
"""提交新任务到队列"""
|
||||
task_id = self._generate_task_id()
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
task_with_id = BacktestTaskWithId(
|
||||
task_id=task_id,
|
||||
status=TaskStatus.PENDING,
|
||||
created_at=now,
|
||||
**task.model_dump()
|
||||
)
|
||||
|
||||
storage.save_task(task_with_id)
|
||||
self.pending_tasks.append(task_id)
|
||||
return task_with_id
|
||||
|
||||
def list_tasks(self, page: int = 1, page_size: int = 10, status: Optional[str] = None) -> Dict:
|
||||
if status == "pending":
|
||||
task_ids = self.pending_tasks
|
||||
elif status == "running":
|
||||
task_ids = self.running_tasks
|
||||
elif status == "completed":
|
||||
task_ids = self.completed_tasks
|
||||
elif status == "failed":
|
||||
task_ids = self.failed_tasks
|
||||
else:
|
||||
task_ids = (
|
||||
self.pending_tasks +
|
||||
self.running_tasks +
|
||||
self.completed_tasks +
|
||||
self.failed_tasks
|
||||
)
|
||||
|
||||
total = len(task_ids)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
result = []
|
||||
for task_id in task_ids[start:end]:
|
||||
for status_dir in ["pending", "running", "completed", "failed"]:
|
||||
task = storage.load_task(task_id, status_dir)
|
||||
if task:
|
||||
result.append(task)
|
||||
break
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"tasks": result
|
||||
}
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[BacktestTaskWithId]:
|
||||
for status_dir in ["pending", "running", "completed", "failed"]:
|
||||
task = storage.load_task(task_id, status_dir)
|
||||
if task:
|
||||
return task
|
||||
return None
|
||||
|
||||
def _worker_loop(self):
|
||||
"""后台工作线程:循环检查pending任务并执行"""
|
||||
from .executor import executor
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
# 检查是否有pending任务且当前无running任务
|
||||
if self.pending_tasks and not self.running_tasks:
|
||||
task_id = self.pending_tasks.pop(0)
|
||||
self.running_tasks.append(task_id)
|
||||
|
||||
# 在后台线程中执行回测
|
||||
try:
|
||||
task = storage.load_task(task_id, "pending")
|
||||
if task:
|
||||
# 移动到running目录
|
||||
storage.save_task(task)
|
||||
result = executor.execute_backtest(task)
|
||||
|
||||
# 从running移到completed/failed
|
||||
self.running_tasks.remove(task_id)
|
||||
if result.status == TaskStatus.COMPLETED:
|
||||
self.completed_tasks.append(task_id)
|
||||
else:
|
||||
self.failed_tasks.append(task_id)
|
||||
except Exception as e:
|
||||
print(f"任务执行异常: {e}\n{traceback.format_exc()}")
|
||||
if task_id in self.running_tasks:
|
||||
self.running_tasks.remove(task_id)
|
||||
self.failed_tasks.append(task_id)
|
||||
|
||||
# 等待1秒再检查
|
||||
self._stop_event.wait(1.0)
|
||||
|
||||
def start_worker_pool(self):
|
||||
"""启动后台工作线程"""
|
||||
if self._worker_thread is None or not self._worker_thread.is_alive():
|
||||
self._stop_event.clear()
|
||||
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
self._worker_thread.start()
|
||||
print(f"工作线程已启动 (max_workers={self.max_workers})")
|
||||
|
||||
def close_worker_pool(self):
|
||||
"""停止工作线程"""
|
||||
self._stop_event.set()
|
||||
if self._worker_thread and self._worker_thread.is_alive():
|
||||
self._worker_thread.join(timeout=5)
|
||||
self._worker_thread = None
|
||||
|
||||
|
||||
task_queue = TaskQueue()
|
||||
Reference in New Issue
Block a user