auto-sync: 2026-04-29 20:14:41

This commit is contained in:
cfdaily
2026-04-29 20:14:41 +08:00
parent e3fad59483
commit b66cc84e85
14 changed files with 1 additions and 0 deletions
-149
View File
@@ -1,149 +0,0 @@
# 自动化回测服务 - 使用说明
## 概述
基于 vnpy 原生 `BacktestingEngine` 封装的 RESTful API 自动化回测服务。
支持:
- 通过 API 提交回测任务
- 自动排队控制并发
- 使用 vnpy 原生回测引擎执行
- 保存回测结果(CSV + 图表 + JSON
- 查询任务状态和结果
## 架构设计
严格遵循 vnpy 原生设计,不修改核心,只做外层封装:
```
[API 服务] ← 接收任务
[任务队列 + 进程池] ← 控制并发
[BacktestingEngine (vnpy 原生)] ← 执行回测
[文件存储] ← 保存结果
[API 查询] ← 返回结果
```
## 启动方式
```bash
# 手动启动
cd /app/scripts/backtest-service
python main.py
# 后台运行
nohup python main.py > backtest-service.log 2>&1 &
# 查看日志
tail -f backtest-service.log
```
## 访问地址
启动后访问:
- 服务地址:http://container-ip:8088
- API 文档:http://container-ip:8088/docs → 可以直接在网页上测试API
## API 接口说明
| 接口 | 方法 | 路径 | 说明 |
|------|------|------|------|
| 提交回测 | POST | `/api/backtest/submit` | 提交新的回测任务 |
| 任务列表 | GET | `/api/backtest/list` | 列出任务,支持分页和状态过滤 |
| 任务状态 | GET | `/api/backtest/status/{task_id}` | 查询任务状态 |
| 回测结果 | GET | `/api/backtest/result/{task_id}` | 获取完整回测结果 |
| 删除任务 | DELETE | `/api/backtest/delete/{task_id}` | 删除任务 |
| 健康检查 | GET | `/api/backtest/health` | 查看服务状态,返回任务统计 |
## 配置
可以通过环境变量覆盖默认配置:
| 环境变量 | 说明 | 默认值 |
|----------|------|--------|
| `MAX_WORKERS` | 最大并发回测数 | 2 |
| `HOST` | 监听地址 | 0.0.0.0 |
| `PORT` | 监听端口 | 8088 |
| `BASE_DIR` | 任务存储根目录 | /app/backtest_jobs |
| `DEBUG` | 调试模式 | True |
## 使用示例
### 1. 提交回测
```bash
curl -X POST http://127.0.0.1:8088/api/backtest/submit \
-H "Content-Type: application/json" \
-d '{
"strategy_name": "双均线策略",
"strategy_code": "完整的策略Python代码...",
"symbol": "IF888.CFFEX",
"interval": "1h",
"start_date": "2020-01-01",
"end_date": "2025-01-01",
"parameters": {
"fast_window": 5,
"slow_window": 20
},
"capital": 1000000
}'
```
响应:
```json
{
"code": 0,
"msg": "任务提交成功",
"data": {
"task_id": "a1b2c3d4...",
"status": "pending",
"created_at": "2026-04-12T10:00:00+08:00"
}
}
```
### 2. 查询任务状态
```bash
curl http://127.0.0.1:8088/api/backtest/status/a1b2c3d4
```
### 3. 获取回测结果
```bash
curl http://127.0.0.1:8088/api/backtest/result/a1b2c3d4
```
## 结果存储结构
```
/app/backtest_jobs/
├── pending/ # 等待执行
├── running/ # 执行中
├── completed/ # 已完成
│ └── <task_id>/
│ ├── task.json # 任务信息
│ ├── result.json # 回测结果
│ ├── equity.csv # 每日净值
│ ├── equity_curve.png # 收益曲线图
│ └── trades.csv # 成交记录
└── failed/ # 执行失败
└── <task_id>/
├── task.json
└── result.json # 包含错误信息
```
## 设计原则
1. **不改动 vnpy 核心**:完全复用原生 `BacktestingEngine`
2. **轻量级**:只用 Python 标准库,不引入额外第三方任务队列
3. **隔离性**:每个回测在独立进程,一个失败不影响其他
4. **可配置并发**:根据 CPU 性能调整 `MAX_WORKERS`,避免资源耗尽
## 作者
三国量化团队 姜维 伯约
2026-04-12
-92
View File
@@ -1,92 +0,0 @@
"""
自动化回测服务 - API路由
"""
from typing import Optional
from fastapi import APIRouter, Query
from .models import (
BacktestTask,
BacktestTaskWithId,
BacktestResult,
TaskListResponse,
ApiResponse,
HealthCheckResponse,
TaskStatus,
)
from .task_queue import task_queue
from .result_storage import storage
from .executor import executor
router = APIRouter()
@router.post("/submit", summary="提交回测任务")
def submit_task(task: BacktestTask) -> ApiResponse[BacktestTaskWithId]:
"""提交一个新的回测任务"""
task_with_id = task_queue.submit_task(task)
storage.save_task(task_with_id)
return ApiResponse(
code=0,
msg="任务提交成功",
data=task_with_id
)
@router.get("/list", summary="列出回测任务")
def list_tasks(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页数量"),
status: Optional[str] = Query(None, description="状态过滤 pending/running/completed/failed")
) -> ApiResponse[TaskListResponse]:
"""列出回测任务,支持分页和状态过滤"""
result = task_queue.list_tasks(page, page_size, status)
return ApiResponse(
code=0,
msg="success",
data=TaskListResponse(**result)
)
@router.get("/status/{task_id}", summary="查询任务状态")
def get_status(task_id: str) -> ApiResponse[Optional[BacktestTaskWithId]]:
"""查询单个任务状态(从磁盘查找)"""
task = storage.find_task(task_id)
if not task:
return ApiResponse(code=404, msg="任务不存在", data=None)
return ApiResponse(code=0, msg="success", data=task)
@router.get("/result/{task_id}", summary="获取回测结果")
def get_result(task_id: str) -> ApiResponse[Optional[BacktestResult]]:
"""获取回测完整结果(从磁盘查找)"""
result = storage.find_result(task_id)
if not result:
task = storage.find_task(task_id)
if not task:
return ApiResponse(code=404, msg="任务不存在", data=None)
return ApiResponse(code=0, msg="任务尚未完成", data=None)
return ApiResponse(code=0, msg="success", data=result)
@router.delete("/delete/{task_id}", summary="删除回测任务")
def delete_task(task_id: str) -> ApiResponse[None]:
"""删除一个回测任务"""
# TODO: 实现物理删除
# 现在只返回成功,后续实现
return ApiResponse(code=0, msg="删除成功(待实现物理删除)", data=None)
@router.get("/health", summary="健康检查")
def health_check() -> ApiResponse[HealthCheckResponse]:
"""服务健康检查,返回任务统计信息"""
return ApiResponse(
code=0,
msg="ok",
data=HealthCheckResponse(
pending_count=len(task_queue.pending_tasks),
running_count=len(task_queue.running_tasks),
completed_count=len(task_queue.completed_tasks),
failed_count=len(task_queue.failed_tasks),
max_workers=task_queue.max_workers
)
)
-32
View File
@@ -1,32 +0,0 @@
"""
自动化回测服务 - 配置
"""
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""回测服务配置"""
# 服务监听地址
host: str = "0.0.0.0"
port: int = 8088
# 最大并发回测数(根据CPU核数调整,NAS赛扬建议2)
max_workers: int = 2
# 回测任务存储根目录
base_dir: str = "/app/backtest_jobs"
# CORS 配置 - 开发环境允许所有来源
cors_allow_all: bool = True
# 调试模式
debug: bool = True
# 允许策略代码中导入模块
allow_imports: bool = True
# 最大回测时间限制(秒),防止无限循环
max_execution_time: int = 3600 # 1小时
settings = Settings()
-262
View File
@@ -1,262 +0,0 @@
"""
自动化回测服务 - 任务执行器
调用 vnpy 4.x BacktestingEngine 执行回测
"""
import os
import sys
import tempfile
import traceback
from datetime import datetime
from typing import Optional
import matplotlib
matplotlib.use("Agg") # 无头模式,服务器上不能弹窗
import matplotlib.pyplot as plt
import pandas as pd
# vnpy 4.x import路径(与3.x不同)
from vnpy.event import EventEngine
from vnpy.trader.engine import MainEngine
from vnpy_ctastrategy.backtesting import BacktestingEngine
from vnpy.trader.constant import Interval, Exchange
from .config import settings
from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId
from .result_storage import storage
# vnpy 4.x精简了Interval枚举,不再有FIVE_MINUTE等细分
INTERVAL_MAP = {
"1m": Interval.MINUTE,
"5m": Interval.MINUTE,
"15m": Interval.MINUTE,
"30m": Interval.MINUTE,
"1h": Interval.HOUR,
"4h": Interval.HOUR,
"1d": Interval.DAILY,
"1w": Interval.WEEKLY,
}
# 交易所映射
EXCHANGE_MAP = {
"SSE": Exchange.SSE,
"SZSE": Exchange.SZSE,
"CFFEX": Exchange.CFFEX,
"SHFE": Exchange.SHFE,
"DCE": Exchange.DCE,
"CZCE": Exchange.CZCE,
"INE": Exchange.INE,
"GFEX": Exchange.GFEX,
}
def _parse_vt_symbol(vt_symbol: str):
"""解析vt_symbol为symbol和exchange,如 '000001.SZ' → ('000001', Exchange.SZSE)"""
if "." in vt_symbol:
symbol, exchange_str = vt_symbol.rsplit(".", 1)
exchange = EXCHANGE_MAP.get(exchange_str.upper())
if exchange is None:
# 尝试模糊匹配
exchange_str_upper = exchange_str.upper()
for key, val in EXCHANGE_MAP.items():
if key.startswith(exchange_str_upper[:2]):
exchange = val
break
if exchange is None:
exchange = Exchange.SZSE # 默认深交所
return symbol, exchange
return vt_symbol, Exchange.SZSE
class BacktestExecutor:
"""回测任务执行器 - 适配vnpy 4.x"""
def __init__(self):
pass
def _load_strategy(self, task: BacktestTask):
"""动态加载策略代码"""
strategy_code = task.strategy_code
temp_dir = tempfile.mkdtemp()
sys.path.insert(0, temp_dir)
strategy_file = os.path.join(temp_dir, "strategy.py")
with open(strategy_file, "w", encoding="utf-8") as f:
f.write(strategy_code)
import importlib
spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file)
module = importlib.util.module_from_spec(spec)
sys.modules["dynamic_strategy"] = module
spec.loader.exec_module(module)
# 找CtaTemplate子类
from vnpy_ctastrategy import CtaTemplate
strategy_class = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, CtaTemplate) and attr is not CtaTemplate:
strategy_class = attr
break
if not strategy_class:
raise ValueError("策略代码中没有找到 CtaTemplate 子类,请检查策略代码")
return strategy_class
def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult:
"""执行一次回测"""
start_time = datetime.now()
started_at = start_time.isoformat()
task.status = TaskStatus.RUNNING
task.started_at = started_at
storage.save_task(task)
result = BacktestResult(
task_id=task.task_id,
strategy_name=task.strategy_name,
status=TaskStatus.RUNNING,
result_csv_path="",
created_at=task.created_at,
started_at=started_at,
)
try:
# 加载策略类
strategy_class = self._load_strategy(task)
# 解析vt_symbol
symbol, exchange = _parse_vt_symbol(task.symbol)
# 获取interval
interval = INTERVAL_MAP.get(task.interval, Interval.DAILY)
# 创建回测引擎
engine = BacktestingEngine()
# 设置回测参数
engine.set_parameters(
vt_symbol=task.symbol,
interval=interval,
start=task.start_date,
end=task.end_date,
rate=0.3 / 10000, # 手续费率万三
slippage=0.1, # 滑点0.1
size=1, # 合约乘数
pricetick=task.tick_size or 0.01, # 最小价格变动
capital=task.capital,
)
# 添加策略
engine.add_strategy(strategy_class, task.parameters)
# 加载历史数据
# 优先从CSV文件加载(/app/data目录通过volume挂载NAS数据)
data_loaded = False
data_dir = settings.base_dir.replace("backtest_jobs", "data")
# 尝试多种数据加载方式
try:
# 方式1: 使用vnpy内置数据加载
engine.load_data()
data_loaded = True
except Exception:
pass
if not data_loaded:
raise ValueError(
f"无法加载 {task.symbol} 在 [{task.start_date}, {task.end_date}] 的历史数据。"
f"请确保数据已导入vnpy数据库或可通过CSV加载。"
)
# 运行回测
engine.run_backtesting()
# 计算统计结果
df = engine.calculate_result()
statistics = engine.calculate_statistics()
# 转换为数据模型
stats = BacktestStatistics(
start_date=str(task.start_date),
end_date=str(task.end_date),
total_days=int(statistics.get("total_days", 0)),
total_trades=int(statistics.get("total_trades", 0)),
winning_trades=int(statistics.get("winning_trades", 0)),
losing_trades=int(statistics.get("losing_trades", 0)),
win_rate=float(statistics.get("win_rate", 0)),
total_return=float(statistics.get("total_return", 0)),
annual_return=float(statistics.get("annual_return", 0)),
sharpe_ratio=float(statistics.get("sharpe_ratio", 0)),
max_drawdown=float(statistics.get("max_drawdown", 0)),
max_drawdown_start=str(statistics.get("max_drawdown_start", "")),
max_drawdown_end=str(statistics.get("max_drawdown_end", "")),
profit_factor=float(statistics.get("profit_factor", 0)),
calmar_ratio=float(statistics.get("calmar_ratio", 0)),
)
# 保存净值CSV
result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv")
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
if df is not None and not df.empty:
df.to_csv(result_csv_path)
# 绘制收益曲线
png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png")
self._plot_equity_curve(df, png_path)
# 保存成交记录
trades_csv_path = None
try:
trades = engine.get_all_trades() if hasattr(engine, 'get_all_trades') else []
if trades:
trades_df = pd.DataFrame([t.__dict__ for t in trades])
trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv")
trades_df.to_csv(trades_csv_path, index=False)
except Exception:
pass
# 完成结果
result.status = TaskStatus.COMPLETED
result.statistics = stats
result.result_csv_path = result_csv_path
result.equity_curve_png_path = png_path
result.trades_csv_path = trades_csv_path
result.completed_at = datetime.now().isoformat()
storage.save_result(result)
return result
except Exception as e:
error_msg = f"{str(e)}\n{traceback.format_exc()}"
result.status = TaskStatus.FAILED
result.error_message = error_msg
result.completed_at = datetime.now().isoformat()
storage.save_result(result)
return result
def _plot_equity_curve(self, df: pd.DataFrame, output_path: str):
"""绘制收益曲线"""
plt.figure(figsize=(12, 6))
if df is not None and not df.empty:
if "equity" in df.columns:
plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2)
elif "net_pnl" in df.columns:
cumulative = df["net_pnl"].cumsum()
plt.plot(df.index, cumulative, label="累计收益", linewidth=2)
elif "balance" in df.columns:
plt.plot(df.index, df["balance"], label="账户余额", linewidth=2)
plt.title("回测收益曲线")
plt.xlabel("时间")
plt.ylabel("净值")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
executor = BacktestExecutor()
-72
View File
@@ -1,72 +0,0 @@
"""
自动化回测服务 - 主入口
启动 FastAPI 服务,接受回测任务提交,执行回测,返回结果
遵循 vnpy 原生设计,只做外层封装
"""
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .config import settings
from .api import router
from .task_queue import task_queue
from .models import ApiResponse, HealthCheckResponse
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期:启动时开启worker线程,关闭时停止"""
# 启动
task_queue.start_worker_pool()
print(f"✅ 回测服务启动 (max_workers={settings.max_workers})")
yield
# 关闭
task_queue.close_worker_pool()
print("回测服务已停止")
app = FastAPI(
title="sanguo 自动化回测服务",
description="基于 vnpy 原生 BacktestingEngine 的自动化回测API服务",
version="1.0.0",
lifespan=lifespan
)
# CORS 配置
if settings.cors_allow_all:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册API路由
app.include_router(router, tags=["backtest"])
@app.get("/api/backtest/health", summary="服务健康检查", response_model=ApiResponse[HealthCheckResponse])
def health():
"""服务健康检查,返回当前任务统计"""
return ApiResponse(
code=0,
msg="ok",
data=HealthCheckResponse(
pending_count=len(task_queue.pending_tasks),
running_count=len(task_queue.running_tasks),
completed_count=len(task_queue.completed_tasks),
failed_count=len(task_queue.failed_tasks),
max_workers=settings.max_workers
)
)
if __name__ == "__main__":
uvicorn.run(
"backtest_service.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
)
-97
View File
@@ -1,97 +0,0 @@
"""
自动化回测服务 - 数据模型
"""
from enum import Enum
from datetime import date, datetime
from typing import Dict, Optional, Any, List, Generic, TypeVar
from pydantic import BaseModel, Field
class TaskStatus(str, Enum):
"""任务状态枚举"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class BacktestTask(BaseModel):
"""回测任务请求"""
strategy_name: str = Field(..., description="策略名称")
strategy_code: str = Field(..., description="策略完整Python代码")
symbol: str = Field(..., description="交易品种,例如 000001.SSE")
interval: str = Field(..., description="K线周期,例如 1d、1h")
start_date: date = Field(..., description="回测开始日期")
end_date: date = Field(..., description="回测结束日期")
parameters: Dict[str, Any] = Field(default_factory=dict, description="策略参数字典")
capital: float = Field(default=1_000_000, description="起始资金")
tick_size: Optional[float] = Field(None, description="最小价格变动,不指定则自动获取")
class BacktestTaskWithId(BacktestTask):
"""带ID和状态的回测任务"""
task_id: str = Field(..., description="任务唯一ID")
status: TaskStatus = Field(..., description="任务状态")
created_at: str = Field(..., description="创建时间 ISO格式")
started_at: Optional[str] = Field(None, description="开始时间")
completed_at: Optional[str] = Field(None, description="完成时间")
class BacktestStatistics(BaseModel):
"""回测结果统计"""
start_date: str
end_date: str
total_days: int
total_trades: int
winning_trades: int
losing_trades: int
win_rate: float
total_return: float # 总收益率
annual_return: float # 年化收益率
sharpe_ratio: float # 夏普比率
max_drawdown: float # 最大回撤
max_drawdown_start: Optional[str] = None
max_drawdown_end: Optional[str] = None
profit_factor: float # 收益因子(总盈利/总亏损)
calmar_ratio: float # 卡玛比率(年化收益/最大回撤)
class BacktestResult(BaseModel):
"""完整回测结果"""
task_id: str
strategy_name: str
status: TaskStatus
statistics: Optional[BacktestStatistics] = None
result_csv_path: str # 每日净值CSV路径
equity_curve_png_path: Optional[str] = None # 收益曲线图片路径
trades_csv_path: Optional[str] = None # 成交记录CSV路径
error_message: Optional[str] = None # 失败时的错误信息
created_at: str
started_at: Optional[str] = None
completed_at: Optional[str] = None
class TaskListResponse(BaseModel):
"""任务列表响应"""
total: int
page: int
page_size: int
tasks: List[BacktestTaskWithId]
T = TypeVar("T")
class ApiResponse(BaseModel, Generic[T]):
"""通用API响应包装"""
code: int = Field(0, description="0表示成功,非0表示错误")
msg: str = Field("success", description="响应消息")
data: Optional[T] = Field(None, description="响应数据")
class HealthCheckResponse(BaseModel):
"""健康检查响应"""
pending_count: int
running_count: int
completed_count: int
failed_count: int
max_workers: int
-95
View File
@@ -1,95 +0,0 @@
"""
自动化回测服务 - 数据模型
"""
from enum import Enum
from datetime import date, datetime
from typing import Dict, Optional, Any, List
from pydantic import BaseModel, Field
class TaskStatus(str, Enum):
"""任务状态枚举"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class BacktestTask(BaseModel):
"""回测任务请求"""
strategy_name: str = Field(..., description="策略名称")
strategy_code: str = Field(..., description="策略完整Python代码")
symbol: str = Field(..., description="交易品种,例如 000001.SSE")
interval: str = Field(..., description="K线周期,例如 1d、1h")
start_date: date = Field(..., description="回测开始日期")
end_date: date = Field(..., description="回测结束日期")
parameters: Dict[str, Any] = Field(default_factory=dict, description="策略参数字典")
capital: float = Field(default=1_000_000, description="起始资金")
tick_size: Optional[float] = Field(None, description="最小价格变动,不指定则自动获取")
class BacktestTaskWithId(BacktestTask):
"""带ID和状态的回测任务"""
task_id: str = Field(..., description="任务唯一ID")
status: TaskStatus = Field(..., description="任务状态")
created_at: str = Field(..., description="创建时间 ISO格式")
started_at: Optional[str] = Field(None, description="开始时间")
completed_at: Optional[str] = Field(None, description="完成时间")
class BacktestStatistics(BaseModel):
"""回测结果统计"""
start_date: str
end_date: str
total_days: int
total_trades: int
winning_trades: int
losing_trades: int
win_rate: float
total_return: float # 总收益率
annual_return: float # 年化收益率
sharpe_ratio: float # 夏普比率
max_drawdown: float # 最大回撤
max_drawdown_start: Optional[str] = None
max_drawdown_end: Optional[str] = None
profit_factor: float # 收益因子(总盈利/总亏损)
calmar_ratio: float # 卡玛比率(年化收益/最大回撤)
class BacktestResult(BaseModel):
"""完整回测结果"""
task_id: str
strategy_name: str
status: TaskStatus
statistics: Optional[BacktestStatistics] = None
result_csv_path: str # 每日净值CSV路径
equity_curve_png_path: Optional[str] = None # 收益曲线图片路径
trades_csv_path: Optional[str] = None # 成交记录CSV路径
error_message: Optional[str] = None # 失败时的错误信息
created_at: str
started_at: Optional[str] = None
completed_at: Optional[str] = None
class TaskListResponse(BaseModel):
"""任务列表响应"""
total: int
page: int
page_size: int
tasks: List[BacktestTaskWithId]
class ApiResponse[T](BaseModel):
"""通用API响应包装"""
code: int = Field(0, description="0表示成功,非0表示错误")
msg: str = Field("success", description="响应消息")
data: Optional[T] = Field(None, description="响应数据")
class HealthCheckResponse(BaseModel):
"""健康检查响应"""
pending_count: int
running_count: int
completed_count: int
failed_count: int
max_workers: int
-101
View File
@@ -1,101 +0,0 @@
"""
自动化回测服务 - 结果存储
"""
import json
import os
from datetime import date, datetime
from typing import Optional
from .models import BacktestTaskWithId, BacktestResult
from .config import settings
def _json_serial(obj):
"""JSON序列化辅助:处理date/datetime"""
if isinstance(obj, (date, datetime)):
return obj.isoformat()
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
class ResultStorage:
"""结果存储管理器"""
def __init__(self):
self.base_dir = settings.base_dir
self._ensure_dirs()
def _ensure_dirs(self):
"""确保目录结构存在"""
for status_dir in ["pending", "running", "completed", "failed"]:
path = os.path.join(self.base_dir, status_dir)
os.makedirs(path, exist_ok=True)
def _task_dir(self, task_id: str, status: str) -> str:
"""获取任务目录"""
return os.path.join(self.base_dir, status, task_id)
def save_task(self, task: BacktestTaskWithId) -> None:
"""保存任务信息"""
task_dir = self._task_dir(task.task_id, task.status)
os.makedirs(task_dir, exist_ok=True)
info_file = os.path.join(task_dir, "task.json")
with open(info_file, "w", encoding="utf-8") as f:
json.dump(task.model_dump(), f, indent=2, ensure_ascii=False, default=_json_serial)
def load_task(self, task_id: str, status: str) -> Optional[BacktestTaskWithId]:
"""加载任务信息"""
task_dir = self._task_dir(task_id, status)
info_file = os.path.join(task_dir, "task.json")
if not os.path.exists(info_file):
return None
with open(info_file, "r", encoding="utf-8") as f:
data = json.load(f)
return BacktestTaskWithId(**data)
def save_result(self, result: BacktestResult) -> None:
"""保存回测结果"""
task_dir = self._task_dir(result.task_id, result.status)
os.makedirs(task_dir, exist_ok=True)
result_file = os.path.join(task_dir, "result.json")
with open(result_file, "w", encoding="utf-8") as f:
json.dump(result.model_dump(), f, indent=2, ensure_ascii=False, default=_json_serial)
def load_result(self, task_id: str, status: str) -> Optional[BacktestResult]:
"""加载回测结果"""
task_dir = self._task_dir(task_id, status)
result_file = os.path.join(task_dir, "result.json")
if not os.path.exists(result_file):
return None
with open(result_file, "r", encoding="utf-8") as f:
data = json.load(f)
return BacktestResult(**data)
def find_task(self, task_id: str) -> Optional[BacktestTaskWithId]:
"""在所有状态目录中查找任务"""
for status_dir in ["running", "failed", "completed", "pending"]:
task = self.load_task(task_id, status_dir)
if task:
return task
return None
def find_result(self, task_id: str) -> Optional[BacktestResult]:
"""在所有状态目录中查找结果"""
for status_dir in ["failed", "completed", "running", "pending"]:
result = self.load_result(task_id, status_dir)
if result:
return result
return None
def get_task_path(self, task_id: str, status: str, filename: str) -> str:
"""获取任务文件路径"""
return os.path.join(self._task_dir(task_id, status), filename)
storage = ResultStorage()
-141
View File
@@ -1,141 +0,0 @@
"""
自动化回测服务 - 任务队列
简单后台线程调度:submit后自动触发执行,同一时间只跑一个回测
"""
import os
import uuid
import threading
import time
import traceback
from datetime import datetime
from typing import List, Optional, Dict
from .config import settings
from .models import TaskStatus, BacktestTask, BacktestTaskWithId
from .result_storage import storage
class TaskQueue:
"""任务队列管理器 - 后台线程调度"""
def __init__(self):
self.max_workers = settings.max_workers
self.pending_tasks: List[str] = []
self.running_tasks: List[str] = []
self.completed_tasks: List[str] = []
self.failed_tasks: List[str] = []
self._worker_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
def _generate_task_id(self) -> str:
return str(uuid.uuid4()).replace("-", "")
def submit_task(self, task: BacktestTask) -> BacktestTaskWithId:
"""提交新任务到队列"""
task_id = self._generate_task_id()
now = datetime.now().isoformat()
task_with_id = BacktestTaskWithId(
task_id=task_id,
status=TaskStatus.PENDING,
created_at=now,
**task.model_dump()
)
storage.save_task(task_with_id)
self.pending_tasks.append(task_id)
return task_with_id
def list_tasks(self, page: int = 1, page_size: int = 10, status: Optional[str] = None) -> Dict:
if status == "pending":
task_ids = self.pending_tasks
elif status == "running":
task_ids = self.running_tasks
elif status == "completed":
task_ids = self.completed_tasks
elif status == "failed":
task_ids = self.failed_tasks
else:
task_ids = (
self.pending_tasks +
self.running_tasks +
self.completed_tasks +
self.failed_tasks
)
total = len(task_ids)
start = (page - 1) * page_size
end = start + page_size
result = []
for task_id in task_ids[start:end]:
for status_dir in ["pending", "running", "completed", "failed"]:
task = storage.load_task(task_id, status_dir)
if task:
result.append(task)
break
return {
"total": total,
"page": page,
"page_size": page_size,
"tasks": result
}
def get_task(self, task_id: str) -> Optional[BacktestTaskWithId]:
for status_dir in ["pending", "running", "completed", "failed"]:
task = storage.load_task(task_id, status_dir)
if task:
return task
return None
def _worker_loop(self):
"""后台工作线程:循环检查pending任务并执行"""
from .executor import executor
while not self._stop_event.is_set():
# 检查是否有pending任务且当前无running任务
if self.pending_tasks and not self.running_tasks:
task_id = self.pending_tasks.pop(0)
self.running_tasks.append(task_id)
# 在后台线程中执行回测
try:
task = storage.load_task(task_id, "pending")
if task:
# 移动到running目录
storage.save_task(task)
result = executor.execute_backtest(task)
# 从running移到completed/failed
self.running_tasks.remove(task_id)
if result.status == TaskStatus.COMPLETED:
self.completed_tasks.append(task_id)
else:
self.failed_tasks.append(task_id)
except Exception as e:
print(f"任务执行异常: {e}\n{traceback.format_exc()}")
if task_id in self.running_tasks:
self.running_tasks.remove(task_id)
self.failed_tasks.append(task_id)
# 等待1秒再检查
self._stop_event.wait(1.0)
def start_worker_pool(self):
"""启动后台工作线程"""
if self._worker_thread is None or not self._worker_thread.is_alive():
self._stop_event.clear()
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
self._worker_thread.start()
print(f"工作线程已启动 (max_workers={self.max_workers})")
def close_worker_pool(self):
"""停止工作线程"""
self._stop_event.set()
if self._worker_thread and self._worker_thread.is_alive():
self._worker_thread.join(timeout=5)
self._worker_thread = None
task_queue = TaskQueue()
-17
View File
@@ -1,17 +0,0 @@
#!/usr/bin/env python3
"""
自动化回测服务启动脚本
启动 FastAPI 服务监听 8088 端口接受回测任务提交
"""
import sys
import os
# 添加 backtest-service 到 sys.path
backtest_dir = os.path.join(os.path.dirname(__file__), "backtest-service")
sys.path.insert(0, backtest_dir)
# 现在导入 main
from main import main
if __name__ == "__main__":
main()
-71
View File
@@ -1,71 +0,0 @@
#!/usr/bin/env python3
"""
VNPY RPC 交易核心服务启动脚本
启动 vnpy 交易核心开启 RPC 服务端 VNPY Web Trader 连接
按照 vnpy 官方标准双进程架构
- 这个进程交易核心 + RPC 服务端 (端口 2018/4102)
- 另一个进程VNPY Web Trader (端口 8000)
"""
from vnpy.trader.event_engine import EventEngine
from vnpy.trader.main_engine import MainEngine
from vnpy.rpc import RpcServer
# 导入你的 gateway
from vnpy_ctp import CtpGateway
# from vnpy_tap import TapGateway
# from vnpy_ib import IbGateway
# 其他 gateway 根据需要添加
# 导入你的策略应用
from vnpy_ctastrategy import CtaStrategyApp
# from vnpy_portfoliostrategy import PortfolioStrategyApp
# from vnpy_spreadtrading import SpreadTradingApp
def main():
# 创建事件引擎和主引擎
event_engine = EventEngine()
main_engine = MainEngine(event_engine)
# 添加 gateway
main_engine.add_gateway(CtpGateway)
# 如果你有其他gateway,在这里添加
# main_engine.add_gateway(TapGateway)
# 添加策略应用
main_engine.add_app(CtaStrategyApp)
# 如果需要其他应用,在这里添加
# main_engine.add_app(PortfolioStrategyApp)
# 启动 RPC 服务端
# 请求端口: 2018,订阅端口: 4102
rpc_server = RpcServer(
main_engine,
("0.0.0.0", 2018),
("0.0.0.0", 4102)
)
print("=" * 50)
print("VNPY RPC 交易核心服务启动")
print(f"RPC 请求地址: tcp://0.0.0.0:2018")
print(f"RPC 订阅地址: tcp://0.0.0.0:4102")
print("请确保已经在 vnpy 配置中配置好你的 gateway")
print("=" * 50)
# 启动 RPC 服务
rpc_server.start()
# 保持进程运行
print("RPC 服务已启动,按 Ctrl+C 退出")
try:
input()
except (KeyboardInterrupt, EOFError):
pass
print("RPC 服务退出")
if __name__ == "__main__":
main()
-38
View File
@@ -1,38 +0,0 @@
#!/usr/bin/env python3
"""
VNPY Web Trader 服务启动脚本
按照 vnpy 官方标准双进程架构
- 需要先启动 start_rpc_server.py 交易核心
- 然后启动这个 Web Trader 服务
- Web Trader 通过 RPC 连接交易核心
"""
from vnpy_webtrader import run_web_trader
def main():
# RPC 连接地址,默认和交易核心同机
rpc_request_address = "tcp://127.0.0.1:2018"
rpc_subscribe_address = "tcp://127.0.0.1:4102"
print("=" * 50)
print("VNPY Web Trader 服务启动")
print(f"RPC 请求地址: {rpc_request_address}")
print(f"RPC 订阅地址: {rpc_subscribe_address}")
print(f"Web 服务监听: 0.0.0.0:8000")
print("=" * 50)
print("请确保先启动 start_rpc_server.py")
print("=" * 50)
# 启动 Web Trader
run_web_trader(
rpc_request_address,
rpc_subscribe_address,
host="0.0.0.0",
port=8000,
cors_allow_all=True # 开发环境允许跨域
)
if __name__ == "__main__":
main()
-469
View File
@@ -1,469 +0,0 @@
#!/usr/bin/env python3
"""
vn.py本地数据适配器 - 姜维
功能让vn.py优先加载赵云将军下载的本地数据本地没有再去akshare接口下载
"""
import pandas as pd
import os
import glob
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, List, Tuple
import akshare as ak
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('vnpy_local_data_adapter.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class VnpyLocalDataAdapter:
"""
vn.py本地数据适配器
实现策略优先本地 fallback akshare
"""
# 赵云数据目录配置
ZHAOYUN_DATA_BASE = "/Users/chufeng/nas/stock/sanguo_vnpy/zhaoyun-data/data"
# 数据目录映射
DATA_DIRS = {
'daily': os.path.join(ZHAOYUN_DATA_BASE, "raw/daily"),
'financial': os.path.join(ZHAOYUN_DATA_BASE, "raw/financial"),
'stock_info': os.path.join(ZHAOYUN_DATA_BASE, "raw/stock_info"),
'minute': os.path.join(ZHAOYUN_DATA_BASE, "raw/minute_kline"),
}
# vn.py需要的字段映射
VNPY_FIELD_MAP = {
'date': 'datetime',
'open': 'open_price',
'high': 'high_price',
'low': 'low_price',
'close': 'close_price',
'volume': 'volume',
'amount': 'turnover',
'turnover': 'turnover_rate',
}
def __init__(self, use_local_first: bool = True):
"""
初始化适配器
Args:
use_local_first: 是否优先使用本地数据
"""
self.use_local_first = use_local_first
self._validate_data_dirs()
def _validate_data_dirs(self):
"""验证数据目录是否存在"""
for name, path in self.DATA_DIRS.items():
if os.path.exists(path):
logger.info(f"✅ 赵云数据目录 {name}: {path}")
else:
logger.warning(f"⚠️ 赵云数据目录不存在 {name}: {path}")
def _parse_symbol(self, symbol: str) -> Tuple[str, str]:
"""
解析股票代码返回标准化代码和交易所
Args:
symbol: 股票代码 "000001.SZ" "600000"
Returns:
(symbol_code, exchange): ("000001", "SZ")
"""
# 移除后缀
if '.' in symbol:
symbol_code, exchange = symbol.split('.')
exchange = exchange.upper()
else:
symbol_code = symbol
# 根据代码判断交易所
if symbol_code.startswith('6'):
exchange = 'SH'
elif symbol_code.startswith(('0', '3')):
exchange = 'SZ'
elif symbol_code.startswith('8'):
exchange = 'BJ'
else:
exchange = 'SZ' # 默认深交所
return symbol_code, exchange
def _get_local_daily_file_path(self, symbol: str, year: int) -> Optional[str]:
"""
获取本地日线数据文件路径
Args:
symbol: 股票代码
year: 年份
Returns:
文件路径如果不存在返回None
"""
symbol_code, exchange = self._parse_symbol(symbol)
# 构建文件名格式
if exchange == 'SH':
file_prefix = f"sh{symbol_code}"
elif exchange == 'SZ':
file_prefix = f"sz{symbol_code}"
elif exchange == 'BJ':
file_prefix = f"bj{symbol_code}"
else:
file_prefix = symbol_code
# 查找文件
pattern = os.path.join(self.DATA_DIRS['daily'], str(year), f"{file_prefix}_daily.parquet")
if os.path.exists(pattern):
return pattern
# 尝试其他可能的文件名格式
pattern2 = os.path.join(self.DATA_DIRS['daily'], str(year), f"{symbol_code}_daily.parquet")
if os.path.exists(pattern2):
return pattern2
return None
def load_local_daily_data(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
从赵云本地数据加载日线数据
Args:
symbol: 股票代码
start_date: 开始日期 "YYYY-MM-DD"
end_date: 结束日期 "YYYY-MM-DD"
Returns:
日线数据DataFrame如果本地没有返回None
"""
if not self.use_local_first:
return None
try:
# 解析日期范围
start_dt = pd.to_datetime(start_date)
end_dt = pd.to_datetime(end_date)
# 收集所有年份的数据
all_data = []
for year in range(start_dt.year, end_dt.year + 1):
file_path = self._get_local_daily_file_path(symbol, year)
if file_path and os.path.exists(file_path):
df = pd.read_parquet(file_path)
# 过滤日期范围
df['date'] = pd.to_datetime(df['date'])
mask = (df['date'] >= start_dt) & (df['date'] <= end_dt)
df_filtered = df[mask]
if not df_filtered.empty:
all_data.append(df_filtered)
logger.debug(f"✅ 从本地加载 {symbol} {year}年数据: {len(df_filtered)}")
if all_data:
# 合并所有年份数据
result = pd.concat(all_data, ignore_index=True)
result = result.sort_values('date')
# 转换为vn.py字段名
result = result.rename(columns=self.VNPY_FIELD_MAP)
# 添加symbol和exchange字段
symbol_code, exchange = self._parse_symbol(symbol)
result['symbol'] = symbol_code
result['exchange'] = exchange
result['interval'] = '1d'
logger.info(f"✅ 成功从本地加载 {symbol} 数据: {len(result)} 条 ({start_date}{end_date})")
return result
else:
logger.info(f"⚠️ 本地没有找到 {symbol} 的数据")
return None
except Exception as e:
logger.error(f"❌ 加载本地数据失败 {symbol}: {e}")
return None
def fetch_akshare_daily_data(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
从akshare获取日线数据fallback方案
Args:
symbol: 股票代码
start_date: 开始日期 "YYYY-MM-DD"
end_date: 结束日期 "YYYY-MM-DD"
Returns:
日线数据DataFrame
"""
try:
symbol_code, exchange = self._parse_symbol(symbol)
# 转换日期格式
start_date_ak = start_date.replace('-', '')
end_date_ak = end_date.replace('-', '')
logger.info(f"📡 从akshare获取 {symbol} 数据 ({start_date}{end_date})")
# 获取数据
df = ak.stock_zh_a_hist(
symbol=symbol_code,
period="daily",
start_date=start_date_ak,
end_date=end_date_ak,
adjust="" # 不复权
)
if df is None or df.empty:
logger.warning(f"⚠️ akshare没有 {symbol} 的数据")
return None
# 重命名列
df.rename(columns={
'日期': 'datetime',
'开盘': 'open_price',
'收盘': 'close_price',
'最高': 'high_price',
'最低': 'low_price',
'成交量': 'volume',
'成交额': 'turnover',
}, inplace=True)
# 格式化日期
df['datetime'] = pd.to_datetime(df['datetime']).dt.strftime('%Y-%m-%d %H:%M:%S')
# 添加其他字段
df['symbol'] = symbol_code
df['exchange'] = exchange
df['interval'] = '1d'
logger.info(f"✅ 从akshare获取 {symbol} 数据成功: {len(df)}")
return df
except Exception as e:
logger.error(f"❌ 从akshare获取数据失败 {symbol}: {e}")
return None
def get_daily_data(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""
获取日线数据优先本地fallback akshare
Args:
symbol: 股票代码
start_date: 开始日期 "YYYY-MM-DD"
end_date: 结束日期 "YYYY-MM-DD"
Returns:
日线数据DataFrame如果都失败返回空DataFrame
"""
# 1. 优先尝试本地数据
if self.use_local_first:
local_data = self.load_local_daily_data(symbol, start_date, end_date)
if local_data is not None and not local_data.empty:
return local_data
# 2. fallback到akshare
akshare_data = self.fetch_akshare_daily_data(symbol, start_date, end_date)
if akshare_data is not None and not akshare_data.empty:
return akshare_data
# 3. 都失败
logger.error(f"❌ 无法获取 {symbol} 的数据")
return pd.DataFrame()
def verify_local_data_structure(self, symbol: str) -> Dict:
"""
验证本地数据结构是否符合vn.py要求
Args:
symbol: 股票代码
Returns:
验证结果字典
"""
result = {
'symbol': symbol,
'has_local_data': False,
'data_years': [],
'missing_fields': [],
'recommendations': [],
'status': 'UNKNOWN'
}
try:
# 查找所有年份的数据
data_years = []
for year in range(2010, 2027): # 假设数据范围
file_path = self._get_local_daily_file_path(symbol, year)
if file_path and os.path.exists(file_path):
data_years.append(year)
# 检查字段
df = pd.read_parquet(file_path)
required_fields = ['date', 'open', 'high', 'low', 'close', 'volume']
missing = [field for field in required_fields if field not in df.columns]
if missing:
result['missing_fields'].extend(missing)
result['data_years'] = data_years
result['has_local_data'] = len(data_years) > 0
if result['has_local_data']:
if result['missing_fields']:
result['status'] = 'INCOMPLETE'
result['recommendations'].append(f"缺少字段: {result['missing_fields']}")
result['recommendations'].append("建议:使用data_convert_tool.py转换数据格式")
else:
result['status'] = 'OK'
result['recommendations'].append(f"✅ 数据结构完整,覆盖 {min(data_years)}-{max(data_years)}")
else:
result['status'] = 'NO_DATA'
result['recommendations'].append("建议:联系赵云将军下载该股票数据")
except Exception as e:
result['status'] = 'ERROR'
result['recommendations'].append(f"验证错误: {e}")
return result
class DataConvertTool:
"""
数据格式转换工具
用于将赵云的数据格式转换为vn.py需要的格式
"""
@staticmethod
def convert_zhaoyun_to_vnpy(input_path: str, output_path: str, symbol: str):
"""
将赵云数据格式转换为vn.py格式
Args:
input_path: 赵云数据文件路径
output_path: 输出文件路径
symbol: 股票代码
"""
try:
# 读取赵云数据
df = pd.read_parquet(input_path)
# 检查必要字段
required = ['date', 'open', 'high', 'low', 'close', 'volume']
missing = [field for field in required if field not in df.columns]
if missing:
raise ValueError(f"缺少必要字段: {missing}")
# 转换为vn.py格式
vnpy_df = pd.DataFrame()
vnpy_df['datetime'] = pd.to_datetime(df['date']).dt.strftime('%Y-%m-%d %H:%M:%S')
vnpy_df['open_price'] = df['open']
vnpy_df['high_price'] = df['high']
vnpy_df['low_price'] = df['low']
vnpy_df['close_price'] = df['close']
vnpy_df['volume'] = df['volume']
# 添加其他字段
if 'amount' in df.columns:
vnpy_df['turnover'] = df['amount']
else:
vnpy_df['turnover'] = df['volume'] * df['close'] # 估算成交额
if 'turnover' in df.columns:
vnpy_df['turnover_rate'] = df['turnover']
# 添加标识字段
symbol_code, exchange = VnpyLocalDataAdapter._parse_symbol(VnpyLocalDataAdapter(), symbol)
vnpy_df['symbol'] = symbol_code
vnpy_df['exchange'] = exchange
vnpy_df['interval'] = '1d'
# 保存为parquet
vnpy_df.to_parquet(output_path, index=False)
logger.info(f"✅ 数据转换完成: {input_path}{output_path}")
except Exception as e:
logger.error(f"❌ 数据转换失败: {e}")
raise
# vn.py数据管理器包装器
class VnpyDataManagerWrapper:
"""
vn.py数据管理器包装器
替换vn.py默认的数据获取逻辑
"""
def __init__(self, original_data_manager, adapter: VnpyLocalDataAdapter):
"""
初始化包装器
Args:
original_data_manager: 原始vn.py数据管理器
adapter: 本地数据适配器
"""
self.original_dm = original_data_manager
self.adapter = adapter
self._patch_methods()
def _patch_methods(self):
"""修补vn.py数据获取方法"""
# 这里需要根据vn.py的具体API进行修补
# 由于vn.py版本和实现不同,这里提供示例代码
logger.info("✅ vn.py数据管理器已修补为优先使用本地数据")
def get_daily_bar_data(self, symbol: str, start_date: str, end_date: str):
"""获取日线数据(重写方法)"""
return self.adapter.get_daily_data(symbol, start_date, end_date)
# 使用示例
if __name__ == "__main__":
# 1. 创建适配器
adapter = VnpyLocalDataAdapter(use_local_first=True)
# 2. 测试数据获取
test_symbol = "000001.SZ" # 平安银行
start_date = "2024-01-01"
end_date = "2024-03-01"
print("=" * 60)
print("vn.py本地数据适配器测试")
print("=" * 60)
# 3. 验证本地数据
print("\n1. 验证本地数据结构:")
verification = adapter.verify_local_data_structure(test_symbol)
for key, value in verification.items():
print(f" {key}: {value}")
# 4. 获取数据
print(f"\n2. 获取 {test_symbol} 数据 ({start_date}{end_date}):")
data = adapter.get_daily_data(test_symbol, start_date, end_date)
if not data.empty:
print(f"✅ 成功获取 {len(data)} 条数据")
print(f"数据字段: {list(data.columns)}")
print(f"时间范围: {data['datetime'].min()}{data['datetime'].max()}")
print(f"数据来源: {'本地' if 'outstanding_share' in data.columns else 'akshare'}")
else:
print("❌ 获取数据失败")
print("\n3. 使用建议:")
print(" a) 在vn.py策略中导入此适配器")
print(" b) 替换原有的数据获取逻辑")
print(" c) 配置赵云数据目录路径")
print(" d) 定期更新本地数据(联系赵云将军)")
print("=" * 60)