Files
sanguo_vnpy/archive/2026-04-29-cleanup/test/backtest/test_fixed.py
T
2026-04-29 20:15:25 +08:00

248 lines
8.6 KiB
Python

#!/usr/bin/env python3
"""
测试修复后的代码
"""
# 策略代码直接嵌入
strategy_code = '''"""
单票固定比例止损策略 - vnpy CTA回测
"""
from vnpy_ctastrategy import (
CtaTemplate,
StopOrder,
TickData,
BarData,
TradeData,
OrderData,
BarGenerator,
ArrayManager,
)
from vnpy.trader.constant import Direction, Offset
class SingleStockStopLossStrategy(CtaTemplate):
"""单票固定比例止损策略 - 均线趋势跟踪+固定比例止损"""
author = "关羽 (云长)"
fast_window = 5
slow_window = 20
stop_loss_pct = 0.15
parameters = ["fast_window", "slow_window", "stop_loss_pct"]
variables = ["fast_ma", "slow_ma", "cost_price", "in_position"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
self.bg = BarGenerator(self.on_bar)
self.am = ArrayManager(max(self.slow_window + 10, 100))
self.fast_ma = 0.0
self.slow_ma = 0.0
self.cost_price = 0.0
self.in_position = False
def on_init(self):
self.write_log(f"策略初始化,fast={self.fast_window}, slow={self.slow_window}, stop_loss={self.stop_loss_pct:.1%}")
self.put_event()
def on_start(self):
self.put_event()
def on_stop(self):
self.put_event()
def on_bar(self, bar):
self.am.update_bar(bar)
if not self.am.inited:
return
self.fast_ma = self.am.sma(self.fast_window)
self.slow_ma = self.am.sma(self.slow_window)
have_signal = True
if self.in_position and self.cost_price > 0:
current_drawdown = (bar.close_price - self.cost_price) / self.cost_price
if current_drawdown <= -self.stop_loss_pct:
if self.pos > 0:
self.sell(bar.close_price, self.pos)
self.in_position = False
have_signal = False
if have_signal:
if not self.in_position:
if self.fast_ma > self.slow_ma:
self.buy(bar.close_price, 1)
self.cost_price = bar.close_price
self.in_position = True
else:
if self.fast_ma < self.slow_ma:
if self.pos > 0:
self.sell(bar.close_price, self.pos)
self.in_position = False
self.put_event()
def on_trade(self, trade):
self.put_event()
def on_order(self, order):
self.put_event()
def on_stop_order(self, stop_order):
self.put_event()
'''
# 导入
import sys
import types
# 兼容性模块
print("🔧 [TEST] 加载vnpy.app兼容性模块...")
vnpy_app_module = types.ModuleType('vnpy.app')
sys.modules['vnpy.app'] = vnpy_app_module
submodules = ['cta_strategy', 'cta_backtester', 'data_manager']
for name in submodules:
full_name = f'vnpy.app.{name}'
submodule = types.ModuleType(full_name)
sys.modules[full_name] = submodule
setattr(vnpy_app_module, name, submodule)
from vnpy_ctastrategy import CtaTemplate, CtaStrategyApp
sys.modules['vnpy.app.cta_strategy'].CtaTemplate = CtaTemplate
sys.modules['vnpy.app.cta_strategy'].CtaStrategyApp = CtaStrategyApp
vnpy_app_module.CtaTemplate = CtaTemplate
vnpy_app_module.CtaStrategyApp = CtaStrategyApp
from vnpy_ctabacktester import BacktesterEngine
sys.modules['vnpy.app.cta_backtester'].BacktesterEngine = BacktesterEngine
vnpy_app_module.BacktesterEngine = BacktesterEngine
print("✅ [TEST] vnpy.app兼容性模块加载完成!")
from vnpy.event import EventEngine
from vnpy.trader.engine import MainEngine
import traceback
def test_run_strategy_backtest(strategy_code: str, symbol: str, interval: str, start: int, end: int, **kwargs):
try:
print(f"\n🚀 [TEST] 开始回测: {symbol} [{start} - {end}]")
local_vars = {}
exec(strategy_code, globals(), local_vars)
strategy_classes = [
v for k, v in local_vars.items()
if isinstance(v, type) and issubclass(v, CtaTemplate) and v != CtaTemplate
]
if not strategy_classes:
return {"error": "未找到CtaTemplate子类"}
StrategyClass = strategy_classes[0]
print(f"✅ [TEST] 找到策略类: {StrategyClass.__name__}")
# ============================================
# 🔥 修复后的正确代码
# ============================================
print(f"🔧 [TEST] 创建引擎...")
event_engine = EventEngine()
print(f"✅ [TEST] event_engine = EventEngine()")
main_engine = MainEngine(event_engine)
print(f"✅ [TEST] main_engine = MainEngine(event_engine)")
# ✅ 正确做法:BacktesterEngine.__init__ 需要 main_engine 和 event_engine
# ❌ add_app 内部调用 app_class() 不带参数,会报错
# ✅ 正确做法:自己实例化,然后 add_app
print(f"🔧 [TEST] BacktesterEngine 需要两个参数,手动实例化")
print(f"🔧 [TEST] backtester_engine = BacktesterEngine(main_engine, event_engine)")
backtester_engine = BacktesterEngine(main_engine, event_engine)
print(f"✅ [TEST] 实例化成功,类型 = {type(backtester_engine)}")
print(f"🔧 [TEST] main_engine.add_app(backtester_engine)")
main_engine.add_app(backtester_engine)
print(f"✅ [TEST] 添加到主引擎完成")
print(f"🔧 [TEST] backtester_engine.init_engine()")
backtester_engine.init_engine()
print(f"✅ [TEST] 初始化完成")
# ============================================
start_str = str(start)
if len(start_str) == 8:
start_str = f"{start_str[:4]}-{start_str[4:6]}-{start_str[6:8]}"
end_str = str(end)
if len(end_str) == 8:
end_str = f"{end_str[:4]}-{end_str[4:6]}-{end_str[6:8]}"
setting = {
"vt_symbol": symbol,
"interval": interval,
"start_date": start_str,
"end_date": end_str,
"rate": kwargs.get("rate", 0.00003),
"slippage": kwargs.get("slippage", 0.2),
"size": kwargs.get("size", 1),
"pricetick": kwargs.get("pricetick", 0.2),
"capital": kwargs.get("capital", 1000000.0),
}
print(f"✅ [TEST] 回测参数: {setting}")
print(f"🔧 [TEST] 执行回测: backtester_engine.run_backtesting(...)")
result = backtester_engine.run_backtesting(
strategy_class=StrategyClass,
setting=setting
)
print(f"✅ [TEST] 回测完成: result = backtester_engine.run_backtesting(...)")
statistics = backtester_engine.get_result_statistics()
print(f"✅ [TEST] 回测完成,统计指标: {list(statistics.keys()) if statistics else ''}")
daily_df = backtester_engine.get_daily_df()
if daily_df is not None and hasattr(daily_df, 'to_dict'):
daily_data = daily_df.to_dict(orient='records')
else:
daily_data = []
trades = backtester_engine.get_all_trades()
trade_list = [t.__dict__ for t in trades] if trades else []
return {
"statistics": statistics,
"trades": trade_list,
"daily_data": daily_data
}
except Exception as e:
error_info = {"error": str(e), "traceback": traceback.format_exc()}
print(f"❌ [TEST] 回测错误: {error_info['error']}")
print(error_info['traceback'])
return error_info
if __name__ == '__main__':
print("\n=== 开始测试修复后的代码 ===")
result = test_run_strategy_backtest(
strategy_code=strategy_code,
symbol="510300.SSE",
interval="1d",
start=20210101,
end=20260301,
rate=0.00003,
slippage=0.002,
size=10000,
pricetick=0.001,
capital=1000000,
)
print("\n=== 测试结果 ===")
if 'error' in result:
print(f"❌ 测试失败: {result['error']}")
else:
print(f"✅ 测试成功!")
print(f"📊 总收益率: {result['statistics'].get('total_return', 'N/A'):.2%}")
print(f"📊 夏普比率: {result['statistics'].get('sharpe_ratio', 'N/A'):.2f}")
print(f"📊 最大回撤: {result['statistics'].get('max_drawdown', 'N/A'):.2%}")
print(f"💹 交易记录数量: {len(result['trades'])}")