143 lines
3.7 KiB
Python
143 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
修复后的FastAPI回测服务
|
|
包含vnpy.app兼容性修复
|
|
"""
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import zmq
|
|
import pydantic
|
|
from typing import Optional, Dict, Any
|
|
|
|
# 配置
|
|
ZMQ_HOST = "127.0.0.1"
|
|
ZMQ_PORT = 8001
|
|
ZMQ_TIMEOUT = 30000 # 30秒超时
|
|
|
|
# 创建FastAPI应用
|
|
app = FastAPI(
|
|
title="回测API服务",
|
|
description="vn.py策略回测API服务",
|
|
version="1.0.0",
|
|
)
|
|
|
|
# 配置CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 创建ZMQ上下文
|
|
context = zmq.Context()
|
|
|
|
# 请求模型
|
|
class BacktestRequest(pydantic.BaseModel):
|
|
strategy_code: str
|
|
symbol: str
|
|
interval: str = "1d"
|
|
start: int
|
|
end: int
|
|
capital: float = 1000000.0
|
|
rate: float = 0.00003
|
|
slippage: float = 0.2
|
|
size: int = 1
|
|
pricetick: float = 0.2
|
|
|
|
# 响应模型
|
|
class ApiResponse(pydantic.BaseModel):
|
|
code: int
|
|
msg: str
|
|
data: Optional[Dict[str, Any]] = None
|
|
error: Optional[str] = None
|
|
error_detail: Optional[str] = None
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "回测API服务正常运行", "version": "1.0.0"}
|
|
|
|
@app.post("/api/backtest/run", response_model=ApiResponse)
|
|
async def run_backtest(request: BacktestRequest):
|
|
"""运行策略回测"""
|
|
try:
|
|
# 创建ZMQ客户端
|
|
socket = context.socket(zmq.REQ)
|
|
socket.connect(f"tcp://{ZMQ_HOST}:{ZMQ_PORT}")
|
|
|
|
# 准备请求
|
|
req = {
|
|
"function": "run_strategy_backtest",
|
|
"args": [],
|
|
"kwargs": {
|
|
"strategy_code": request.strategy_code,
|
|
"symbol": request.symbol,
|
|
"interval": request.interval,
|
|
"start": request.start,
|
|
"end": request.end,
|
|
"capital": request.capital,
|
|
"rate": request.rate,
|
|
"slippage": request.slippage,
|
|
"size": request.size,
|
|
"pricetick": request.pricetick,
|
|
}
|
|
}
|
|
|
|
# 发送请求
|
|
socket.send_pyobj(req)
|
|
|
|
# 设置轮询器
|
|
poller = zmq.Poller()
|
|
poller.register(socket, zmq.POLLIN)
|
|
events = poller.poll(ZMQ_TIMEOUT)
|
|
|
|
if not events:
|
|
socket.close()
|
|
return ApiResponse(
|
|
code=504,
|
|
msg="回测请求超时",
|
|
error="请求超时,请检查服务状态",
|
|
)
|
|
|
|
# 接收响应
|
|
result = socket.recv_pyobj()
|
|
socket.close()
|
|
|
|
if "error" in result:
|
|
# 回测执行出错
|
|
return ApiResponse(
|
|
code=400,
|
|
msg="回测执行出错",
|
|
data=result,
|
|
error=result.get("error"),
|
|
error_detail=result.get("traceback"),
|
|
)
|
|
else:
|
|
# 回测成功
|
|
return ApiResponse(
|
|
code=200,
|
|
msg="回测完成",
|
|
data=result,
|
|
error=None,
|
|
error_detail=None,
|
|
)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
error_tb = traceback.format_exc()
|
|
return ApiResponse(
|
|
code=500,
|
|
msg="API服务内部错误",
|
|
error=str(e),
|
|
error_detail=error_tb,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
print("🚀 启动修复后的回测API服务...")
|
|
print(f" 监听地址: 0.0.0.0:8088")
|
|
print(f" ZMQ RPC: tcp://{ZMQ_HOST}:{ZMQ_PORT}")
|
|
print(f" vnpy.app兼容性: ✅ 已修复")
|
|
uvicorn.run(app, host="0.0.0.0", port=8088) |