157 lines
4.2 KiB
Python
157 lines
4.2 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
最终修复版本API - 使用已映射的端口 8088
|
||
"""
|
||
|
||
from fastapi import FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
import zmq
|
||
import pydantic
|
||
from typing import Optional, Dict, Any
|
||
|
||
# 配置 - RPC端口 8004 (容器内),API端口 8088 (已映射)
|
||
ZMQ_HOST = "127.0.0.1"
|
||
ZMQ_PORT = 8004
|
||
ZMQ_TIMEOUT = 30000
|
||
|
||
# 创建FastAPI应用
|
||
app = FastAPI(
|
||
title="回测API服务",
|
||
description="vn.py策略回测API服务 - 最终修复版本",
|
||
version="6.0.0",
|
||
)
|
||
|
||
# 配置CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 创建ZMQ上下文
|
||
context = zmq.Context()
|
||
|
||
# 请求模型
|
||
class BacktestRequest(pydantic.BaseModel):
|
||
strategy_code: str
|
||
symbol: str
|
||
interval: str = "1d"
|
||
start: int
|
||
end: int
|
||
capital: float = 1000000.0
|
||
rate: float = 0.00003
|
||
slippage: float = 0.2
|
||
size: int = 1
|
||
pricetick: float = 0.2
|
||
|
||
# 响应模型
|
||
class ApiResponse(pydantic.BaseModel):
|
||
code: int
|
||
msg: str
|
||
data: Optional[Dict[str, Any]] = None
|
||
error: Optional[str] = None
|
||
error_detail: Optional[str] = None
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
return {
|
||
"message": "回测API服务正常运行",
|
||
"version": "6.0.0",
|
||
"fixes": [
|
||
"✅ vnpy.app模块兼容性修复",
|
||
"✅ BacktesterEngine 初始化正确修复 (传入main_engine + event_engine)",
|
||
"✅ 510300.SSE 数据已导入 (3361行)",
|
||
],
|
||
"endpoints": {
|
||
"run_backtest": "/api/backtest/run",
|
||
"docs": "/docs",
|
||
},
|
||
}
|
||
|
||
@app.post("/api/backtest/run", response_model=ApiResponse)
|
||
async def run_backtest(request: BacktestRequest):
|
||
"""运行策略回测"""
|
||
try:
|
||
# 创建ZMQ客户端
|
||
socket = context.socket(zmq.REQ)
|
||
socket.connect(f"tcp://{ZMQ_HOST}:{ZMQ_PORT}")
|
||
|
||
# 准备请求
|
||
req = {
|
||
"function": "run_strategy_backtest",
|
||
"args": [],
|
||
"kwargs": {
|
||
"strategy_code": request.strategy_code,
|
||
"symbol": request.symbol,
|
||
"interval": request.interval,
|
||
"start": request.start,
|
||
"end": request.end,
|
||
"capital": request.capital,
|
||
"rate": request.rate,
|
||
"slippage": request.slippage,
|
||
"size": request.size,
|
||
"pricetick": request.pricetick,
|
||
},
|
||
}
|
||
|
||
# 发送请求
|
||
socket.send_pyobj(req)
|
||
|
||
# 设置轮询器
|
||
poller = zmq.Poller()
|
||
poller.register(socket, zmq.POLLIN)
|
||
events = poller.poll(ZMQ_TIMEOUT)
|
||
|
||
if not events:
|
||
socket.close()
|
||
return ApiResponse(
|
||
code=504,
|
||
msg="回测请求超时",
|
||
error="请求超时,请检查服务状态",
|
||
)
|
||
|
||
# 接收响应
|
||
result = socket.recv_pyobj()
|
||
socket.close()
|
||
|
||
if "error" in result:
|
||
# 回测执行出错
|
||
return ApiResponse(
|
||
code=400,
|
||
msg="回测执行出错",
|
||
data=result,
|
||
error=result.get("error"),
|
||
error_detail=result.get("traceback"),
|
||
)
|
||
else:
|
||
# 回测成功
|
||
return ApiResponse(
|
||
code=200,
|
||
msg="回测完成",
|
||
data=result,
|
||
error=None,
|
||
error_detail=None,
|
||
)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_tb = traceback.format_exc()
|
||
return ApiResponse(
|
||
code=500,
|
||
msg="API服务内部错误",
|
||
error=str(e),
|
||
error_detail=error_tb,
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
print("🚀 启动最终修复版本回测API服务")
|
||
print(f" 监听地址: 0.0.0.0:8088 (已映射到主机)")
|
||
print(f" ZMQ RPC: tcp://{ZMQ_HOST}:{ZMQ_PORT}")
|
||
print(f" BacktesterEngine: ✅ 正确传入两个参数")
|
||
print(f" 510300.SSE: ✅ 3361行数据已导入")
|
||
print(f" vnpy.app: ✅ 兼容性已修复")
|
||
uvicorn.run(app, host="0.0.0.0", port=8088)
|