Files
sanguo_vnpy/archive/2026-04-29-cleanup/scripts/utils/get_result_json_fixed.py
T
2026-04-29 20:15:43 +08:00

236 lines
8.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""获取回测结果JSON(精简版,修复numpy int64序列化问题)"""
import zmq
import json
import sys
import numpy as np
# 自定义JSON编码器,处理numpy类型
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (np.integer, np.int32, np.int64)):
return int(obj)
elif isinstance(obj, (np.floating, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
# 关羽完整策略代码
strategy_code = '''from vnpy_ctastrategy import (
CtaTemplate,
StopOrder,
TickData,
BarData,
TradeData,
OrderData,
BarGenerator,
ArrayManager,
)
from vnpy.trader.constant import Direction, Offset
class SingleStockStopLossStrategy(CtaTemplate):
"""单票固定比例止损策略 - 均线趋势跟踪+固定比例止损"""
author = "关羽 (云长)"
# 策略参数
fast_window = 5 # 短期均线窗口
slow_window = 20 # 长期均线窗口
stop_loss_pct = 0.15 # 止损比例,亏损超过这个比例止损
# 参数列表
parameters = ["fast_window", "slow_window", "stop_loss_pct"]
# 变量列表
variables = ["fast_ma", "slow_ma", "cost_price", "in_position"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
"""初始化"""
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
self.bg = BarGenerator(self.on_bar)
self.am = ArrayManager(max(self.slow_window + 10, 30))
# 均线数值
self.fast_ma = 0.0
self.slow_ma = 0.0
# 开仓成本
self.cost_price = 0.0
# 是否持仓
self.in_position = False
def on_init(self):
"""初始化策略"""
self.write_log(f"策略初始化,fast={self.fast_window}, slow={self.slow_window}, stop_loss={self.stop_loss_pct:.1%}")
self.load_bar(self.slow_window + 10)
self.put_event()
def on_start(self):
"""启动策略"""
self.put_event()
def on_stop(self):
"""停止策略"""
self.put_event()
def on_bar(self, bar):
"""K线更新"""
self.am.update_bar(bar)
if not self.am.inited:
return
# 计算均线
self.fast_ma = self.am.sma(self.fast_window)
self.slow_ma = self.am.sma(self.slow_window)
# 检查止损(只有持仓时才检查)
have_signal = True
if self.in_position and self.cost_price > 0:
current_drawdown = (bar.close_price - self.cost_price) / self.cost_price
if current_drawdown <= -self.stop_loss_pct:
# 触发止损,全部平仓
if self.pos > 0:
self.sell(bar.close_price, self.pos)
self.in_position = False
have_signal = False
# 如果没有触发止损,继续处理信号
if have_signal:
# 均线金叉死叉信号
if not self.in_position:
# 金叉:短期上穿长期,开多
if self.fast_ma > self.slow_ma:
self.buy(bar.close_price, 10000)
self.cost_price = bar.close_price
self.in_position = True
else:
# 死叉:短期下穿长期,平多
if self.fast_ma < self.slow_ma:
if self.pos > 0:
self.sell(bar.close_price, self.pos)
self.in_position = False
self.put_event()
def on_trade(self, trade):
"""交易成交回调"""
self.put_event()
def on_order(self, order):
"""订单回调"""
self.put_event()
def on_stop_order(self, stop_order):
"""停止单回调"""
self.put_event()
'''
# RPC请求 - 完整区间 2021-01-01 ~ 2026-03-01
request = {
"function": "run_strategy_backtest",
"args": [],
"kwargs": {
"strategy_code": strategy_code,
"symbol": "510300.SSE",
"interval": "1d",
"start": 1609459200, # 2021-01-01
"end": 1772515200, # 2026-03-01
"capital": 1000000,
"rate": 3e-5,
"slippage": 0.002,
"size": 10000,
"pricetick": 0.001,
"data_source": "sqlite",
"setting": {"stop_loss_pct": 0.15}
}
}
print("🔗 连接RPC: tcp://127.0.0.1:8008 (容器内部)")
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect("tcp://127.0.0.1:8008")
socket.setsockopt(zmq.LINGER, 0)
socket.setsockopt(zmq.RCVTIMEO, 900000) # 15分钟超时
socket.setsockopt(zmq.SNDTIMEO, 900000)
print("🚀 发送请求 (全区间 2021-01-01 ~ 2026-03-01, 止损15%)")
print(" 等待响应... 大约需要几分钟")
try:
socket.send_pyobj(request)
result = socket.recv_pyobj()
if "error" in result:
print(f"\n❌ ERROR: {result['error']}")
if "traceback" in result:
print("\nTraceback:")
print(result["traceback"])
sys.exit(1)
else:
print(f"\n✅ SUCCESS! 回测完成!")
print(f" 交易笔数: {result.get('trades_count', 0)}")
# 统计数据就是完整的,不需要精简
# daily_data只保留必要字段,减少大小
daily_data = result.get('daily_data', [])
print(f" 每日数据点数: {len(daily_data)}")
# 保存完整JSON(包含所有你需要的数据)
output_file = "/app/guanyu_510300_backtest_result.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2, cls=NumpyEncoder)
# 获取文件大小
import os
file_size = os.path.getsize(output_file)
print(f"\n📝 完整JSON已保存到容器: {output_file}")
print(f" 文件大小: {file_size} bytes ({file_size / 1024 / 1024:.2f} MB)")
# 打印统计信息
if "statistics" in result:
stats = result["statistics"]
print(f"\n📊 绩效指标:")
print(f" 总收益率: {float(stats.get('total_return', 0)):.2%}")
print(f" 年化收益率: {float(stats.get('annual_return', 0)):.2%}")
print(f" 最大回撤: {float(stats.get('max_drawdown', 0)):.2%}")
print(f" 夏普比率: {float(stats.get('sharpe_ratio', 0)):.2f}")
if 'calmar_ratio' in stats:
print(f" 卡玛比率: {float(stats.get('calmar_ratio', 0)):.2f}")
print(f" 总交易次数: {int(stats.get('total_trades', 0))}")
if 'win_rate' in stats:
print(f" 胜率: {float(stats.get('win_rate', 0)):.2%}")
if 'profit_loss_ratio' in stats:
print(f" 盈亏比: {float(stats.get('profit_loss_ratio', 0)):.2f}")
if "trades" in result:
trades = result["trades"]
print(f"\n📝 交易记录: 共 {len(trades)}")
if len(trades) > 0:
print(f" 前5笔:")
for idx, trade in enumerate(trades[:5], 1):
dt = trade.get('datetime', '')[:10] if trade.get('datetime') else ''
direction = trade.get('direction', '').split('.')[-1] if '.' in trade.get('direction', '') else trade.get('direction', '')
price = float(trade.get('price', 0))
volume = int(trade.get('volume', 0))
print(f" {idx}. {dt} {direction} @ {price:.2f} × {volume}")
socket.close()
context.term()
print("\n✅ 完成!JSON已保存到容器。")
except zmq.error.Again:
print("\n⏱️ ❌ TIMEOUT: 超过15分钟仍未完成")
sys.exit(1)
except Exception as e:
print(f"\n❌ ERROR: {e}")
import traceback
traceback.print_exc()
sys.exit(1)