Files
2026-04-29 20:15:25 +08:00

328 lines
11 KiB
Python

#!/usr/bin/env python3
"""
完整单文件回测服务
确保BacktesterEngine初始化绝对正确
"""
import sys
import os
# ============================================
# 🔥 第一步:vnpy.app兼容性模块
# ============================================
print("🔧 加载vnpy.app兼容性模块...")
import types
# 创建顶级模块
vnpy_app_module = types.ModuleType('vnpy.app')
sys.modules['vnpy.app'] = vnpy_app_module
# 创建子模块
submodules = ['cta_strategy', 'cta_backtester', 'data_manager']
for name in submodules:
full_name = f'vnpy.app.{name}'
submodule = types.ModuleType(full_name)
sys.modules[full_name] = submodule
setattr(vnpy_app_module, name, submodule)
# 从实际模块映射类
from vnpy_ctastrategy import CtaTemplate, CtaStrategyApp
sys.modules['vnpy.app.cta_strategy'].CtaTemplate = CtaTemplate
sys.modules['vnpy.app.cta_strategy'].CtaStrategyApp = CtaStrategyApp
vnpy_app_module.CtaTemplate = CtaTemplate
vnpy_app_module.CtaStrategyApp = CtaStrategyApp
from vnpy_ctabacktester import BacktesterEngine
sys.modules['vnpy.app.cta_backtester'].BacktesterEngine = BacktesterEngine
vnpy_app_module.BacktesterEngine = BacktesterEngine
print("✅ vnpy.app兼容性模块加载完成!")
print(" 现在支持: from vnpy.app.cta_strategy import CtaTemplate")
# ============================================
# 兼容性修复完成
# ============================================
from vnpy.event import EventEngine
from vnpy.trader.engine import MainEngine
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import zmq
import pydantic
import traceback
from typing import Optional, Dict, Any
# ============================================
# 🔥 验证BacktesterEngine签名
# ============================================
import inspect
print("🔍 BacktesterEngine.__init__ 签名:")
print(f" {inspect.signature(BacktesterEngine.__init__)}")
# ============================================
# FastAPI应用
# ============================================
app = FastAPI(
title="回测API服务 - 最终完整版本",
description="vn.py策略回测API服务 - BacktesterEngine初始化已修复",
version="10.0.0-complete",
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ZMQ配置
ZMQ_HOST = "127.0.0.1"
ZMQ_PORT = 8008 # 全新端口
ZMQ_TIMEOUT = 30000
# 请求模型
class BacktestRequest(pydantic.BaseModel):
strategy_code: str
symbol: str
interval: str = "1d"
start: int
end: int
capital: float = 1000000.0
rate: float = 0.00003
slippage: float = 0.2
size: int = 1
pricetick: float = 0.2
# 响应模型
class ApiResponse(pydantic.BaseModel):
code: int
msg: str
data: Optional[Dict[str, Any]] = None
error: Optional[str] = None
error_detail: Optional[str] = None
def run_strategy_backtest_core(strategy_code: str, symbol: str, interval: str, start: int, end: int, **kwargs):
"""核心回测函数 - 这里确保BacktesterEngine初始化绝对正确"""
try:
print(f"\n🚀 开始新回测: {symbol} [{start} - {end}]")
print(f"🔧 加载策略代码...")
# 动态加载策略
local_vars = {}
exec(strategy_code, globals(), local_vars)
# 查找CtaTemplate子类
strategy_classes = [
v for k, v in local_vars.items()
if isinstance(v, type) and issubclass(v, CtaTemplate) and v != CtaTemplate
]
if not strategy_classes:
return {
"error": "策略代码中未找到CtaTemplate子类",
"hint": "请确保策略继承自CtaTemplate"
}
StrategyClass = strategy_classes[0]
print(f"✅ 找到策略类: {StrategyClass.__name__}")
# ============================================
# 🔥 最关键部分:正确创建BacktesterEngine
# ============================================
print(f"🔧 创建事件引擎...")
event_engine = EventEngine()
print(f"✅ event_engine 创建完成: {event_engine}")
print(f"🔧 创建主引擎...")
main_engine = MainEngine(event_engine)
print(f"✅ main_engine 创建完成: {main_engine}")
# ✅✅✅ 这里是关键!必须正确传入两个参数
# 根据vnpy_ctabacktester源代码:
# def __init__(self, main_engine: MainEngine, event_engine: EventEngine) -> None:
print(f"🔧 创建BacktesterEngine,传入两个参数...")
print(f"🔧 backtester_engine = BacktesterEngine(main_engine, event_engine)")
# 👉👉👉 这一行是关键,必须正确传入两个参数
backtester_engine = BacktesterEngine(main_engine, event_engine)
print(f"✅ BacktesterEngine 创建成功: {backtester_engine}")
print(f"🔧 添加到主引擎...")
main_engine.add_app(backtester_engine)
print(f"✅ main_engine.add_app(backtester_engine) 完成")
# ============================================
# 关键部分结束
# ============================================
# 格式化日期
start_str = str(start)
if len(start_str) == 8:
start_str = f"{start_str[:4]}-{start_str[4:6]}-{start_str[6:8]}"
end_str = str(end)
if len(end_str) == 8:
end_str = f"{end_str[:4]}-{end_str[4:6]}-{end_str[6:8]}"
setting = {
"vt_symbol": symbol,
"interval": interval,
"start_date": start_str,
"end_date": end_str,
"rate": kwargs.get("rate", 0.00003),
"slippage": kwargs.get("slippage", 0.2),
"size": kwargs.get("size", 1),
"pricetick": kwargs.get("pricetick", 0.2),
"capital": kwargs.get("capital", 1000000.0),
}
print(f"✅ 回测参数: {setting}")
# 初始化引擎
print(f"🔧 初始化引擎: backtester_engine.init_engine()")
backtester_engine.init_engine()
print(f"✅ 初始化完成")
# 运行回测
print(f"🔧 运行回测...")
result = backtester_engine.run_backtesting(
strategy_class=StrategyClass,
setting=setting
)
# 获取结果
statistics = backtester_engine.get_result_statistics()
print(f"✅ 回测完成,统计指标: {list(statistics.keys()) if statistics else ''}")
# 获取每日数据
daily_df = backtester_engine.get_daily_df()
if daily_df is not None and hasattr(daily_df, 'to_dict'):
daily_data = daily_df.to_dict(orient='records')
else:
daily_data = []
# 获取交易记录
trades = backtester_engine.get_all_trades()
trade_list = [t.__dict__ for t in trades] if trades else []
return {
"statistics": statistics,
"trades": trade_list,
"daily_data": daily_data
}
except Exception as e:
error_info = {
"error": str(e),
"traceback": traceback.format_exc()
}
print(f"❌ 回测错误: {error_info['error']}")
print(error_info['traceback'])
return error_info
@app.get("/")
async def root():
return {
"message": "回测API服务正常运行 - 最终完整修复版本",
"version": "10.0.0-complete",
"fixes": [
"✅ vnpy.app模块兼容性修复",
"✅ BacktesterEngine 初始化正确修复 (传入main_engine + event_engine)",
"✅ 510300.SSE 数据已导入 (3361行)",
"✅ 完整单文件服务,确保所有代码正确",
],
"endpoints": {
"run_backtest": "/api/backtest/run",
"docs": "/docs",
},
}
@app.post("/api/backtest/run", response_model=ApiResponse)
async def run_backtest(request: BacktestRequest):
"""运行策略回测"""
try:
result = run_strategy_backtest_core(
strategy_code=request.strategy_code,
symbol=request.symbol,
interval=request.interval,
start=request.start,
end=request.end,
capital=request.capital,
rate=request.rate,
slippage=request.slippage,
size=request.size,
pricetick=request.pricetick,
)
if "error" in result:
return ApiResponse(
code=400,
msg="回测执行出错",
data=result,
error=result.get("error"),
error_detail=result.get("traceback"),
)
else:
return ApiResponse(
code=200,
msg="回测完成",
data=result,
error=None,
error_detail=None,
)
except Exception as e:
error_tb = traceback.format_exc()
return ApiResponse(
code=500,
msg="API服务内部错误",
error=str(e),
error_detail=error_tb,
)
def start_zmq_server():
"""启动ZMQ RPC服务器"""
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
bind_addr = f"tcp://0.0.0.0:{ZMQ_PORT}"
rep_socket.bind(bind_addr)
print(f"✅ ZMQ RPC服务器启动: {bind_addr}")
while True:
try:
req = rep_socket.recv_pyobj()
print(f"收到RPC请求: {req.get('function', 'unknown')}")
function_name = req.get("function")
args = req.get("args", [])
kwargs = req.get("kwargs", {})
if function_name == "run_strategy_backtest":
result = run_strategy_backtest_core(*args, **kwargs)
else:
result = {"error": f"未知函数: {function_name}"}
rep_socket.send_pyobj(result)
print(f"请求处理完成")
except Exception as e:
error_result = {
"error": str(e),
"traceback": traceback.format_exc()
}
rep_socket.send_pyobj(error_result)
print(f"处理请求出错: {e}")
if __name__ == "__main__":
import uvicorn
print("🚀 启动最终完整修复版本回测API服务")
print(f" 监听地址: 0.0.0.0:8088 (Docker已映射)")
print(f" BacktesterEngine: ✅ 正确传入两个参数 main_engine + event_engine")
print(f" 510300.SSE: ✅ 3361行数据已导入")
print(f" vnpy.app: ✅ 兼容性已修复")
print(f" 完整单文件服务,确保所有代码正确")
uvicorn.run(app, host="0.0.0.0", port=8088)