#!/usr/bin/env python3 """ 最终修复版本 v4 - 添加策略类到BacktesterEngine """ from datetime import datetime from vnpy.trader.constant import Interval # 策略代码直接嵌入 strategy_code = '''""" 单票固定比例止损策略 - vnpy CTA回测 """ from vnpy_ctastrategy import ( CtaTemplate, StopOrder, TickData, BarData, TradeData, OrderData, BarGenerator, ArrayManager, ) from vnpy.trader.constant import Direction, Offset class SingleStockStopLossStrategy(CtaTemplate): """单票固定比例止损策略 - 均线趋势跟踪+固定比例止损""" author = "关羽 (云长)" fast_window = 5 slow_window = 20 stop_loss_pct = 0.15 parameters = ["fast_window", "slow_window", "stop_loss_pct"] variables = ["fast_ma", "slow_ma", "cost_price", "in_position"] def __init__(self, cta_engine, strategy_name, vt_symbol, setting): super().__init__(cta_engine, strategy_name, vt_symbol, setting) self.bg = BarGenerator(self.on_bar) self.am = ArrayManager(max(self.slow_window + 10, 100)) self.fast_ma = 0.0 self.slow_ma = 0.0 self.cost_price = 0.0 self.in_position = False def on_init(self): self.write_log(f"策略初始化,fast={self.fast_window}, slow={self.slow_window}, stop_loss={self.stop_loss_pct:.1%}") self.put_event() def on_start(self): self.put_event() def on_stop(self): self.put_event() def on_bar(self, bar): self.am.update_bar(bar) if not self.am.inited: return self.fast_ma = self.am.sma(self.fast_window) self.slow_ma = self.am.sma(self.slow_window) have_signal = True if self.in_position and self.cost_price > 0: current_drawdown = (bar.close_price - self.cost_price) / self.cost_price if current_drawdown <= -self.stop_loss_pct: if self.pos > 0: self.sell(bar.close_price, self.pos) self.in_position = False have_signal = False if have_signal: if not self.in_position: if self.fast_ma > self.slow_ma: self.buy(bar.close_price, 1) self.cost_price = bar.close_price self.in_position = True else: if self.fast_ma < self.slow_ma: if self.pos > 0: self.sell(bar.close_price, self.pos) self.in_position = False self.put_event() def on_trade(self, trade): self.put_event() def on_order(self, order): self.put_event() def on_stop_order(self, stop_order): self.put_event() ''' # 导入 import sys import types # 兼容性模块 print("🔧 [TEST] 加载vnpy.app兼容性模块...") vnpy_app_module = types.ModuleType('vnpy.app') sys.modules['vnpy.app'] = vnpy_app_module submodules = ['cta_strategy', 'cta_backtester', 'data_manager'] for name in submodules: full_name = f'vnpy.app.{name}' submodule = types.ModuleType(full_name) sys.modules[full_name] = submodule setattr(vnpy_app_module, name, submodule) from vnpy_ctastrategy import CtaTemplate, CtaStrategyApp sys.modules['vnpy.app.cta_strategy'].CtaTemplate = CtaTemplate sys.modules['vnpy.app.cta_strategy'].CtaStrategyApp = CtaStrategyApp vnpy_app_module.CtaTemplate = CtaTemplate vnpy_app_module.CtaStrategyApp = CtaStrategyApp from vnpy_ctabacktester import BacktesterEngine sys.modules['vnpy.app.cta_backtester'].BacktesterEngine = BacktesterEngine vnpy_app_module.BacktesterEngine = BacktesterEngine print("✅ [TEST] vnpy.app兼容性模块加载完成!") from vnpy.event import EventEngine from vnpy.trader.engine import MainEngine from datetime import datetime import traceback def str_to_interval(interval_str: str) -> Interval: """字符串转Interval枚举""" mapping = { "1m": Interval.MINUTE, "min": Interval.MINUTE, "hour": Interval.HOUR, "1h": Interval.HOUR, "d": Interval.DAILY, "1d": Interval.DAILY, "daily": Interval.DAILY, "w": Interval.WEEKLY, "1w": Interval.WEEKLY, "weekly": Interval.WEEKLY, } return mapping.get(interval_str.lower(), Interval.DAILY) def parse_date(date_int: int) -> datetime: """将YYYYMMDD转为datetime""" s = str(date_int) year = int(s[:4]) month = int(s[4:6]) day = int(s[6:8]) return datetime(year, month, day) def test_run_strategy_backtest(strategy_code: str, symbol: str, interval: str, start: int, end: int, **kwargs): try: print(f"\n🚀 [TEST] 开始回测: {symbol} [{start} - {end}]") local_vars = {} exec(strategy_code, globals(), local_vars) strategy_classes = [ v for k, v in local_vars.items() if isinstance(v, type) and issubclass(v, CtaTemplate) and v != CtaTemplate ] if not strategy_classes: return {"error": "未找到CtaTemplate子类"} StrategyClass = strategy_classes[0] class_name = StrategyClass.__name__ print(f"✅ [TEST] 找到策略类: {class_name}") # ============================================ # 🔥 最终修复:完全按照vnpy 4.x官方签名 # ============================================ print(f"🔧 [TEST] 创建引擎...") event_engine = EventEngine() print(f"✅ [TEST] event_engine = EventEngine()") main_engine = MainEngine(event_engine) print(f"✅ [TEST] main_engine = MainEngine(event_engine)") # ✅ 正确做法:直接实例化,参数正确 print(f"🔧 [TEST] BacktesterEngine 需要 main_engine + event_engine,直接实例化") print(f"🔧 [TEST] backtester_engine = BacktesterEngine(main_engine, event_engine)") backtester_engine = BacktesterEngine(main_engine, event_engine) print(f"✅ [TEST] 实例化成功,类型 = {type(backtester_engine)}") print(f"🔧 [TEST] backtester_engine.init_engine()") backtester_engine.init_engine() print(f"✅ [TEST] 初始化完成") # ✅ 添加策略类到BacktesterEngine print(f"🔧 [TEST] 添加策略类: backtester_engine.add_strategy({class_name}, {StrategyClass})") backtester_engine.add_strategy(class_name, StrategyClass) print(f"✅ [TEST] 添加策略类完成") # ============================================ # 转换参数 start_dt = parse_date(start) end_dt = parse_date(end) interval_enum = str_to_interval(interval) rate = kwargs.get("rate", 0.00003) slippage = kwargs.get("slippage", 0.2) size = kwargs.get("size", 1) pricetick = kwargs.get("pricetick", 0.2) capital = kwargs.get("capital", 1000000) setting = { "fast_window": 5, "slow_window": 20, "stop_loss_pct": 0.15 } print(f"✅ [TEST] 参数准备完成:") print(f" class_name: {class_name}") print(f" vt_symbol: {symbol}") print(f" interval: {interval} → {interval_enum}") print(f" start: {start_dt}") print(f" end: {end_dt}") print(f" rate: {rate}") print(f" slippage: {slippage}") print(f" size: {size}") print(f" pricetick: {pricetick}") print(f" capital: {capital}") print(f" setting: {setting}") print(f"🔧 [TEST] 执行回测: backtester_engine.run_backtesting(...) 按照官方参数签名") # ✅ 完全按照官方签名传参 backtester_engine.run_backtesting( class_name, symbol, interval_enum, start_dt, end_dt, rate, slippage, size, pricetick, capital, setting ) print(f"✅ [TEST] 回测完成!") statistics = backtester_engine.get_result_statistics() print(f"✅ [TEST] 获取统计指标: {list(statistics.keys()) if statistics else '无'}") daily_df = backtester_engine.get_daily_df() if daily_df is not None and hasattr(daily_df, 'to_dict'): daily_data = daily_df.to_dict(orient='records') else: daily_data = [] trades = backtester_engine.get_all_trades() trade_list = [t.__dict__ for t in trades] if trades else [] return { "statistics": statistics, "trades": trade_list, "daily_data": daily_data } except Exception as e: error_info = {"error": str(e), "traceback": traceback.format_exc()} print(f"❌ [TEST] 回测错误: {error_info['error']}") print(error_info['traceback']) return error_info if __name__ == '__main__': print("\n=== 开始最终修复测试 v4 (添加策略类到BacktesterEngine) ===") result = test_run_strategy_backtest( strategy_code=strategy_code, symbol="510300.SSE", interval="1d", start=20210101, end=20260301, rate=0.00003, slippage=0.002, size=10000, pricetick=0.001, capital=1000000, ) print("\n=== 测试结果 ===") if 'error' in result: print(f"❌ 测试失败: {result['error']}") else: print(f"✅ 测试成功!") if result['statistics']: for key, value in result['statistics'].items(): if isinstance(value, float): print(f"📊 {key}: {value:.2%}") else: print(f"📊 {key}: {value}") print(f"💹 交易记录数量: {len(result['trades'])}") print(f"📈 每日数据行数: {len(result['daily_data'])}")