Files
sanguo_vnpy/test/backtest/run_backtest_via_rpc.py
T
2026-04-11 21:18:55 +08:00

176 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
通过RPC执行回测 - 完整版
"""
import zmq
import json
# 策略代码
STRATEGY_CODE = '''
"""
单票固定比例止损策略 - vnpy CTA回测
"""
from vnpy_ctastrategy import (
CtaTemplate, StopOrder, TickData, BarData, TradeData, OrderData, BarGenerator, ArrayManager
)
from vnpy.trader.constant import Direction, Offset
class SingleStockStopLossStrategy(CtaTemplate):
"""单票固定比例止损策略"""
author = "关羽 (云长)"
parameters = ["fast_window", "slow_window", "stop_loss_pct"]
variables = ["fast_ma", "slow_ma", "cost_price", "in_position"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
self.bg = BarGenerator(self.on_bar)
self.am = ArrayManager(max(30, 100))
self.fast_ma = 0.0
self.slow_ma = 0.0
self.cost_price = 0.0
self.in_position = False
def on_init(self):
self.write_log(f"策略初始化,fast={self.fast_window}, slow={self.slow_window}, stop_loss={self.stop_loss_pct:.1%}")
self.put_event()
def on_bar(self, bar):
self.am.update_bar(bar)
if not self.am.inited:
return
self.fast_ma = self.am.sma(self.fast_window)
self.slow_ma = self.am.sma(self.slow_window)
have_signal = True
if self.in_position and self.cost_price > 0:
current_drawdown = (bar.close_price - self.cost_price) / self.cost_price
if current_drawdown <= -self.stop_loss_pct:
if self.pos > 0:
self.sell(bar.close_price, self.pos)
self.in_position = False
self.write_log(f"触发止损:成本{self.cost_price:.2f},当前{bar.close_price:.2f},回撤{current_drawdown:.1%}")
have_signal = False
if have_signal:
if not self.in_position:
if self.fast_ma > self.slow_ma:
self.buy(bar.close_price, 10000)
self.cost_price = bar.close_price
self.in_position = True
self.write_log(f"金叉开多:价格{bar.close_price:.2f}")
else:
if self.fast_ma < self.slow_ma:
if self.pos > 0:
self.sell(bar.close_price, self.pos)
self.in_position = False
self.write_log(f"死叉平仓:价格{bar.close_price:.2f}")
self.put_event()
'''
print("=" * 80)
print("🚀 通过RPC执行回测")
print("=" * 80)
print(f"✅ 策略代码: {len(STRATEGY_CODE)} 字符")
# RPC请求
request = {
"strategy_code": STRATEGY_CODE,
"symbol": "510300.S.SSE",
"interval": "1d",
"start": 1609459200,
"end": 1772515200,
"capital": 1000000,
"rate": 3e-5,
"slippage": 0.002,
"size": 10000,
"pricetick": 0.001,
"data_source": "sqlite"
}
print("\n请求配置:")
print(f" 标的: 510300.SSE")
print(f" 时间: 2021-01-01 ~ 2026-03-01")
print(f" 资金: 1,000,000")
print(f" 止损: 15%")
# 连接RPC
print(f"\n连接RPC: 127.0.0.1:8008")
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.LINGER, 0)
socket.connect("tcp://127.0.0.1:8008")
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
# 发送请求
print("\n发送请求...")
request_json = json.dumps(request)
socket.send_string(request_json)
print("✅ 请求已发送,等待响应...")
# 接收响应
try:
response_json = socket.recv_string()
response = json.loads(response_json)
print("✅ 收到响应")
if "error" in response:
print(f"\n❌ 回测失败: {response['error']}")
if "traceback" in response:
print("\n错误堆栈:")
print(response["traceback"])
else:
print("\n" + "=" * 80)
print("回测结果:")
print("=" * 80)
if "statistics" in response:
stats = response["statistics"]
print(f"\n📊 绩效指标:")
print(f" 总收益率: {stats.get('total_return', 0):.2%}")
print(f" 年化收益率: {stats.get('annual_return', 0):.2%}")
print(f" 最大回撤: {stats.get('max_drawdown', 0):.2%}")
print(f" 夏普比率: {stats.get('sharpe_ratio', 0):.2f}")
print(f" 卡玛比率: {stats.get('calmar_ratio', 0):.2f}")
print(f" 总交易次数: {stats.get('total_trades', 0)}")
print(f" 胜率: {stats.get('win_rate', 0):.2%}")
print(f" 盈亏比: {stats.get('profit_loss_ratio', 0):.2f}")
if "trades" in response:
trades = response["trades"]
print(f"\n📝 交易记录: 共 {len(trades)}")
for idx, trade in enumerate(trades[:20], 1):
print(f" {idx}. {trade.get('datetime')} {trade.get('direction')} {trade.get('symbol')} @ {trade.get('price'):.2f} × {trade.get('volume')}")
if len(trades) > 20:
print(f" ... 还有 {len(trades) - 20}")
print("\n" + "=" * 80)
print("✅ 回测执行完成!")
print("=" * 80)
except zmq.error.Again:
print("❌ 请求超时: RPC服务响应时间过长")
except Exception as e:
print(f"❌ 接收响应失败: {e}")
import traceback
traceback.print_exc()
finally:
socket.close()
context.term()