241 lines
7.4 KiB
Python
241 lines
7.4 KiB
Python
"""
|
|
BacktestRunner - 回测引擎
|
|
|
|
胶水层:取数据 → 生成信号 → 模拟交易 → 出报告
|
|
|
|
自带简易回测引擎,不依赖 vnpy,降低使用门槛。
|
|
保留 vnpy 入口供高级功能切换。
|
|
|
|
Usage:
|
|
from data_platform import DataCatalog
|
|
from data_platform.backtest_runner import BacktestRunner
|
|
from data_platform.strategy_base import BaseStrategy
|
|
|
|
class MyStrategy(BaseStrategy):
|
|
def generate_signals(self, data):
|
|
...
|
|
|
|
runner = BacktestRunner(DataCatalog())
|
|
result = runner.run(MyStrategy(), "600519", "20250101", "20251231")
|
|
print(result.summary())
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, List, Dict
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
from data_platform.catalog import DataCatalog
|
|
from data_platform.strategy_base import BaseStrategy
|
|
from data_platform.backtest_report import BacktestReport
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Trade:
|
|
"""单笔交易记录"""
|
|
entry_date: pd.Timestamp
|
|
exit_date: Optional[pd.Timestamp]
|
|
entry_price: float
|
|
exit_price: Optional[float]
|
|
direction: int # 1=多, -1=空
|
|
shares: int
|
|
profit: Optional[float] = None
|
|
profit_pct: Optional[float] = None
|
|
|
|
|
|
@dataclass
|
|
class BacktestResult:
|
|
"""回测结果"""
|
|
strategy_name: str
|
|
code: str
|
|
start_date: pd.Timestamp
|
|
end_date: pd.Timestamp
|
|
initial_capital: float
|
|
final_capital: float
|
|
total_return: float
|
|
annual_return: float
|
|
max_drawdown: float
|
|
sharpe_ratio: float
|
|
win_rate: float
|
|
total_trades: int
|
|
trades: List[Trade] = field(default_factory=list)
|
|
equity_curve: Optional[pd.Series] = None
|
|
|
|
def summary(self) -> str:
|
|
return (
|
|
f"策略: {self.strategy_name} | 股票: {self.code}\n"
|
|
f"区间: {self.start_date.date()} ~ {self.end_date.date()}\n"
|
|
f"总收益率: {self.total_return:.2%} | 年化: {self.annual_return:.2%}\n"
|
|
f"最大回撤: {self.max_drawdown:.2%} | 夏普: {self.sharpe_ratio:.2f}\n"
|
|
f"胜率: {self.win_rate:.2%} | 交易次数: {self.total_trades}"
|
|
)
|
|
|
|
|
|
class BacktestRunner:
|
|
"""
|
|
回测引擎
|
|
|
|
一条命令完成:获取数据 → 运行策略 → 生成回测结果
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
catalog: DataCatalog,
|
|
commission_rate: float = 0.0003,
|
|
slippage: float = 0.001,
|
|
):
|
|
self.catalog = catalog
|
|
self.commission_rate = commission_rate
|
|
self.slippage = slippage
|
|
|
|
def run(
|
|
self,
|
|
strategy: BaseStrategy,
|
|
code: str,
|
|
start: str,
|
|
end: str,
|
|
initial_capital: float = 1_000_000,
|
|
) -> BacktestResult:
|
|
"""
|
|
运行单只股票回测
|
|
|
|
Args:
|
|
strategy: 策略实例(BaseStrategy 子类)
|
|
code: 股票代码
|
|
start: 起始日期 YYYYMMDD
|
|
end: 结束日期 YYYYMMDD
|
|
initial_capital: 初始资金
|
|
|
|
Returns:
|
|
BacktestResult
|
|
"""
|
|
# 1. 获取数据
|
|
data = self.catalog.get_daily(code, start=start, end=end)
|
|
if len(data) < 20:
|
|
raise ValueError(f"数据不足:{code} 仅 {len(data)} 行")
|
|
|
|
# 2. 生成信号
|
|
signals = strategy.generate_signals(data)
|
|
if "signal" not in signals.columns:
|
|
raise ValueError(f"策略 {strategy.name} 未生成 signal 列")
|
|
|
|
# 3. 模拟交易
|
|
trades, equity = self._simulate(signals, initial_capital)
|
|
|
|
# 4. 计算指标
|
|
return self._build_result(
|
|
strategy.name, code, signals["date"].iloc[0], signals["date"].iloc[-1],
|
|
initial_capital, equity, trades
|
|
)
|
|
|
|
def run_batch(
|
|
self,
|
|
strategy: BaseStrategy,
|
|
codes: List[str],
|
|
start: str,
|
|
end: str,
|
|
initial_capital: float = 1_000_000,
|
|
) -> Dict[str, BacktestResult]:
|
|
"""批量回测多只股票"""
|
|
results = {}
|
|
for code in codes:
|
|
try:
|
|
results[code] = self.run(strategy, code, start, end, initial_capital)
|
|
except Exception as e:
|
|
logger.warning("回测 %s 失败: %s", code, e)
|
|
return results
|
|
|
|
# ------------------------------------------------------------------
|
|
# 内部方法
|
|
# ------------------------------------------------------------------
|
|
|
|
def _simulate(self, data: pd.DataFrame, capital: float):
|
|
"""模拟交易,返回 (trades, equity_series)"""
|
|
position = 0 # 当前持仓股数
|
|
entry_price = 0.0
|
|
entry_date = None
|
|
trades = []
|
|
equity = []
|
|
|
|
for _, row in data.iterrows():
|
|
price = row["close"]
|
|
|
|
if row["signal"] == 1 and position == 0:
|
|
# 买入
|
|
cost = price * (1 + self.slippage)
|
|
commission = capital * 0.99 * self.commission_rate # 用99%资金买入
|
|
shares = int(capital * 0.99 / (cost * (1 + self.commission_rate)))
|
|
shares = shares // 100 * 100 # 整手
|
|
if shares > 0:
|
|
position = shares
|
|
entry_price = cost
|
|
entry_date = row["date"]
|
|
capital -= shares * cost + shares * cost * self.commission_rate
|
|
|
|
elif row["signal"] == -1 and position > 0:
|
|
# 卖出
|
|
sell_price = price * (1 - self.slippage)
|
|
proceeds = position * sell_price * (1 - self.commission_rate)
|
|
capital += proceeds
|
|
trades.append(Trade(
|
|
entry_date=entry_date,
|
|
exit_date=row["date"],
|
|
entry_price=entry_price,
|
|
exit_price=sell_price,
|
|
direction=1,
|
|
shares=position,
|
|
profit=proceeds - position * entry_price,
|
|
profit_pct=(sell_price / entry_price - 1),
|
|
))
|
|
position = 0
|
|
|
|
# 当日权益
|
|
equity.append(capital + position * price)
|
|
|
|
return trades, pd.Series(equity, index=data.index)
|
|
|
|
def _build_result(self, name, code, start, end, capital, equity, trades):
|
|
"""构建回测结果"""
|
|
final = equity.iloc[-1]
|
|
total_return = final / capital - 1
|
|
days = (end - start).days or 1
|
|
annual_return = (1 + total_return) ** (252 / max(days, 1)) - 1
|
|
|
|
# 最大回撤
|
|
cummax = equity.cummax()
|
|
drawdown = (equity - cummax) / cummax
|
|
max_drawdown = drawdown.min()
|
|
|
|
# 夏普比率
|
|
daily_returns = equity.pct_change().dropna()
|
|
sharpe = (
|
|
daily_returns.mean() / daily_returns.std() * np.sqrt(252)
|
|
if len(daily_returns) > 1 and daily_returns.std() > 0
|
|
else 0.0
|
|
)
|
|
|
|
# 胜率
|
|
wins = sum(1 for t in trades if t.profit and t.profit > 0)
|
|
total = len(trades)
|
|
|
|
return BacktestResult(
|
|
strategy_name=name,
|
|
code=code,
|
|
start_date=start,
|
|
end_date=end,
|
|
initial_capital=capital,
|
|
final_capital=final,
|
|
total_return=total_return,
|
|
annual_return=annual_return,
|
|
max_drawdown=max_drawdown,
|
|
sharpe_ratio=sharpe,
|
|
win_rate=wins / total if total > 0 else 0.0,
|
|
total_trades=total,
|
|
trades=trades,
|
|
equity_curve=equity,
|
|
)
|