337 lines
12 KiB
Python
337 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
本地直接测试回测 - 直接在容器内运行,找到崩溃原因
|
||
|
||
策略来自关羽将军:single_stock_stop_loss_final_correct.py
|
||
参数:510300.SSE,2021-01-01 ~ 2026-03-01,stop_loss_pct=0.15
|
||
"""
|
||
|
||
# 策略代码直接嵌入
|
||
strategy_code = '''"""
|
||
单票固定比例止损策略 - vnpy CTA回测
|
||
|
||
策略逻辑:
|
||
- 标的:沪深300ETF (510300.SSE)
|
||
- 简单均线趋势跟踪:金叉开多,死叉平多
|
||
- 开多后,如果价格从开仓价下跌超过X%,立即止损平仓
|
||
- 测试不同止损比例对策略绩效的影响
|
||
|
||
回测目标:验证不同止损比例对胜率、盈亏比、最大回撤的影响
|
||
"""
|
||
|
||
from vnpy_ctastrategy import (
|
||
CtaTemplate,
|
||
StopOrder,
|
||
TickData,
|
||
BarData,
|
||
TradeData,
|
||
OrderData,
|
||
BarGenerator,
|
||
ArrayManager,
|
||
)
|
||
from vnpy.trader.constant import Direction, Offset
|
||
|
||
|
||
class SingleStockStopLossStrategy(CtaTemplate):
|
||
"""单票固定比例止损策略 - 均线趋势跟踪+固定比例止损"""
|
||
|
||
author = "关羽 (云长)"
|
||
|
||
# 策略参数
|
||
fast_window = 5 # 短期均线窗口
|
||
slow_window = 20 # 长期均线窗口
|
||
stop_loss_pct = 0.15 # 止损比例,亏损超过这个比例止损
|
||
|
||
# 参数列表
|
||
parameters = ["fast_window", "slow_window", "stop_loss_pct"]
|
||
|
||
# 变量列表
|
||
variables = ["fast_ma", "slow_ma", "cost_price", "in_position"]
|
||
|
||
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||
"""初始化"""
|
||
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
|
||
|
||
self.bg = BarGenerator(self.on_bar)
|
||
self.am = ArrayManager(max(self.slow_window + 10, 100))
|
||
|
||
# 均线数值
|
||
self.fast_ma = 0.0
|
||
self.slow_ma = 0.0
|
||
|
||
# 开仓成本
|
||
self.cost_price = 0.0
|
||
|
||
# 是否持仓
|
||
self.in_position = False
|
||
|
||
def on_init(self):
|
||
"""初始化策略"""
|
||
self.write_log(f"策略初始化,fast={self.fast_window}, slow={self.slow_window}, stop_loss={self.stop_loss_pct:.1%}")
|
||
self.put_event()
|
||
|
||
def on_start(self):
|
||
"""启动策略"""
|
||
self.put_event()
|
||
|
||
def on_stop(self):
|
||
"""停止策略"""
|
||
self.put_event()
|
||
|
||
def on_bar(self, bar):
|
||
"""K线更新"""
|
||
self.am.update_bar(bar)
|
||
|
||
if not self.am.inited:
|
||
return
|
||
|
||
# 计算均线
|
||
self.fast_ma = self.am.sma(self.fast_window)
|
||
self.slow_ma = self.am.sma(self.slow_window)
|
||
|
||
# 检查止损(只有持仓时才检查)
|
||
have_signal = True
|
||
if self.in_position and self.cost_price > 0:
|
||
current_drawdown = (bar.close_price - self.cost_price) / self.cost_price
|
||
|
||
if current_drawdown <= -self.stop_loss_pct:
|
||
# 触发止损,全部平仓
|
||
if self.pos > 0:
|
||
self.sell(bar.close_price, self.pos)
|
||
self.in_position = False
|
||
self.write_log(f"🔴 触发止损:成本{self.cost_price:.2f},当前{bar.close_price:.2f},回撤{current_drawdown:.1%},止损卖出")
|
||
have_signal = False
|
||
|
||
# 如果没有触发止损,继续处理信号
|
||
if have_signal:
|
||
# 均线金叉死叉信号
|
||
if not self.in_position:
|
||
# 金叉:短期上穿长期,开多
|
||
if self.fast_ma > self.slow_ma:
|
||
self.buy(bar.close_price, 1) # 1手
|
||
self.cost_price = bar.close_price
|
||
self.in_position = True
|
||
self.write_log(f"🟢 金叉开多:价格{bar.close_price:.2f},均线fast{self.fast_ma:.2f} slow{self.slow_ma:.2f}")
|
||
else:
|
||
# 死叉:短期下穿长期,平多
|
||
if self.fast_ma < self.slow_ma:
|
||
if self.pos > 0:
|
||
self.sell(bar.close_price, self.pos)
|
||
self.in_position = False
|
||
self.write_log(f"🔵 死叉平仓:价格{bar.close_price:.2f},均线fast{self.fast_ma:.2f} slow{self.slow_ma:.2f}")
|
||
|
||
self.put_event()
|
||
|
||
def on_trade(self, trade):
|
||
"""交易成交回调"""
|
||
self.put_event()
|
||
|
||
def on_order(self, order):
|
||
"""订单回调"""
|
||
self.put_event()
|
||
|
||
def on_stop_order(self, stop_order):
|
||
"""停止单回调"""
|
||
self.put_event()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试
|
||
print("=== 单票固定比例止损策略 测试 ===")
|
||
print(f"默认参数:fast=5, slow=20, stop_loss=15%")
|
||
print("策略初始化完成,等待回测")
|
||
'''
|
||
|
||
# 导入我们的RPC函数,直接调用测试
|
||
import sys
|
||
import types
|
||
|
||
# ============================================
|
||
# 🔥 复制vnpy.app兼容性模块
|
||
# ============================================
|
||
print("🔧 [TEST] 加载vnpy.app兼容性模块...")
|
||
|
||
# 创建顶级模块
|
||
vnpy_app_module = types.ModuleType('vnpy.app')
|
||
sys.modules['vnpy.app'] = vnpy_app_module
|
||
|
||
# 创建子模块
|
||
submodules = ['cta_strategy', 'cta_backtester', 'data_manager']
|
||
for name in submodules:
|
||
full_name = f'vnpy.app.{name}'
|
||
submodule = types.ModuleType(full_name)
|
||
sys.modules[full_name] = submodule
|
||
setattr(vnpy_app_module, name, submodule)
|
||
|
||
# 从实际模块映射类
|
||
from vnpy_ctastrategy import CtaTemplate, CtaStrategyApp
|
||
sys.modules['vnpy.app.cta_strategy'].CtaTemplate = CtaTemplate
|
||
sys.modules['vnpy.app.cta_strategy'].CtaStrategyApp = CtaStrategyApp
|
||
vnpy_app_module.CtaTemplate = CtaTemplate
|
||
vnpy_app_module.CtaStrategyApp = CtaStrategyApp
|
||
|
||
from vnpy_ctabacktester import BacktesterEngine
|
||
sys.modules['vnpy.app.cta_backtester'].BacktesterEngine = BacktesterEngine
|
||
vnpy_app_module.BacktesterEngine = BacktesterEngine
|
||
|
||
print("✅ [TEST] vnpy.app兼容性模块加载完成!")
|
||
print(f" 确认: BacktesterEngine 的类型 = {type(BacktesterEngine)}, 是否是类 = {isinstance(BacktesterEngine, type)}")
|
||
|
||
# ============================================
|
||
# 兼容性修复完成,现在导入其他模块
|
||
# ============================================
|
||
|
||
from vnpy.event import EventEngine
|
||
from vnpy.trader.engine import MainEngine
|
||
import traceback
|
||
|
||
def test_run_strategy_backtest(strategy_code: str, symbol: str, interval: str, start: int, end: int, **kwargs):
|
||
"""直接测试run_strategy_backtest"""
|
||
try:
|
||
print(f"\n🚀 [TEST] 开始回测: {symbol} [{start} - {end}]")
|
||
|
||
# 动态加载策略
|
||
local_vars = {}
|
||
exec(strategy_code, globals(), local_vars)
|
||
|
||
# 查找CtaTemplate子类
|
||
strategy_classes = [
|
||
v for k, v in local_vars.items()
|
||
if isinstance(v, type) and issubclass(v, CtaTemplate) and v != CtaTemplate
|
||
]
|
||
|
||
if not strategy_classes:
|
||
return {
|
||
"error": "策略代码中未找到CtaTemplate子类",
|
||
"hint": "请确保策略继承自CtaTemplate"
|
||
}
|
||
|
||
StrategyClass = strategy_classes[0]
|
||
print(f"✅ [TEST] 找到策略类: {StrategyClass.__name__}")
|
||
|
||
# ============================================
|
||
# 创建引擎 - 按照vnpy 4.x最新规范
|
||
# ============================================
|
||
print(f"🔧 [TEST] 创建引擎...")
|
||
|
||
event_engine = EventEngine()
|
||
print(f"✅ [TEST] event_engine = EventEngine()")
|
||
|
||
main_engine = MainEngine(event_engine)
|
||
print(f"✅ [TEST] main_engine = MainEngine(event_engine)")
|
||
|
||
# ✅ vnpy 4.x 正确用法:add_app 添加类,MainEngine负责实例化
|
||
print(f"🔧 [TEST] main_engine.add_app(BacktesterEngine) // 添加类,不是实例")
|
||
print(f"🔧 [TEST] 确认: BacktesterEngine 的类型 = {type(BacktesterEngine)}, 是否是类 = {isinstance(BacktesterEngine, type)}")
|
||
main_engine.add_app(BacktesterEngine)
|
||
print(f"✅ [TEST] 添加到主引擎完成")
|
||
|
||
print(f"🔧 [TEST] backtester_engine = main_engine.get_app(BacktesterEngine)")
|
||
backtester_engine = main_engine.get_app(BacktesterEngine)
|
||
print(f"✅ [TEST] get_app 返回: 类型 = {type(backtester_engine)}, 是否是实例 = {not isinstance(backtester_engine, type)}")
|
||
|
||
# 双重保险:如果get_app返回的还是类,我们自己实例化
|
||
if isinstance(backtester_engine, type):
|
||
print(f"⚠️ [TEST] get_app 返回的还是类,手动实例化: backtester_engine = BacktesterEngine(main_engine, event_engine)")
|
||
backtester_engine = BacktesterEngine(main_engine, event_engine)
|
||
print(f"✅ [TEST] 手动实例化成功,现在是实例: 类型 = {type(backtester_engine)}")
|
||
|
||
print(f"🔧 [TEST] backtester_engine.init_engine()")
|
||
backtester_engine.init_engine()
|
||
print(f"✅ [TEST] 初始化完成")
|
||
# ============================================
|
||
# 修复完成
|
||
# ============================================
|
||
|
||
# 格式化日期
|
||
start_str = str(start)
|
||
if len(start_str) == 8:
|
||
start_str = f"{start_str[:4]}-{start_str[4:6]}-{start_str[6:8]}"
|
||
end_str = str(end)
|
||
if len(end_str) == 8:
|
||
end_str = f"{end_str[:4]}-{end_str[4:6]}-{end_str[6:8]}"
|
||
|
||
setting = {
|
||
"vt_symbol": symbol,
|
||
"interval": interval,
|
||
"start_date": start_str,
|
||
"end_date": end_str,
|
||
"rate": kwargs.get("rate", 0.00003),
|
||
"slippage": kwargs.get("slippage", 0.2),
|
||
"size": kwargs.get("size", 1),
|
||
"pricetick": kwargs.get("pricetick", 0.2),
|
||
"capital": kwargs.get("capital", 1000000.0),
|
||
}
|
||
|
||
print(f"✅ [TEST] 回测参数: {setting}")
|
||
|
||
# ============================================
|
||
# 🔥 这里确认:正确调用方法,不直接调用实例
|
||
# ============================================
|
||
print(f"🔧 [TEST] 执行回测: backtester_engine.run_backtesting(...)")
|
||
# ✅ 正确写法:调用方法,不直接调用实例
|
||
# ❌ 错误写法:result = backtester_engine(...)
|
||
# ✅ 正确写法:result = backtester_engine.run_backtesting(...)
|
||
result = backtester_engine.run_backtesting(
|
||
strategy_class=StrategyClass,
|
||
setting=setting
|
||
)
|
||
|
||
print(f"✅ [TEST] 回测完成: result = backtester_engine.run_backtesting(...)")
|
||
|
||
# 获取结果
|
||
statistics = backtester_engine.get_result_statistics()
|
||
print(f"✅ [TEST] 回测完成,统计指标: {list(statistics.keys()) if statistics else '无'}")
|
||
|
||
# 获取每日数据
|
||
daily_df = backtester_engine.get_daily_df()
|
||
if daily_df is not None and hasattr(daily_df, 'to_dict'):
|
||
daily_data = daily_df.to_dict(orient='records')
|
||
else:
|
||
daily_data = []
|
||
|
||
# 获取交易记录
|
||
trades = backtester_engine.get_all_trades()
|
||
trade_list = [t.__dict__ for t in trades] if trades else []
|
||
|
||
return {
|
||
"statistics": statistics,
|
||
"trades": trade_list,
|
||
"daily_data": daily_data
|
||
}
|
||
|
||
except Exception as e:
|
||
error_info = {
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}
|
||
print(f"❌ [TEST] 回测错误: {error_info['error']}")
|
||
print(error_info['traceback'])
|
||
return error_info
|
||
|
||
if __name__ == '__main__':
|
||
print("\n=== 开始本地测试 ===")
|
||
result = test_run_strategy_backtest(
|
||
strategy_code=strategy_code,
|
||
symbol="510300.SSE",
|
||
interval="1d",
|
||
start=20210101,
|
||
end=20260301,
|
||
rate=0.00003,
|
||
slippage=0.002,
|
||
size=10000,
|
||
pricetick=0.001,
|
||
capital=1000000,
|
||
)
|
||
|
||
print("\n=== 测试结果 ===")
|
||
if 'error' in result:
|
||
print(f"❌ 测试失败: {result['error']}")
|
||
print("\n完整traceback:")
|
||
print(result['traceback'])
|
||
else:
|
||
print(f"✅ 测试成功!")
|
||
print(f"📊 统计指标: {list(result['statistics'].keys())}")
|
||
print(f"💹 交易记录数量: {len(result['trades'])}")
|
||
print(f"📈 每日数据行数: {len(result['daily_data'])}")
|