initial-import: 2026-04-11 21:18:55
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
修复后的FastAPI回测服务
|
||||
包含vnpy.app兼容性修复
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import zmq
|
||||
import pydantic
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
# 配置
|
||||
ZMQ_HOST = "127.0.0.1"
|
||||
ZMQ_PORT = 8001
|
||||
ZMQ_TIMEOUT = 30000 # 30秒超时
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="回测API服务",
|
||||
description="vn.py策略回测API服务",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 创建ZMQ上下文
|
||||
context = zmq.Context()
|
||||
|
||||
# 请求模型
|
||||
class BacktestRequest(pydantic.BaseModel):
|
||||
strategy_code: str
|
||||
symbol: str
|
||||
interval: str = "1d"
|
||||
start: int
|
||||
end: int
|
||||
capital: float = 1000000.0
|
||||
rate: float = 0.00003
|
||||
slippage: float = 0.2
|
||||
size: int = 1
|
||||
pricetick: float = 0.2
|
||||
|
||||
# 响应模型
|
||||
class ApiResponse(pydantic.BaseModel):
|
||||
code: int
|
||||
msg: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
error_detail: Optional[str] = None
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "回测API服务正常运行", "version": "1.0.0"}
|
||||
|
||||
@app.post("/api/backtest/run", response_model=ApiResponse)
|
||||
async def run_backtest(request: BacktestRequest):
|
||||
"""运行策略回测"""
|
||||
try:
|
||||
# 创建ZMQ客户端
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://{ZMQ_HOST}:{ZMQ_PORT}")
|
||||
|
||||
# 准备请求
|
||||
req = {
|
||||
"function": "run_strategy_backtest",
|
||||
"args": [],
|
||||
"kwargs": {
|
||||
"strategy_code": request.strategy_code,
|
||||
"symbol": request.symbol,
|
||||
"interval": request.interval,
|
||||
"start": request.start,
|
||||
"end": request.end,
|
||||
"capital": request.capital,
|
||||
"rate": request.rate,
|
||||
"slippage": request.slippage,
|
||||
"size": request.size,
|
||||
"pricetick": request.pricetick,
|
||||
}
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
socket.send_pyobj(req)
|
||||
|
||||
# 设置轮询器
|
||||
poller = zmq.Poller()
|
||||
poller.register(socket, zmq.POLLIN)
|
||||
events = poller.poll(ZMQ_TIMEOUT)
|
||||
|
||||
if not events:
|
||||
socket.close()
|
||||
return ApiResponse(
|
||||
code=504,
|
||||
msg="回测请求超时",
|
||||
error="请求超时,请检查服务状态",
|
||||
)
|
||||
|
||||
# 接收响应
|
||||
result = socket.recv_pyobj()
|
||||
socket.close()
|
||||
|
||||
if "error" in result:
|
||||
# 回测执行出错
|
||||
return ApiResponse(
|
||||
code=400,
|
||||
msg="回测执行出错",
|
||||
data=result,
|
||||
error=result.get("error"),
|
||||
error_detail=result.get("traceback"),
|
||||
)
|
||||
else:
|
||||
# 回测成功
|
||||
return ApiResponse(
|
||||
code=200,
|
||||
msg="回测完成",
|
||||
data=result,
|
||||
error=None,
|
||||
error_detail=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_tb = traceback.format_exc()
|
||||
return ApiResponse(
|
||||
code=500,
|
||||
msg="API服务内部错误",
|
||||
error=str(e),
|
||||
error_detail=error_tb,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print("🚀 启动修复后的回测API服务...")
|
||||
print(f" 监听地址: 0.0.0.0:8088")
|
||||
print(f" ZMQ RPC: tcp://{ZMQ_HOST}:{ZMQ_PORT}")
|
||||
print(f" vnpy.app兼容性: ✅ 已修复")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8088)
|
||||
Reference in New Issue
Block a user