Files
sanguo_vnpy/scripts/backtest-service/executor.py
T
2026-04-28 14:04:15 +08:00

263 lines
9.2 KiB
Python
Executable File

"""
自动化回测服务 - 任务执行器
调用 vnpy 4.x BacktestingEngine 执行回测
"""
import os
import sys
import tempfile
import traceback
from datetime import datetime
from typing import Optional
import matplotlib
matplotlib.use("Agg") # 无头模式,服务器上不能弹窗
import matplotlib.pyplot as plt
import pandas as pd
# vnpy 4.x import路径(与3.x不同)
from vnpy.event import EventEngine
from vnpy.trader.engine import MainEngine
from vnpy_ctastrategy.backtesting import BacktestingEngine
from vnpy.trader.constant import Interval, Exchange
from .config import settings
from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId
from .result_storage import storage
# vnpy 4.x精简了Interval枚举,不再有FIVE_MINUTE等细分
INTERVAL_MAP = {
"1m": Interval.MINUTE,
"5m": Interval.MINUTE,
"15m": Interval.MINUTE,
"30m": Interval.MINUTE,
"1h": Interval.HOUR,
"4h": Interval.HOUR,
"1d": Interval.DAILY,
"1w": Interval.WEEKLY,
}
# 交易所映射
EXCHANGE_MAP = {
"SSE": Exchange.SSE,
"SZSE": Exchange.SZSE,
"CFFEX": Exchange.CFFEX,
"SHFE": Exchange.SHFE,
"DCE": Exchange.DCE,
"CZCE": Exchange.CZCE,
"INE": Exchange.INE,
"GFEX": Exchange.GFEX,
}
def _parse_vt_symbol(vt_symbol: str):
"""解析vt_symbol为symbol和exchange,如 '000001.SZ' → ('000001', Exchange.SZSE)"""
if "." in vt_symbol:
symbol, exchange_str = vt_symbol.rsplit(".", 1)
exchange = EXCHANGE_MAP.get(exchange_str.upper())
if exchange is None:
# 尝试模糊匹配
exchange_str_upper = exchange_str.upper()
for key, val in EXCHANGE_MAP.items():
if key.startswith(exchange_str_upper[:2]):
exchange = val
break
if exchange is None:
exchange = Exchange.SZSE # 默认深交所
return symbol, exchange
return vt_symbol, Exchange.SZSE
class BacktestExecutor:
"""回测任务执行器 - 适配vnpy 4.x"""
def __init__(self):
pass
def _load_strategy(self, task: BacktestTask):
"""动态加载策略代码"""
strategy_code = task.strategy_code
temp_dir = tempfile.mkdtemp()
sys.path.insert(0, temp_dir)
strategy_file = os.path.join(temp_dir, "strategy.py")
with open(strategy_file, "w", encoding="utf-8") as f:
f.write(strategy_code)
import importlib
spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file)
module = importlib.util.module_from_spec(spec)
sys.modules["dynamic_strategy"] = module
spec.loader.exec_module(module)
# 找CtaTemplate子类
from vnpy_ctastrategy import CtaTemplate
strategy_class = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, CtaTemplate) and attr is not CtaTemplate:
strategy_class = attr
break
if not strategy_class:
raise ValueError("策略代码中没有找到 CtaTemplate 子类,请检查策略代码")
return strategy_class
def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult:
"""执行一次回测"""
start_time = datetime.now()
started_at = start_time.isoformat()
task.status = TaskStatus.RUNNING
task.started_at = started_at
storage.save_task(task)
result = BacktestResult(
task_id=task.task_id,
strategy_name=task.strategy_name,
status=TaskStatus.RUNNING,
result_csv_path="",
created_at=task.created_at,
started_at=started_at,
)
try:
# 加载策略类
strategy_class = self._load_strategy(task)
# 解析vt_symbol
symbol, exchange = _parse_vt_symbol(task.symbol)
# 获取interval
interval = INTERVAL_MAP.get(task.interval, Interval.DAILY)
# 创建回测引擎
engine = BacktestingEngine()
# 设置回测参数
engine.set_parameters(
vt_symbol=task.symbol,
interval=interval,
start=task.start_date,
end=task.end_date,
rate=0.3 / 10000, # 手续费率万三
slippage=0.1, # 滑点0.1
size=1, # 合约乘数
pricetick=task.tick_size or 0.01, # 最小价格变动
capital=task.capital,
)
# 添加策略
engine.add_strategy(strategy_class, task.parameters)
# 加载历史数据
# 优先从CSV文件加载(/app/data目录通过volume挂载NAS数据)
data_loaded = False
data_dir = settings.base_dir.replace("backtest_jobs", "data")
# 尝试多种数据加载方式
try:
# 方式1: 使用vnpy内置数据加载
engine.load_data()
data_loaded = True
except Exception:
pass
if not data_loaded:
raise ValueError(
f"无法加载 {task.symbol} 在 [{task.start_date}, {task.end_date}] 的历史数据。"
f"请确保数据已导入vnpy数据库或可通过CSV加载。"
)
# 运行回测
engine.run_backtesting()
# 计算统计结果
df = engine.calculate_result()
statistics = engine.calculate_statistics()
# 转换为数据模型
stats = BacktestStatistics(
start_date=str(task.start_date),
end_date=str(task.end_date),
total_days=int(statistics.get("total_days", 0)),
total_trades=int(statistics.get("total_trades", 0)),
winning_trades=int(statistics.get("winning_trades", 0)),
losing_trades=int(statistics.get("losing_trades", 0)),
win_rate=float(statistics.get("win_rate", 0)),
total_return=float(statistics.get("total_return", 0)),
annual_return=float(statistics.get("annual_return", 0)),
sharpe_ratio=float(statistics.get("sharpe_ratio", 0)),
max_drawdown=float(statistics.get("max_drawdown", 0)),
max_drawdown_start=str(statistics.get("max_drawdown_start", "")),
max_drawdown_end=str(statistics.get("max_drawdown_end", "")),
profit_factor=float(statistics.get("profit_factor", 0)),
calmar_ratio=float(statistics.get("calmar_ratio", 0)),
)
# 保存净值CSV
result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv")
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
if df is not None and not df.empty:
df.to_csv(result_csv_path)
# 绘制收益曲线
png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png")
self._plot_equity_curve(df, png_path)
# 保存成交记录
trades_csv_path = None
try:
trades = engine.get_all_trades() if hasattr(engine, 'get_all_trades') else []
if trades:
trades_df = pd.DataFrame([t.__dict__ for t in trades])
trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv")
trades_df.to_csv(trades_csv_path, index=False)
except Exception:
pass
# 完成结果
result.status = TaskStatus.COMPLETED
result.statistics = stats
result.result_csv_path = result_csv_path
result.equity_curve_png_path = png_path
result.trades_csv_path = trades_csv_path
result.completed_at = datetime.now().isoformat()
storage.save_result(result)
return result
except Exception as e:
error_msg = f"{str(e)}\n{traceback.format_exc()}"
result.status = TaskStatus.FAILED
result.error_message = error_msg
result.completed_at = datetime.now().isoformat()
storage.save_result(result)
return result
def _plot_equity_curve(self, df: pd.DataFrame, output_path: str):
"""绘制收益曲线"""
plt.figure(figsize=(12, 6))
if df is not None and not df.empty:
if "equity" in df.columns:
plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2)
elif "net_pnl" in df.columns:
cumulative = df["net_pnl"].cumsum()
plt.plot(df.index, cumulative, label="累计收益", linewidth=2)
elif "balance" in df.columns:
plt.plot(df.index, df["balance"], label="账户余额", linewidth=2)
plt.title("回测收益曲线")
plt.xlabel("时间")
plt.ylabel("净值")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
executor = BacktestExecutor()