236 lines
8.0 KiB
Python
Executable File
236 lines
8.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""获取回测结果JSON(精简版,修复numpy int64序列化问题)"""
|
||
|
||
import zmq
|
||
import json
|
||
import sys
|
||
import numpy as np
|
||
|
||
# 自定义JSON编码器,处理numpy类型
|
||
class NumpyEncoder(json.JSONEncoder):
|
||
def default(self, obj):
|
||
if isinstance(obj, (np.integer, np.int32, np.int64)):
|
||
return int(obj)
|
||
elif isinstance(obj, (np.floating, np.float32, np.float64)):
|
||
return float(obj)
|
||
elif isinstance(obj, np.ndarray):
|
||
return obj.tolist()
|
||
return super().default(obj)
|
||
|
||
# 关羽完整策略代码
|
||
strategy_code = '''from vnpy_ctastrategy import (
|
||
CtaTemplate,
|
||
StopOrder,
|
||
TickData,
|
||
BarData,
|
||
TradeData,
|
||
OrderData,
|
||
BarGenerator,
|
||
ArrayManager,
|
||
)
|
||
from vnpy.trader.constant import Direction, Offset
|
||
|
||
|
||
class SingleStockStopLossStrategy(CtaTemplate):
|
||
"""单票固定比例止损策略 - 均线趋势跟踪+固定比例止损"""
|
||
|
||
author = "关羽 (云长)"
|
||
|
||
# 策略参数
|
||
fast_window = 5 # 短期均线窗口
|
||
slow_window = 20 # 长期均线窗口
|
||
stop_loss_pct = 0.15 # 止损比例,亏损超过这个比例止损
|
||
|
||
# 参数列表
|
||
parameters = ["fast_window", "slow_window", "stop_loss_pct"]
|
||
|
||
# 变量列表
|
||
variables = ["fast_ma", "slow_ma", "cost_price", "in_position"]
|
||
|
||
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||
"""初始化"""
|
||
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
|
||
|
||
self.bg = BarGenerator(self.on_bar)
|
||
self.am = ArrayManager(max(self.slow_window + 10, 30))
|
||
|
||
# 均线数值
|
||
self.fast_ma = 0.0
|
||
self.slow_ma = 0.0
|
||
|
||
# 开仓成本
|
||
self.cost_price = 0.0
|
||
|
||
# 是否持仓
|
||
self.in_position = False
|
||
|
||
def on_init(self):
|
||
"""初始化策略"""
|
||
self.write_log(f"策略初始化,fast={self.fast_window}, slow={self.slow_window}, stop_loss={self.stop_loss_pct:.1%}")
|
||
self.load_bar(self.slow_window + 10)
|
||
self.put_event()
|
||
|
||
def on_start(self):
|
||
"""启动策略"""
|
||
self.put_event()
|
||
|
||
def on_stop(self):
|
||
"""停止策略"""
|
||
self.put_event()
|
||
|
||
def on_bar(self, bar):
|
||
"""K线更新"""
|
||
self.am.update_bar(bar)
|
||
|
||
if not self.am.inited:
|
||
return
|
||
|
||
# 计算均线
|
||
self.fast_ma = self.am.sma(self.fast_window)
|
||
self.slow_ma = self.am.sma(self.slow_window)
|
||
|
||
# 检查止损(只有持仓时才检查)
|
||
have_signal = True
|
||
if self.in_position and self.cost_price > 0:
|
||
current_drawdown = (bar.close_price - self.cost_price) / self.cost_price
|
||
|
||
if current_drawdown <= -self.stop_loss_pct:
|
||
# 触发止损,全部平仓
|
||
if self.pos > 0:
|
||
self.sell(bar.close_price, self.pos)
|
||
self.in_position = False
|
||
have_signal = False
|
||
|
||
# 如果没有触发止损,继续处理信号
|
||
if have_signal:
|
||
# 均线金叉死叉信号
|
||
if not self.in_position:
|
||
# 金叉:短期上穿长期,开多
|
||
if self.fast_ma > self.slow_ma:
|
||
self.buy(bar.close_price, 10000)
|
||
self.cost_price = bar.close_price
|
||
self.in_position = True
|
||
else:
|
||
# 死叉:短期下穿长期,平多
|
||
if self.fast_ma < self.slow_ma:
|
||
if self.pos > 0:
|
||
self.sell(bar.close_price, self.pos)
|
||
self.in_position = False
|
||
|
||
self.put_event()
|
||
|
||
def on_trade(self, trade):
|
||
"""交易成交回调"""
|
||
self.put_event()
|
||
|
||
def on_order(self, order):
|
||
"""订单回调"""
|
||
self.put_event()
|
||
|
||
def on_stop_order(self, stop_order):
|
||
"""停止单回调"""
|
||
self.put_event()
|
||
'''
|
||
|
||
# RPC请求 - 完整区间 2021-01-01 ~ 2026-03-01
|
||
request = {
|
||
"function": "run_strategy_backtest",
|
||
"args": [],
|
||
"kwargs": {
|
||
"strategy_code": strategy_code,
|
||
"symbol": "510300.SSE",
|
||
"interval": "1d",
|
||
"start": 1609459200, # 2021-01-01
|
||
"end": 1772515200, # 2026-03-01
|
||
"capital": 1000000,
|
||
"rate": 3e-5,
|
||
"slippage": 0.002,
|
||
"size": 10000,
|
||
"pricetick": 0.001,
|
||
"data_source": "sqlite",
|
||
"setting": {"stop_loss_pct": 0.15}
|
||
}
|
||
}
|
||
|
||
print("🔗 连接RPC: tcp://127.0.0.1:8008 (容器内部)")
|
||
context = zmq.Context()
|
||
socket = context.socket(zmq.REQ)
|
||
socket.connect("tcp://127.0.0.1:8008")
|
||
socket.setsockopt(zmq.LINGER, 0)
|
||
socket.setsockopt(zmq.RCVTIMEO, 900000) # 15分钟超时
|
||
socket.setsockopt(zmq.SNDTIMEO, 900000)
|
||
|
||
print("🚀 发送请求 (全区间 2021-01-01 ~ 2026-03-01, 止损15%)")
|
||
print(" 等待响应... 大约需要几分钟")
|
||
|
||
try:
|
||
socket.send_pyobj(request)
|
||
result = socket.recv_pyobj()
|
||
|
||
if "error" in result:
|
||
print(f"\n❌ ERROR: {result['error']}")
|
||
if "traceback" in result:
|
||
print("\nTraceback:")
|
||
print(result["traceback"])
|
||
sys.exit(1)
|
||
else:
|
||
print(f"\n✅ SUCCESS! 回测完成!")
|
||
print(f" 交易笔数: {result.get('trades_count', 0)}")
|
||
|
||
# 统计数据就是完整的,不需要精简
|
||
# daily_data只保留必要字段,减少大小
|
||
daily_data = result.get('daily_data', [])
|
||
print(f" 每日数据点数: {len(daily_data)}")
|
||
|
||
# 保存完整JSON(包含所有你需要的数据)
|
||
output_file = "/app/guanyu_510300_backtest_result.json"
|
||
with open(output_file, "w", encoding="utf-8") as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2, cls=NumpyEncoder)
|
||
|
||
# 获取文件大小
|
||
import os
|
||
file_size = os.path.getsize(output_file)
|
||
print(f"\n📝 完整JSON已保存到容器: {output_file}")
|
||
print(f" 文件大小: {file_size} bytes ({file_size / 1024 / 1024:.2f} MB)")
|
||
|
||
# 打印统计信息
|
||
if "statistics" in result:
|
||
stats = result["statistics"]
|
||
print(f"\n📊 绩效指标:")
|
||
print(f" 总收益率: {float(stats.get('total_return', 0)):.2%}")
|
||
print(f" 年化收益率: {float(stats.get('annual_return', 0)):.2%}")
|
||
print(f" 最大回撤: {float(stats.get('max_drawdown', 0)):.2%}")
|
||
print(f" 夏普比率: {float(stats.get('sharpe_ratio', 0)):.2f}")
|
||
if 'calmar_ratio' in stats:
|
||
print(f" 卡玛比率: {float(stats.get('calmar_ratio', 0)):.2f}")
|
||
print(f" 总交易次数: {int(stats.get('total_trades', 0))}")
|
||
if 'win_rate' in stats:
|
||
print(f" 胜率: {float(stats.get('win_rate', 0)):.2%}")
|
||
if 'profit_loss_ratio' in stats:
|
||
print(f" 盈亏比: {float(stats.get('profit_loss_ratio', 0)):.2f}")
|
||
|
||
if "trades" in result:
|
||
trades = result["trades"]
|
||
print(f"\n📝 交易记录: 共 {len(trades)} 笔")
|
||
if len(trades) > 0:
|
||
print(f" 前5笔:")
|
||
for idx, trade in enumerate(trades[:5], 1):
|
||
dt = trade.get('datetime', '')[:10] if trade.get('datetime') else ''
|
||
direction = trade.get('direction', '').split('.')[-1] if '.' in trade.get('direction', '') else trade.get('direction', '')
|
||
price = float(trade.get('price', 0))
|
||
volume = int(trade.get('volume', 0))
|
||
print(f" {idx}. {dt} {direction} @ {price:.2f} × {volume}")
|
||
|
||
socket.close()
|
||
context.term()
|
||
print("\n✅ 完成!JSON已保存到容器。")
|
||
|
||
except zmq.error.Again:
|
||
print("\n⏱️ ❌ TIMEOUT: 超过15分钟仍未完成")
|
||
sys.exit(1)
|
||
except Exception as e:
|
||
print(f"\n❌ ERROR: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.exit(1)
|