Files
sanguo_vnpy/scripts/backtest-service/executor.py
T
2026-04-12 10:19:17 +08:00

225 lines
8.0 KiB
Python

"""
自动化回测服务 - 任务执行器
调用 vnpy 原生 BacktestingEngine 执行回测
"""
import os
import sys
import tempfile
import traceback
from datetime import datetime
from typing import Optional
import matplotlib.pyplot as plt
import pandas as pd
from vnpy.trader.engine import MainEngine
from vnpy.trader.event import EventEngine
from vnpy.trader.backtesting import BacktestingEngine
from vnpy.trader.constant import Interval
from vnpy.trader.database import database_manager
from vnpy.trader.object import HistoryRequest
from .config import settings
from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId
from .result_storage import storage
INTERVAL_MAP = {
"1m": Interval.MINUTE,
"5m": Interval.FIVE_MINUTE,
"15m": Interval.FIFTEEN_MINUTE,
"30m": Interval.THIRTY_MINUTE,
"1h": Interval.HOUR,
"4h": Interval.FOUR_HOUR,
"1d": Interval.DAILY,
"1w": Interval.WEEKLY,
}
class BacktestExecutor:
"""回测任务执行器"""
def __init__(self):
pass
def _load_strategy(self, task: BacktestTask):
"""动态加载策略代码"""
# 将策略代码写入临时文件
strategy_code = task.strategy_code
# 创建临时目录
temp_dir = tempfile.mkdtemp()
sys.path.insert(0, temp_dir)
strategy_file = os.path.join(temp_dir, "strategy.py")
with open(strategy_file, "w", encoding="utf-8") as f:
f.write(strategy_code)
# 导入模块
import importlib
spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file)
module = importlib.util.module_from_spec(spec)
sys.modules["dynamic_strategy"] = module
spec.loader.exec_module(module)
# 找到策略类 - 假设第一个继承自 Strategy 的就是我们要的
from vnpy.trader.strategy import Strategy
strategy_class = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, Strategy) and attr != Strategy:
strategy_class = attr
break
if not strategy_class:
raise ValueError("策略代码中没有找到 Strategy 子类,请检查策略代码")
# 创建策略实例,注入参数
strategy = strategy_class
return strategy
def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult:
"""执行一次回测"""
from vnpy.trader.database import database_manager
start_time = datetime.now()
started_at = start_time.isoformat()
# 更新任务状态为运行中
task.status = TaskStatus.RUNNING
task.started_at = started_at
storage.save_task(task)
result = BacktestResult(
task_id=task.task_id,
strategy_name=task.strategy_name,
status=TaskStatus.RUNNING,
result_csv_path="",
created_at=task.created_at,
started_at=started_at,
)
try:
# 加载策略类
strategy_class = self._load_strategy(task)
# 获取interval
interval = INTERVAL_MAP.get(task.interval, Interval.DAILY)
# 查询历史数据 - 使用 vnpy 数据库
req = HistoryRequest(
symbol=task.symbol,
exchange=None, # 由代码处理
interval=interval,
start=task.start_date,
end=task.end_date,
)
data = database_manager.query_history(req)
if data.empty:
raise ValueError(f"未找到 {task.symbol} 在 [{task.start_date}, {task.end_date}] 范围内的历史数据")
# 创建回测引擎
engine = BacktestingEngine()
# 设置参数
engine.set_parameters(
data=data,
interval=interval,
capital=task.capital,
tick_size=task.tick_size,
)
# 添加策略
engine.add_strategy(strategy_class, task.parameters)
# 运行回测
engine.run_backtesting()
# 计算统计结果
df = engine.calculate_results()
# 统计结果
statistics = engine.get_result_statistics()
# 转换为我们的数据模型
stats = BacktestStatistics(
start_date=statistics["start_date"].isoformat() if hasattr(statistics["start_date"], "isoformat") else str(statistics["start_date"]),
end_date=statistics["end_date"].isoformat() if hasattr(statistics["end_date"], "isoformat") else str(statistics["end_date"]),
total_days=int(statistics["total_days"]),
total_trades=int(statistics["total_trades"]),
winning_trades=int(statistics["winning_trades"]),
losing_trades=int(statistics["losing_trades"]),
win_rate=float(statistics["win_rate"]),
total_return=float(statistics["total_return"]),
annual_return=float(statistics["annual_return"]),
sharpe_ratio=float(statistics["sharpe_ratio"]),
max_drawdown=float(statistics["max_drawdown"]),
profit_factor=float(statistics.get("profit_factor", 0)),
calmar_ratio=float(statistics.get("calmar_ratio", 0)),
)
# 保存净值CSV
result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv")
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
df.to_csv(result_csv_path, index=False)
# 绘制收益曲线
png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png")
self._plot_equity_curve(df, png_path)
# 保存成交记录
trades = engine.get_trades()
if trades:
trades_df = pd.DataFrame([t.__dict__ for t in trades])
trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv")
trades_df.to_csv(trades_csv_path, index=False)
else:
trades_csv_path = None
# 完成结果
result.status = TaskStatus.COMPLETED
result.statistics = stats
result.result_csv_path = result_csv_path
result.equity_curve_png_path = png_path
result.trades_csv_path = trades_csv_path
completed_at = datetime.now().isoformat()
result.completed_at = completed_at
storage.save_result(result)
return result
except Exception as e:
# 捕获异常,记录错误信息
error_msg = f"{str(e)}\n{traceback.format_exc()}"
result.status = TaskStatus.FAILED
result.error_message = error_msg
completed_at = datetime.now().isoformat()
result.completed_at = completed_at
storage.save_result(result)
return result
def _plot_equity_curve(self, df: pd.DataFrame, output_path: str):
"""绘制收益曲线"""
plt.figure(figsize=(12, 6))
if "equity" in df.columns:
plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2)
elif "net_pnl" in df.columns:
cumulative = df["net_pnl"].cumsum()
plt.plot(df.index, cumulative, label="累计收益", linewidth=2)
plt.title("回测收益曲线")
plt.xlabel("时间")
plt.ylabel("净值")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
executor = BacktestExecutor()