initial-import: 2026-04-11 21:18:55
This commit is contained in:
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
最终修复版本 v2 - 修复Interval枚举问题
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from vnpy.trader.constant import Interval
|
||||
|
||||
# 策略代码直接嵌入
|
||||
strategy_code = '''"""
|
||||
单票固定比例止损策略 - vnpy CTA回测
|
||||
"""
|
||||
|
||||
from vnpy_ctastrategy import (
|
||||
CtaTemplate,
|
||||
StopOrder,
|
||||
TickData,
|
||||
BarData,
|
||||
TradeData,
|
||||
OrderData,
|
||||
BarGenerator,
|
||||
ArrayManager,
|
||||
)
|
||||
from vnpy.trader.constant import Direction, Offset
|
||||
|
||||
|
||||
class SingleStockStopLossStrategy(CtaTemplate):
|
||||
"""单票固定比例止损策略 - 均线趋势跟踪+固定比例止损"""
|
||||
|
||||
author = "关羽 (云长)"
|
||||
|
||||
fast_window = 5
|
||||
slow_window = 20
|
||||
stop_loss_pct = 0.15
|
||||
|
||||
parameters = ["fast_window", "slow_window", "stop_loss_pct"]
|
||||
variables = ["fast_ma", "slow_ma", "cost_price", "in_position"]
|
||||
|
||||
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||||
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
|
||||
self.bg = BarGenerator(self.on_bar)
|
||||
self.am = ArrayManager(max(self.slow_window + 10, 100))
|
||||
self.fast_ma = 0.0
|
||||
self.slow_ma = 0.0
|
||||
self.cost_price = 0.0
|
||||
self.in_position = False
|
||||
|
||||
def on_init(self):
|
||||
self.write_log(f"策略初始化,fast={self.fast_window}, slow={self.slow_window}, stop_loss={self.stop_loss_pct:.1%}")
|
||||
self.put_event()
|
||||
|
||||
def on_start(self):
|
||||
self.put_event()
|
||||
|
||||
def on_stop(self):
|
||||
self.put_event()
|
||||
|
||||
def on_bar(self, bar):
|
||||
self.am.update_bar(bar)
|
||||
if not self.am.inited:
|
||||
return
|
||||
|
||||
self.fast_ma = self.am.sma(self.fast_window)
|
||||
self.slow_ma = self.am.sma(self.slow_window)
|
||||
|
||||
have_signal = True
|
||||
if self.in_position and self.cost_price > 0:
|
||||
current_drawdown = (bar.close_price - self.cost_price) / self.cost_price
|
||||
if current_drawdown <= -self.stop_loss_pct:
|
||||
if self.pos > 0:
|
||||
self.sell(bar.close_price, self.pos)
|
||||
self.in_position = False
|
||||
have_signal = False
|
||||
|
||||
if have_signal:
|
||||
if not self.in_position:
|
||||
if self.fast_ma > self.slow_ma:
|
||||
self.buy(bar.close_price, 1)
|
||||
self.cost_price = bar.close_price
|
||||
self.in_position = True
|
||||
else:
|
||||
if self.fast_ma < self.slow_ma:
|
||||
if self.pos > 0:
|
||||
self.sell(bar.close_price, self.pos)
|
||||
self.in_position = False
|
||||
|
||||
self.put_event()
|
||||
|
||||
def on_trade(self, trade):
|
||||
self.put_event()
|
||||
|
||||
def on_order(self, order):
|
||||
self.put_event()
|
||||
|
||||
def on_stop_order(self, stop_order):
|
||||
self.put_event()
|
||||
'''
|
||||
|
||||
# 导入
|
||||
import sys
|
||||
import types
|
||||
|
||||
# 兼容性模块
|
||||
print("🔧 [TEST] 加载vnpy.app兼容性模块...")
|
||||
vnpy_app_module = types.ModuleType('vnpy.app')
|
||||
sys.modules['vnpy.app'] = vnpy_app_module
|
||||
submodules = ['cta_strategy', 'cta_backtester', 'data_manager']
|
||||
for name in submodules:
|
||||
full_name = f'vnpy.app.{name}'
|
||||
submodule = types.ModuleType(full_name)
|
||||
sys.modules[full_name] = submodule
|
||||
setattr(vnpy_app_module, name, submodule)
|
||||
|
||||
from vnpy_ctastrategy import CtaTemplate, CtaStrategyApp
|
||||
sys.modules['vnpy.app.cta_strategy'].CtaTemplate = CtaTemplate
|
||||
sys.modules['vnpy.app.cta_strategy'].CtaStrategyApp = CtaStrategyApp
|
||||
vnpy_app_module.CtaTemplate = CtaTemplate
|
||||
vnpy_app_module.CtaStrategyApp = CtaStrategyApp
|
||||
|
||||
from vnpy_ctabacktester import BacktesterEngine
|
||||
sys.modules['vnpy.app.cta_backtester'].BacktesterEngine = BacktesterEngine
|
||||
vnpy_app_module.BacktesterEngine = BacktesterEngine
|
||||
|
||||
print("✅ [TEST] vnpy.app兼容性模块加载完成!")
|
||||
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
|
||||
def str_to_interval(interval_str: str) -> Interval:
|
||||
"""字符串转Interval枚举"""
|
||||
mapping = {
|
||||
"1m": Interval.MINUTE,
|
||||
"5m": Interval.FIVE_MINUTE,
|
||||
"15m": Interval.FIFTEEN_MINUTE,
|
||||
"30m": Interval.THIRTY_MINUTE,
|
||||
"1h": Interval.HOUR,
|
||||
"4h": Interval.FOUR_HOUR,
|
||||
"1d": Interval.DAILY,
|
||||
"1w": Interval.WEEKLY,
|
||||
"d": Interval.DAILY,
|
||||
"daily": Interval.DAILY,
|
||||
}
|
||||
return mapping.get(interval_str.lower(), Interval.DAILY)
|
||||
|
||||
def parse_date(date_int: int) -> datetime:
|
||||
"""将YYYYMMDD转为datetime"""
|
||||
s = str(date_int)
|
||||
year = int(s[:4])
|
||||
month = int(s[4:6])
|
||||
day = int(s[6:8])
|
||||
return datetime(year, month, day)
|
||||
|
||||
def test_run_strategy_backtest(strategy_code: str, symbol: str, interval: str, start: int, end: int, **kwargs):
|
||||
try:
|
||||
print(f"\n🚀 [TEST] 开始回测: {symbol} [{start} - {end}]")
|
||||
|
||||
local_vars = {}
|
||||
exec(strategy_code, globals(), local_vars)
|
||||
|
||||
strategy_classes = [
|
||||
v for k, v in local_vars.items()
|
||||
if isinstance(v, type) and issubclass(v, CtaTemplate) and v != CtaTemplate
|
||||
]
|
||||
|
||||
if not strategy_classes:
|
||||
return {"error": "未找到CtaTemplate子类"}
|
||||
|
||||
StrategyClass = strategy_classes[0]
|
||||
class_name = StrategyClass.__name__
|
||||
print(f"✅ [TEST] 找到策略类: {class_name}")
|
||||
|
||||
# 把策略类添加到全局
|
||||
globals()[class_name] = StrategyClass
|
||||
|
||||
# ============================================
|
||||
# 🔥 最终修复:完全按照vnpy 4.x官方签名
|
||||
# ============================================
|
||||
print(f"🔧 [TEST] 创建引擎...")
|
||||
event_engine = EventEngine()
|
||||
print(f"✅ [TEST] event_engine = EventEngine()")
|
||||
main_engine = MainEngine(event_engine)
|
||||
print(f"✅ [TEST] main_engine = MainEngine(event_engine)")
|
||||
|
||||
# ✅ 正确做法:直接实例化,参数正确
|
||||
print(f"🔧 [TEST] BacktesterEngine 需要 main_engine + event_engine,直接实例化")
|
||||
print(f"🔧 [TEST] backtester_engine = BacktesterEngine(main_engine, event_engine)")
|
||||
backtester_engine = BacktesterEngine(main_engine, event_engine)
|
||||
print(f"✅ [TEST] 实例化成功,类型 = {type(backtester_engine)}")
|
||||
|
||||
print(f"🔧 [TEST] backtester_engine.init_engine()")
|
||||
backtester_engine.init_engine()
|
||||
print(f"✅ [TEST] 初始化完成")
|
||||
# ============================================
|
||||
|
||||
# 转换参数
|
||||
start_dt = parse_date(start)
|
||||
end_dt = parse_date(end)
|
||||
interval_enum = str_to_interval(interval)
|
||||
|
||||
rate = kwargs.get("rate", 0.00003)
|
||||
slippage = kwargs.get("slippage", 0.2)
|
||||
size = kwargs.get("size", 1)
|
||||
pricetick = kwargs.get("pricetick", 0.2)
|
||||
capital = kwargs.get("capital", 1000000)
|
||||
setting = {
|
||||
"fast_window": 5,
|
||||
"slow_window": 20,
|
||||
"stop_loss_pct": 0.15
|
||||
}
|
||||
|
||||
print(f"✅ [TEST] 参数准备完成:")
|
||||
print(f" class_name: {class_name}")
|
||||
print(f" vt_symbol: {symbol}")
|
||||
print(f" interval: {interval} → {interval_enum}")
|
||||
print(f" start: {start_dt}")
|
||||
print(f" end: {end_dt}")
|
||||
print(f" rate: {rate}")
|
||||
print(f" slippage: {slippage}")
|
||||
print(f" size: {size}")
|
||||
print(f" pricetick: {pricetick}")
|
||||
print(f" capital: {capital}")
|
||||
print(f" setting: {setting}")
|
||||
|
||||
print(f"🔧 [TEST] 执行回测: backtester_engine.run_backtesting(...) 按照官方参数签名")
|
||||
# ✅ 完全按照官方签名传参
|
||||
backtester_engine.run_backtesting(
|
||||
class_name,
|
||||
symbol,
|
||||
interval_enum,
|
||||
start_dt,
|
||||
end_dt,
|
||||
rate,
|
||||
slippage,
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
setting
|
||||
)
|
||||
|
||||
print(f"✅ [TEST] 回测完成!")
|
||||
|
||||
statistics = backtester_engine.get_result_statistics()
|
||||
print(f"✅ [TEST] 获取统计指标: {list(statistics.keys()) if statistics else '无'}")
|
||||
|
||||
daily_df = backtester_engine.get_daily_df()
|
||||
if daily_df is not None and hasattr(daily_df, 'to_dict'):
|
||||
daily_data = daily_df.to_dict(orient='records')
|
||||
else:
|
||||
daily_data = []
|
||||
|
||||
trades = backtester_engine.get_all_trades()
|
||||
trade_list = [t.__dict__ for t in trades] if trades else []
|
||||
|
||||
return {
|
||||
"statistics": statistics,
|
||||
"trades": trade_list,
|
||||
"daily_data": daily_data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_info = {"error": str(e), "traceback": traceback.format_exc()}
|
||||
print(f"❌ [TEST] 回测错误: {error_info['error']}")
|
||||
print(error_info['traceback'])
|
||||
return error_info
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("\n=== 开始最终修复测试 v2 (修复Interval枚举) ===")
|
||||
result = test_run_strategy_backtest(
|
||||
strategy_code=strategy_code,
|
||||
symbol="510300.SSE",
|
||||
interval="1d",
|
||||
start=20210101,
|
||||
end=20260301,
|
||||
rate=0.00003,
|
||||
slippage=0.002,
|
||||
size=10000,
|
||||
pricetick=0.001,
|
||||
capital=1000000,
|
||||
)
|
||||
|
||||
print("\n=== 测试结果 ===")
|
||||
if 'error' in result:
|
||||
print(f"❌ 测试失败: {result['error']}")
|
||||
else:
|
||||
print(f"✅ 测试成功!")
|
||||
print(f"📊 总收益率: {result['statistics'].get('total_return', 'N/A'):.2%}")
|
||||
print(f"📊 夏普比率: {result['statistics'].get('sharpe_ratio', 'N/A'):.2f}")
|
||||
print(f"📊 最大回撤: {result['statistics'].get('max_drawdown', 'N/A'):.2%}")
|
||||
print(f"💹 交易记录数量: {len(result['trades'])}")
|
||||
Reference in New Issue
Block a user