auto-sync: 2026-04-28 14:04:15

This commit is contained in:
cfdaily
2026-04-28 14:04:15 +08:00
parent 35f47e1fcb
commit 516bd0ef15
2 changed files with 145 additions and 106 deletions
+1
View File
@@ -5433,5 +5433,6 @@
+115 -77
View File
@@ -1,6 +1,6 @@
""" """
自动化回测服务 - 任务执行器 自动化回测服务 - 任务执行器
调用 vnpy 原生 BacktestingEngine 执行回测 调用 vnpy 4.x BacktestingEngine 执行回测
""" """
import os import os
import sys import sys
@@ -8,45 +8,75 @@ import tempfile
import traceback import traceback
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
import matplotlib
matplotlib.use("Agg") # 无头模式,服务器上不能弹窗
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
# vnpy 4.x import路径(与3.x不同)
from vnpy.event import EventEngine
from vnpy.trader.engine import MainEngine from vnpy.trader.engine import MainEngine
from vnpy.trader.event import EventEngine from vnpy_ctastrategy.backtesting import BacktestingEngine
from vnpy.trader.backtesting import BacktestingEngine from vnpy.trader.constant import Interval, Exchange
from vnpy.trader.constant import Interval
from vnpy.trader.database import database_manager
from vnpy.trader.object import HistoryRequest
from .config import settings from .config import settings
from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId
from .result_storage import storage from .result_storage import storage
# vnpy 4.x精简了Interval枚举,不再有FIVE_MINUTE等细分
INTERVAL_MAP = { INTERVAL_MAP = {
"1m": Interval.MINUTE, "1m": Interval.MINUTE,
"5m": Interval.FIVE_MINUTE, "5m": Interval.MINUTE,
"15m": Interval.FIFTEEN_MINUTE, "15m": Interval.MINUTE,
"30m": Interval.THIRTY_MINUTE, "30m": Interval.MINUTE,
"1h": Interval.HOUR, "1h": Interval.HOUR,
"4h": Interval.FOUR_HOUR, "4h": Interval.HOUR,
"1d": Interval.DAILY, "1d": Interval.DAILY,
"1w": Interval.WEEKLY, "1w": Interval.WEEKLY,
} }
# 交易所映射
EXCHANGE_MAP = {
"SSE": Exchange.SSE,
"SZSE": Exchange.SZSE,
"CFFEX": Exchange.CFFEX,
"SHFE": Exchange.SHFE,
"DCE": Exchange.DCE,
"CZCE": Exchange.CZCE,
"INE": Exchange.INE,
"GFEX": Exchange.GFEX,
}
def _parse_vt_symbol(vt_symbol: str):
"""解析vt_symbol为symbol和exchange,如 '000001.SZ' → ('000001', Exchange.SZSE)"""
if "." in vt_symbol:
symbol, exchange_str = vt_symbol.rsplit(".", 1)
exchange = EXCHANGE_MAP.get(exchange_str.upper())
if exchange is None:
# 尝试模糊匹配
exchange_str_upper = exchange_str.upper()
for key, val in EXCHANGE_MAP.items():
if key.startswith(exchange_str_upper[:2]):
exchange = val
break
if exchange is None:
exchange = Exchange.SZSE # 默认深交所
return symbol, exchange
return vt_symbol, Exchange.SZSE
class BacktestExecutor: class BacktestExecutor:
"""回测任务执行器""" """回测任务执行器 - 适配vnpy 4.x"""
def __init__(self): def __init__(self):
pass pass
def _load_strategy(self, task: BacktestTask): def _load_strategy(self, task: BacktestTask):
"""动态加载策略代码""" """动态加载策略代码"""
# 将策略代码写入临时文件
strategy_code = task.strategy_code strategy_code = task.strategy_code
# 创建临时目录
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
sys.path.insert(0, temp_dir) sys.path.insert(0, temp_dir)
@@ -54,38 +84,31 @@ class BacktestExecutor:
with open(strategy_file, "w", encoding="utf-8") as f: with open(strategy_file, "w", encoding="utf-8") as f:
f.write(strategy_code) f.write(strategy_code)
# 导入模块
import importlib import importlib
spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file) spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
sys.modules["dynamic_strategy"] = module sys.modules["dynamic_strategy"] = module
spec.loader.exec_module(module) spec.loader.exec_module(module)
# 找到策略类 - 假设第一个继承自 Strategy 的就是我们要的 # 找CtaTemplate子类
from vnpy.trader.strategy import Strategy from vnpy_ctastrategy import CtaTemplate
strategy_class = None strategy_class = None
for attr_name in dir(module): for attr_name in dir(module):
attr = getattr(module, attr_name) attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, Strategy) and attr != Strategy: if isinstance(attr, type) and issubclass(attr, CtaTemplate) and attr is not CtaTemplate:
strategy_class = attr strategy_class = attr
break break
if not strategy_class: if not strategy_class:
raise ValueError("策略代码中没有找到 Strategy 子类,请检查策略代码") raise ValueError("策略代码中没有找到 CtaTemplate 子类,请检查策略代码")
# 创建策略实例,注入参数 return strategy_class
strategy = strategy_class
return strategy
def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult: def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult:
"""执行一次回测""" """执行一次回测"""
from vnpy.trader.database import database_manager
start_time = datetime.now() start_time = datetime.now()
started_at = start_time.isoformat() started_at = start_time.isoformat()
# 更新任务状态为运行中
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
task.started_at = started_at task.started_at = started_at
storage.save_task(task) storage.save_task(task)
@@ -103,58 +126,72 @@ class BacktestExecutor:
# 加载策略类 # 加载策略类
strategy_class = self._load_strategy(task) strategy_class = self._load_strategy(task)
# 解析vt_symbol
symbol, exchange = _parse_vt_symbol(task.symbol)
# 获取interval # 获取interval
interval = INTERVAL_MAP.get(task.interval, Interval.DAILY) interval = INTERVAL_MAP.get(task.interval, Interval.DAILY)
# 查询历史数据 - 使用 vnpy 数据库
req = HistoryRequest(
symbol=task.symbol,
exchange=None, # 由代码处理
interval=interval,
start=task.start_date,
end=task.end_date,
)
data = database_manager.query_history(req)
if data.empty:
raise ValueError(f"未找到 {task.symbol} 在 [{task.start_date}, {task.end_date}] 范围内的历史数据")
# 创建回测引擎 # 创建回测引擎
engine = BacktestingEngine() engine = BacktestingEngine()
# 设置参数 # 设置回测参数
engine.set_parameters( engine.set_parameters(
data=data, vt_symbol=task.symbol,
interval=interval, interval=interval,
start=task.start_date,
end=task.end_date,
rate=0.3 / 10000, # 手续费率万三
slippage=0.1, # 滑点0.1
size=1, # 合约乘数
pricetick=task.tick_size or 0.01, # 最小价格变动
capital=task.capital, capital=task.capital,
tick_size=task.tick_size,
) )
# 添加策略 # 添加策略
engine.add_strategy(strategy_class, task.parameters) engine.add_strategy(strategy_class, task.parameters)
# 加载历史数据
# 优先从CSV文件加载(/app/data目录通过volume挂载NAS数据)
data_loaded = False
data_dir = settings.base_dir.replace("backtest_jobs", "data")
# 尝试多种数据加载方式
try:
# 方式1: 使用vnpy内置数据加载
engine.load_data()
data_loaded = True
except Exception:
pass
if not data_loaded:
raise ValueError(
f"无法加载 {task.symbol} 在 [{task.start_date}, {task.end_date}] 的历史数据。"
f"请确保数据已导入vnpy数据库或可通过CSV加载。"
)
# 运行回测 # 运行回测
engine.run_backtesting() engine.run_backtesting()
# 计算统计结果 # 计算统计结果
df = engine.calculate_results() df = engine.calculate_result()
statistics = engine.calculate_statistics()
# 统计结果 # 转换为数据模型
statistics = engine.get_result_statistics()
# 转换为我们的数据模型
stats = BacktestStatistics( stats = BacktestStatistics(
start_date=statistics["start_date"].isoformat() if hasattr(statistics["start_date"], "isoformat") else str(statistics["start_date"]), start_date=str(task.start_date),
end_date=statistics["end_date"].isoformat() if hasattr(statistics["end_date"], "isoformat") else str(statistics["end_date"]), end_date=str(task.end_date),
total_days=int(statistics["total_days"]), total_days=int(statistics.get("total_days", 0)),
total_trades=int(statistics["total_trades"]), total_trades=int(statistics.get("total_trades", 0)),
winning_trades=int(statistics["winning_trades"]), winning_trades=int(statistics.get("winning_trades", 0)),
losing_trades=int(statistics["losing_trades"]), losing_trades=int(statistics.get("losing_trades", 0)),
win_rate=float(statistics["win_rate"]), win_rate=float(statistics.get("win_rate", 0)),
total_return=float(statistics["total_return"]), total_return=float(statistics.get("total_return", 0)),
annual_return=float(statistics["annual_return"]), annual_return=float(statistics.get("annual_return", 0)),
sharpe_ratio=float(statistics["sharpe_ratio"]), sharpe_ratio=float(statistics.get("sharpe_ratio", 0)),
max_drawdown=float(statistics["max_drawdown"]), max_drawdown=float(statistics.get("max_drawdown", 0)),
max_drawdown_start=str(statistics.get("max_drawdown_start", "")),
max_drawdown_end=str(statistics.get("max_drawdown_end", "")),
profit_factor=float(statistics.get("profit_factor", 0)), profit_factor=float(statistics.get("profit_factor", 0)),
calmar_ratio=float(statistics.get("calmar_ratio", 0)), calmar_ratio=float(statistics.get("calmar_ratio", 0)),
) )
@@ -162,20 +199,23 @@ class BacktestExecutor:
# 保存净值CSV # 保存净值CSV
result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv") result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv")
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True) os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
df.to_csv(result_csv_path, index=False) if df is not None and not df.empty:
df.to_csv(result_csv_path)
# 绘制收益曲线 # 绘制收益曲线
png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png") png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png")
self._plot_equity_curve(df, png_path) self._plot_equity_curve(df, png_path)
# 保存成交记录 # 保存成交记录
trades = engine.get_trades() trades_csv_path = None
if trades: try:
trades_df = pd.DataFrame([t.__dict__ for t in trades]) trades = engine.get_all_trades() if hasattr(engine, 'get_all_trades') else []
trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv") if trades:
trades_df.to_csv(trades_csv_path, index=False) trades_df = pd.DataFrame([t.__dict__ for t in trades])
else: trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv")
trades_csv_path = None trades_df.to_csv(trades_csv_path, index=False)
except Exception:
pass
# 完成结果 # 完成结果
result.status = TaskStatus.COMPLETED result.status = TaskStatus.COMPLETED
@@ -183,21 +223,16 @@ class BacktestExecutor:
result.result_csv_path = result_csv_path result.result_csv_path = result_csv_path
result.equity_curve_png_path = png_path result.equity_curve_png_path = png_path
result.trades_csv_path = trades_csv_path result.trades_csv_path = trades_csv_path
result.completed_at = datetime.now().isoformat()
completed_at = datetime.now().isoformat()
result.completed_at = completed_at
storage.save_result(result) storage.save_result(result)
return result return result
except Exception as e: except Exception as e:
# 捕获异常,记录错误信息
error_msg = f"{str(e)}\n{traceback.format_exc()}" error_msg = f"{str(e)}\n{traceback.format_exc()}"
result.status = TaskStatus.FAILED result.status = TaskStatus.FAILED
result.error_message = error_msg result.error_message = error_msg
result.completed_at = datetime.now().isoformat()
completed_at = datetime.now().isoformat()
result.completed_at = completed_at
storage.save_result(result) storage.save_result(result)
return result return result
@@ -205,11 +240,14 @@ class BacktestExecutor:
def _plot_equity_curve(self, df: pd.DataFrame, output_path: str): def _plot_equity_curve(self, df: pd.DataFrame, output_path: str):
"""绘制收益曲线""" """绘制收益曲线"""
plt.figure(figsize=(12, 6)) plt.figure(figsize=(12, 6))
if "equity" in df.columns: if df is not None and not df.empty:
plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2) if "equity" in df.columns:
elif "net_pnl" in df.columns: plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2)
cumulative = df["net_pnl"].cumsum() elif "net_pnl" in df.columns:
plt.plot(df.index, cumulative, label="累计收益", linewidth=2) cumulative = df["net_pnl"].cumsum()
plt.plot(df.index, cumulative, label="累计收益", linewidth=2)
elif "balance" in df.columns:
plt.plot(df.index, df["balance"], label="账户余额", linewidth=2)
plt.title("回测收益曲线") plt.title("回测收益曲线")
plt.xlabel("时间") plt.xlabel("时间")