auto-sync: 2026-04-28 14:04:15
This commit is contained in:
@@ -5433,5 +5433,6 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
自动化回测服务 - 任务执行器
|
自动化回测服务 - 任务执行器
|
||||||
调用 vnpy 原生 BacktestingEngine 执行回测
|
调用 vnpy 4.x BacktestingEngine 执行回测
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -8,88 +8,111 @@ import tempfile
|
|||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg") # 无头模式,服务器上不能弹窗
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
# vnpy 4.x import路径(与3.x不同)
|
||||||
|
from vnpy.event import EventEngine
|
||||||
from vnpy.trader.engine import MainEngine
|
from vnpy.trader.engine import MainEngine
|
||||||
from vnpy.trader.event import EventEngine
|
from vnpy_ctastrategy.backtesting import BacktestingEngine
|
||||||
from vnpy.trader.backtesting import BacktestingEngine
|
from vnpy.trader.constant import Interval, Exchange
|
||||||
from vnpy.trader.constant import Interval
|
|
||||||
from vnpy.trader.database import database_manager
|
|
||||||
from vnpy.trader.object import HistoryRequest
|
|
||||||
|
|
||||||
from .config import settings
|
from .config import settings
|
||||||
from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId
|
from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId
|
||||||
from .result_storage import storage
|
from .result_storage import storage
|
||||||
|
|
||||||
|
|
||||||
|
# vnpy 4.x精简了Interval枚举,不再有FIVE_MINUTE等细分
|
||||||
INTERVAL_MAP = {
|
INTERVAL_MAP = {
|
||||||
"1m": Interval.MINUTE,
|
"1m": Interval.MINUTE,
|
||||||
"5m": Interval.FIVE_MINUTE,
|
"5m": Interval.MINUTE,
|
||||||
"15m": Interval.FIFTEEN_MINUTE,
|
"15m": Interval.MINUTE,
|
||||||
"30m": Interval.THIRTY_MINUTE,
|
"30m": Interval.MINUTE,
|
||||||
"1h": Interval.HOUR,
|
"1h": Interval.HOUR,
|
||||||
"4h": Interval.FOUR_HOUR,
|
"4h": Interval.HOUR,
|
||||||
"1d": Interval.DAILY,
|
"1d": Interval.DAILY,
|
||||||
"1w": Interval.WEEKLY,
|
"1w": Interval.WEEKLY,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 交易所映射
|
||||||
|
EXCHANGE_MAP = {
|
||||||
|
"SSE": Exchange.SSE,
|
||||||
|
"SZSE": Exchange.SZSE,
|
||||||
|
"CFFEX": Exchange.CFFEX,
|
||||||
|
"SHFE": Exchange.SHFE,
|
||||||
|
"DCE": Exchange.DCE,
|
||||||
|
"CZCE": Exchange.CZCE,
|
||||||
|
"INE": Exchange.INE,
|
||||||
|
"GFEX": Exchange.GFEX,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_vt_symbol(vt_symbol: str):
|
||||||
|
"""解析vt_symbol为symbol和exchange,如 '000001.SZ' → ('000001', Exchange.SZSE)"""
|
||||||
|
if "." in vt_symbol:
|
||||||
|
symbol, exchange_str = vt_symbol.rsplit(".", 1)
|
||||||
|
exchange = EXCHANGE_MAP.get(exchange_str.upper())
|
||||||
|
if exchange is None:
|
||||||
|
# 尝试模糊匹配
|
||||||
|
exchange_str_upper = exchange_str.upper()
|
||||||
|
for key, val in EXCHANGE_MAP.items():
|
||||||
|
if key.startswith(exchange_str_upper[:2]):
|
||||||
|
exchange = val
|
||||||
|
break
|
||||||
|
if exchange is None:
|
||||||
|
exchange = Exchange.SZSE # 默认深交所
|
||||||
|
return symbol, exchange
|
||||||
|
return vt_symbol, Exchange.SZSE
|
||||||
|
|
||||||
|
|
||||||
class BacktestExecutor:
|
class BacktestExecutor:
|
||||||
"""回测任务执行器"""
|
"""回测任务执行器 - 适配vnpy 4.x"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _load_strategy(self, task: BacktestTask):
|
def _load_strategy(self, task: BacktestTask):
|
||||||
"""动态加载策略代码"""
|
"""动态加载策略代码"""
|
||||||
# 将策略代码写入临时文件
|
|
||||||
strategy_code = task.strategy_code
|
strategy_code = task.strategy_code
|
||||||
|
|
||||||
# 创建临时目录
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
sys.path.insert(0, temp_dir)
|
sys.path.insert(0, temp_dir)
|
||||||
|
|
||||||
strategy_file = os.path.join(temp_dir, "strategy.py")
|
strategy_file = os.path.join(temp_dir, "strategy.py")
|
||||||
with open(strategy_file, "w", encoding="utf-8") as f:
|
with open(strategy_file, "w", encoding="utf-8") as f:
|
||||||
f.write(strategy_code)
|
f.write(strategy_code)
|
||||||
|
|
||||||
# 导入模块
|
|
||||||
import importlib
|
import importlib
|
||||||
spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file)
|
spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
sys.modules["dynamic_strategy"] = module
|
sys.modules["dynamic_strategy"] = module
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
# 找到策略类 - 假设第一个继承自 Strategy 的就是我们要的
|
# 找CtaTemplate子类
|
||||||
from vnpy.trader.strategy import Strategy
|
from vnpy_ctastrategy import CtaTemplate
|
||||||
|
|
||||||
strategy_class = None
|
strategy_class = None
|
||||||
for attr_name in dir(module):
|
for attr_name in dir(module):
|
||||||
attr = getattr(module, attr_name)
|
attr = getattr(module, attr_name)
|
||||||
if isinstance(attr, type) and issubclass(attr, Strategy) and attr != Strategy:
|
if isinstance(attr, type) and issubclass(attr, CtaTemplate) and attr is not CtaTemplate:
|
||||||
strategy_class = attr
|
strategy_class = attr
|
||||||
break
|
break
|
||||||
|
|
||||||
if not strategy_class:
|
if not strategy_class:
|
||||||
raise ValueError("策略代码中没有找到 Strategy 子类,请检查策略代码")
|
raise ValueError("策略代码中没有找到 CtaTemplate 子类,请检查策略代码")
|
||||||
|
|
||||||
# 创建策略实例,注入参数
|
return strategy_class
|
||||||
strategy = strategy_class
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult:
|
def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult:
|
||||||
"""执行一次回测"""
|
"""执行一次回测"""
|
||||||
from vnpy.trader.database import database_manager
|
|
||||||
|
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
started_at = start_time.isoformat()
|
started_at = start_time.isoformat()
|
||||||
|
|
||||||
# 更新任务状态为运行中
|
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
task.started_at = started_at
|
task.started_at = started_at
|
||||||
storage.save_task(task)
|
storage.save_task(task)
|
||||||
|
|
||||||
result = BacktestResult(
|
result = BacktestResult(
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
strategy_name=task.strategy_name,
|
strategy_name=task.strategy_name,
|
||||||
@@ -98,119 +121,134 @@ class BacktestExecutor:
|
|||||||
created_at=task.created_at,
|
created_at=task.created_at,
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 加载策略类
|
# 加载策略类
|
||||||
strategy_class = self._load_strategy(task)
|
strategy_class = self._load_strategy(task)
|
||||||
|
|
||||||
|
# 解析vt_symbol
|
||||||
|
symbol, exchange = _parse_vt_symbol(task.symbol)
|
||||||
|
|
||||||
# 获取interval
|
# 获取interval
|
||||||
interval = INTERVAL_MAP.get(task.interval, Interval.DAILY)
|
interval = INTERVAL_MAP.get(task.interval, Interval.DAILY)
|
||||||
|
|
||||||
# 查询历史数据 - 使用 vnpy 数据库
|
# 创建回测引擎
|
||||||
req = HistoryRequest(
|
engine = BacktestingEngine()
|
||||||
symbol=task.symbol,
|
|
||||||
exchange=None, # 由代码处理
|
# 设置回测参数
|
||||||
|
engine.set_parameters(
|
||||||
|
vt_symbol=task.symbol,
|
||||||
interval=interval,
|
interval=interval,
|
||||||
start=task.start_date,
|
start=task.start_date,
|
||||||
end=task.end_date,
|
end=task.end_date,
|
||||||
)
|
rate=0.3 / 10000, # 手续费率万三
|
||||||
data = database_manager.query_history(req)
|
slippage=0.1, # 滑点0.1
|
||||||
|
size=1, # 合约乘数
|
||||||
if data.empty:
|
pricetick=task.tick_size or 0.01, # 最小价格变动
|
||||||
raise ValueError(f"未找到 {task.symbol} 在 [{task.start_date}, {task.end_date}] 范围内的历史数据")
|
|
||||||
|
|
||||||
# 创建回测引擎
|
|
||||||
engine = BacktestingEngine()
|
|
||||||
|
|
||||||
# 设置参数
|
|
||||||
engine.set_parameters(
|
|
||||||
data=data,
|
|
||||||
interval=interval,
|
|
||||||
capital=task.capital,
|
capital=task.capital,
|
||||||
tick_size=task.tick_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加策略
|
# 添加策略
|
||||||
engine.add_strategy(strategy_class, task.parameters)
|
engine.add_strategy(strategy_class, task.parameters)
|
||||||
|
|
||||||
|
# 加载历史数据
|
||||||
|
# 优先从CSV文件加载(/app/data目录通过volume挂载NAS数据)
|
||||||
|
data_loaded = False
|
||||||
|
data_dir = settings.base_dir.replace("backtest_jobs", "data")
|
||||||
|
|
||||||
|
# 尝试多种数据加载方式
|
||||||
|
try:
|
||||||
|
# 方式1: 使用vnpy内置数据加载
|
||||||
|
engine.load_data()
|
||||||
|
data_loaded = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not data_loaded:
|
||||||
|
raise ValueError(
|
||||||
|
f"无法加载 {task.symbol} 在 [{task.start_date}, {task.end_date}] 的历史数据。"
|
||||||
|
f"请确保数据已导入vnpy数据库或可通过CSV加载。"
|
||||||
|
)
|
||||||
|
|
||||||
# 运行回测
|
# 运行回测
|
||||||
engine.run_backtesting()
|
engine.run_backtesting()
|
||||||
|
|
||||||
# 计算统计结果
|
# 计算统计结果
|
||||||
df = engine.calculate_results()
|
df = engine.calculate_result()
|
||||||
|
statistics = engine.calculate_statistics()
|
||||||
# 统计结果
|
|
||||||
statistics = engine.get_result_statistics()
|
# 转换为数据模型
|
||||||
|
|
||||||
# 转换为我们的数据模型
|
|
||||||
stats = BacktestStatistics(
|
stats = BacktestStatistics(
|
||||||
start_date=statistics["start_date"].isoformat() if hasattr(statistics["start_date"], "isoformat") else str(statistics["start_date"]),
|
start_date=str(task.start_date),
|
||||||
end_date=statistics["end_date"].isoformat() if hasattr(statistics["end_date"], "isoformat") else str(statistics["end_date"]),
|
end_date=str(task.end_date),
|
||||||
total_days=int(statistics["total_days"]),
|
total_days=int(statistics.get("total_days", 0)),
|
||||||
total_trades=int(statistics["total_trades"]),
|
total_trades=int(statistics.get("total_trades", 0)),
|
||||||
winning_trades=int(statistics["winning_trades"]),
|
winning_trades=int(statistics.get("winning_trades", 0)),
|
||||||
losing_trades=int(statistics["losing_trades"]),
|
losing_trades=int(statistics.get("losing_trades", 0)),
|
||||||
win_rate=float(statistics["win_rate"]),
|
win_rate=float(statistics.get("win_rate", 0)),
|
||||||
total_return=float(statistics["total_return"]),
|
total_return=float(statistics.get("total_return", 0)),
|
||||||
annual_return=float(statistics["annual_return"]),
|
annual_return=float(statistics.get("annual_return", 0)),
|
||||||
sharpe_ratio=float(statistics["sharpe_ratio"]),
|
sharpe_ratio=float(statistics.get("sharpe_ratio", 0)),
|
||||||
max_drawdown=float(statistics["max_drawdown"]),
|
max_drawdown=float(statistics.get("max_drawdown", 0)),
|
||||||
|
max_drawdown_start=str(statistics.get("max_drawdown_start", "")),
|
||||||
|
max_drawdown_end=str(statistics.get("max_drawdown_end", "")),
|
||||||
profit_factor=float(statistics.get("profit_factor", 0)),
|
profit_factor=float(statistics.get("profit_factor", 0)),
|
||||||
calmar_ratio=float(statistics.get("calmar_ratio", 0)),
|
calmar_ratio=float(statistics.get("calmar_ratio", 0)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存净值CSV
|
# 保存净值CSV
|
||||||
result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv")
|
result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv")
|
||||||
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
|
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
|
||||||
df.to_csv(result_csv_path, index=False)
|
if df is not None and not df.empty:
|
||||||
|
df.to_csv(result_csv_path)
|
||||||
|
|
||||||
# 绘制收益曲线
|
# 绘制收益曲线
|
||||||
png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png")
|
png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png")
|
||||||
self._plot_equity_curve(df, png_path)
|
self._plot_equity_curve(df, png_path)
|
||||||
|
|
||||||
# 保存成交记录
|
# 保存成交记录
|
||||||
trades = engine.get_trades()
|
trades_csv_path = None
|
||||||
if trades:
|
try:
|
||||||
trades_df = pd.DataFrame([t.__dict__ for t in trades])
|
trades = engine.get_all_trades() if hasattr(engine, 'get_all_trades') else []
|
||||||
trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv")
|
if trades:
|
||||||
trades_df.to_csv(trades_csv_path, index=False)
|
trades_df = pd.DataFrame([t.__dict__ for t in trades])
|
||||||
else:
|
trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv")
|
||||||
trades_csv_path = None
|
trades_df.to_csv(trades_csv_path, index=False)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# 完成结果
|
# 完成结果
|
||||||
result.status = TaskStatus.COMPLETED
|
result.status = TaskStatus.COMPLETED
|
||||||
result.statistics = stats
|
result.statistics = stats
|
||||||
result.result_csv_path = result_csv_path
|
result.result_csv_path = result_csv_path
|
||||||
result.equity_curve_png_path = png_path
|
result.equity_curve_png_path = png_path
|
||||||
result.trades_csv_path = trades_csv_path
|
result.trades_csv_path = trades_csv_path
|
||||||
|
result.completed_at = datetime.now().isoformat()
|
||||||
completed_at = datetime.now().isoformat()
|
|
||||||
result.completed_at = completed_at
|
|
||||||
|
|
||||||
storage.save_result(result)
|
storage.save_result(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 捕获异常,记录错误信息
|
|
||||||
error_msg = f"{str(e)}\n{traceback.format_exc()}"
|
error_msg = f"{str(e)}\n{traceback.format_exc()}"
|
||||||
result.status = TaskStatus.FAILED
|
result.status = TaskStatus.FAILED
|
||||||
result.error_message = error_msg
|
result.error_message = error_msg
|
||||||
|
result.completed_at = datetime.now().isoformat()
|
||||||
completed_at = datetime.now().isoformat()
|
|
||||||
result.completed_at = completed_at
|
|
||||||
|
|
||||||
storage.save_result(result)
|
storage.save_result(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _plot_equity_curve(self, df: pd.DataFrame, output_path: str):
|
def _plot_equity_curve(self, df: pd.DataFrame, output_path: str):
|
||||||
"""绘制收益曲线"""
|
"""绘制收益曲线"""
|
||||||
plt.figure(figsize=(12, 6))
|
plt.figure(figsize=(12, 6))
|
||||||
if "equity" in df.columns:
|
if df is not None and not df.empty:
|
||||||
plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2)
|
if "equity" in df.columns:
|
||||||
elif "net_pnl" in df.columns:
|
plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2)
|
||||||
cumulative = df["net_pnl"].cumsum()
|
elif "net_pnl" in df.columns:
|
||||||
plt.plot(df.index, cumulative, label="累计收益", linewidth=2)
|
cumulative = df["net_pnl"].cumsum()
|
||||||
|
plt.plot(df.index, cumulative, label="累计收益", linewidth=2)
|
||||||
|
elif "balance" in df.columns:
|
||||||
|
plt.plot(df.index, df["balance"], label="账户余额", linewidth=2)
|
||||||
|
|
||||||
plt.title("回测收益曲线")
|
plt.title("回测收益曲线")
|
||||||
plt.xlabel("时间")
|
plt.xlabel("时间")
|
||||||
plt.ylabel("净值")
|
plt.ylabel("净值")
|
||||||
|
|||||||
Reference in New Issue
Block a user