100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
简单测试 - 缩短区间(只1年)验证交易生成
|
||
"""
|
||
|
||
# 兼容性
|
||
import types
|
||
import sys
|
||
vnpy_app = types.ModuleType('vnpy.app')
|
||
sys.modules['vnpy.app'] = vnpy_app
|
||
for name in ['cta_strategy', 'cta_backtester']:
|
||
mod = types.ModuleType(f'vnpy.app.{name}')
|
||
sys.modules[f'vnpy.app.{name}'] = mod
|
||
setattr(vnpy_app, name, mod)
|
||
|
||
from vnpy_ctastrategy import CtaTemplate, BarGenerator, ArrayManager
|
||
from vnpy_ctabacktester import BacktesterEngine
|
||
from vnpy.event import EventEngine
|
||
from vnpy.trader.engine import MainEngine
|
||
from vnpy.trader.constant import Exchange, Direction
|
||
from datetime import datetime
|
||
|
||
sys.modules['vnpy.app.cta_strategy'].CtaTemplate = CtaTemplate
|
||
sys.modules['vnpy.app.cta_backtester'].BacktesterEngine = BacktesterEngine
|
||
|
||
# 策略
|
||
class SimpleTestStrategy(CtaTemplate):
|
||
author = "test"
|
||
parameters = []
|
||
variables = ["in_position"]
|
||
|
||
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
|
||
self.bg = BarGenerator(self.on_bar)
|
||
self.am = ArrayManager(100)
|
||
self.in_position = False
|
||
|
||
def on_init(self):
|
||
self.load_bar(1000)
|
||
print("策略初始化完成")
|
||
|
||
def on_bar(self, bar):
|
||
self.am.update_bar(bar)
|
||
if not self.am.inited:
|
||
return
|
||
|
||
if not self.in_position:
|
||
self.buy(bar.close_price, 10000)
|
||
self.in_position = True
|
||
print(f"买入 @ {bar.close_price:.2f}")
|
||
|
||
self.put_event()
|
||
|
||
# 初始化引擎
|
||
print("初始化引擎...")
|
||
event_engine = EventEngine()
|
||
main_engine = MainEngine(event_engine)
|
||
backtester_engine = BacktesterEngine(main_engine, event_engine)
|
||
backtester_engine.classes["SimpleTestStrategy"] = SimpleTestStrategy
|
||
|
||
# 运行回测 - 缩短区间(只1年)减少内存
|
||
print("开始回测...")
|
||
try:
|
||
backtester_engine.run_backtesting(
|
||
class_name="SimpleTestStrategy",
|
||
vt_symbol="510300.SSE",
|
||
interval="1d",
|
||
start=datetime(2025, 1, 1),
|
||
end=datetime(2026, 3, 1),
|
||
rate=3e-5,
|
||
slippage=0.002,
|
||
size=10000,
|
||
pricetick=0.001,
|
||
capital=1000000,
|
||
setting={}
|
||
)
|
||
|
||
result = backtester_engine.get_result_statistics()
|
||
print("\n" + "="*60)
|
||
print("回测结果:")
|
||
print(f"总收益率: {result.get('total_return', 0):.2%}")
|
||
print(f"年化收益率: {result.get('annual_return', 0):.2%}")
|
||
print(f"最大回撤: {result.get('max_drawdown', 0):.2%}")
|
||
print(f"夏普比率: {result.get('sharpe_ratio', 0):.2f}")
|
||
print(f"总交易次数: {result.get('total_trades', 0)}")
|
||
print(f"胜率: {result.get('win_rate', 0):.2%}")
|
||
|
||
trades = backtester_engine.get_all_trades()
|
||
print(f"\n交易记录: {len(trades)} 笔")
|
||
for t in trades:
|
||
direction_str = "LONG" if t.direction == Direction.LONG else "SHORT"
|
||
print(f" {t.datetime.date()} {direction_str} @ {t.price:.2f} × {t.volume}")
|
||
|
||
print("\n✅ 回测完成!")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|