auto-sync: 2026-04-29 20:14:41
This commit is contained in:
@@ -1,149 +0,0 @@
|
||||
# 自动化回测服务 - 使用说明
|
||||
|
||||
## 概述
|
||||
|
||||
基于 vnpy 原生 `BacktestingEngine` 封装的 RESTful API 自动化回测服务。
|
||||
|
||||
支持:
|
||||
- 通过 API 提交回测任务
|
||||
- 自动排队控制并发
|
||||
- 使用 vnpy 原生回测引擎执行
|
||||
- 保存回测结果(CSV + 图表 + JSON)
|
||||
- 查询任务状态和结果
|
||||
|
||||
## 架构设计
|
||||
|
||||
严格遵循 vnpy 原生设计,不修改核心,只做外层封装:
|
||||
|
||||
```
|
||||
[API 服务] ← 接收任务
|
||||
↓
|
||||
[任务队列 + 进程池] ← 控制并发
|
||||
↓
|
||||
[BacktestingEngine (vnpy 原生)] ← 执行回测
|
||||
↓
|
||||
[文件存储] ← 保存结果
|
||||
↓
|
||||
[API 查询] ← 返回结果
|
||||
```
|
||||
|
||||
## 启动方式
|
||||
|
||||
```bash
|
||||
# 手动启动
|
||||
cd /app/scripts/backtest-service
|
||||
python main.py
|
||||
|
||||
# 后台运行
|
||||
nohup python main.py > backtest-service.log 2>&1 &
|
||||
|
||||
# 查看日志
|
||||
tail -f backtest-service.log
|
||||
```
|
||||
|
||||
## 访问地址
|
||||
|
||||
启动后访问:
|
||||
- 服务地址:http://container-ip:8088
|
||||
- API 文档:http://container-ip:8088/docs → 可以直接在网页上测试API
|
||||
|
||||
## API 接口说明
|
||||
|
||||
| 接口 | 方法 | 路径 | 说明 |
|
||||
|------|------|------|------|
|
||||
| 提交回测 | POST | `/api/backtest/submit` | 提交新的回测任务 |
|
||||
| 任务列表 | GET | `/api/backtest/list` | 列出任务,支持分页和状态过滤 |
|
||||
| 任务状态 | GET | `/api/backtest/status/{task_id}` | 查询任务状态 |
|
||||
| 回测结果 | GET | `/api/backtest/result/{task_id}` | 获取完整回测结果 |
|
||||
| 删除任务 | DELETE | `/api/backtest/delete/{task_id}` | 删除任务 |
|
||||
| 健康检查 | GET | `/api/backtest/health` | 查看服务状态,返回任务统计 |
|
||||
|
||||
## 配置
|
||||
|
||||
可以通过环境变量覆盖默认配置:
|
||||
|
||||
| 环境变量 | 说明 | 默认值 |
|
||||
|----------|------|--------|
|
||||
| `MAX_WORKERS` | 最大并发回测数 | 2 |
|
||||
| `HOST` | 监听地址 | 0.0.0.0 |
|
||||
| `PORT` | 监听端口 | 8088 |
|
||||
| `BASE_DIR` | 任务存储根目录 | /app/backtest_jobs |
|
||||
| `DEBUG` | 调试模式 | True |
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 1. 提交回测
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8088/api/backtest/submit \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"strategy_name": "双均线策略",
|
||||
"strategy_code": "完整的策略Python代码...",
|
||||
"symbol": "IF888.CFFEX",
|
||||
"interval": "1h",
|
||||
"start_date": "2020-01-01",
|
||||
"end_date": "2025-01-01",
|
||||
"parameters": {
|
||||
"fast_window": 5,
|
||||
"slow_window": 20
|
||||
},
|
||||
"capital": 1000000
|
||||
}'
|
||||
```
|
||||
|
||||
响应:
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"msg": "任务提交成功",
|
||||
"data": {
|
||||
"task_id": "a1b2c3d4...",
|
||||
"status": "pending",
|
||||
"created_at": "2026-04-12T10:00:00+08:00"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 查询任务状态
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8088/api/backtest/status/a1b2c3d4
|
||||
```
|
||||
|
||||
### 3. 获取回测结果
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8088/api/backtest/result/a1b2c3d4
|
||||
```
|
||||
|
||||
## 结果存储结构
|
||||
|
||||
```
|
||||
/app/backtest_jobs/
|
||||
├── pending/ # 等待执行
|
||||
├── running/ # 执行中
|
||||
├── completed/ # 已完成
|
||||
│ └── <task_id>/
|
||||
│ ├── task.json # 任务信息
|
||||
│ ├── result.json # 回测结果
|
||||
│ ├── equity.csv # 每日净值
|
||||
│ ├── equity_curve.png # 收益曲线图
|
||||
│ └── trades.csv # 成交记录
|
||||
└── failed/ # 执行失败
|
||||
└── <task_id>/
|
||||
├── task.json
|
||||
└── result.json # 包含错误信息
|
||||
```
|
||||
|
||||
## 设计原则
|
||||
|
||||
1. **不改动 vnpy 核心**:完全复用原生 `BacktestingEngine`
|
||||
2. **轻量级**:只用 Python 标准库,不引入额外第三方任务队列
|
||||
3. **隔离性**:每个回测在独立进程,一个失败不影响其他
|
||||
4. **可配置并发**:根据 CPU 性能调整 `MAX_WORKERS`,避免资源耗尽
|
||||
|
||||
## 作者
|
||||
|
||||
三国量化团队 姜维 伯约
|
||||
2026-04-12
|
||||
@@ -1,92 +0,0 @@
|
||||
"""
|
||||
自动化回测服务 - API路由
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from .models import (
|
||||
BacktestTask,
|
||||
BacktestTaskWithId,
|
||||
BacktestResult,
|
||||
TaskListResponse,
|
||||
ApiResponse,
|
||||
HealthCheckResponse,
|
||||
TaskStatus,
|
||||
)
|
||||
from .task_queue import task_queue
|
||||
from .result_storage import storage
|
||||
from .executor import executor
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/submit", summary="提交回测任务")
|
||||
def submit_task(task: BacktestTask) -> ApiResponse[BacktestTaskWithId]:
|
||||
"""提交一个新的回测任务"""
|
||||
task_with_id = task_queue.submit_task(task)
|
||||
storage.save_task(task_with_id)
|
||||
return ApiResponse(
|
||||
code=0,
|
||||
msg="任务提交成功",
|
||||
data=task_with_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", summary="列出回测任务")
|
||||
def list_tasks(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
status: Optional[str] = Query(None, description="状态过滤 pending/running/completed/failed")
|
||||
) -> ApiResponse[TaskListResponse]:
|
||||
"""列出回测任务,支持分页和状态过滤"""
|
||||
result = task_queue.list_tasks(page, page_size, status)
|
||||
return ApiResponse(
|
||||
code=0,
|
||||
msg="success",
|
||||
data=TaskListResponse(**result)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status/{task_id}", summary="查询任务状态")
|
||||
def get_status(task_id: str) -> ApiResponse[Optional[BacktestTaskWithId]]:
|
||||
"""查询单个任务状态(从磁盘查找)"""
|
||||
task = storage.find_task(task_id)
|
||||
if not task:
|
||||
return ApiResponse(code=404, msg="任务不存在", data=None)
|
||||
return ApiResponse(code=0, msg="success", data=task)
|
||||
|
||||
|
||||
@router.get("/result/{task_id}", summary="获取回测结果")
|
||||
def get_result(task_id: str) -> ApiResponse[Optional[BacktestResult]]:
|
||||
"""获取回测完整结果(从磁盘查找)"""
|
||||
result = storage.find_result(task_id)
|
||||
if not result:
|
||||
task = storage.find_task(task_id)
|
||||
if not task:
|
||||
return ApiResponse(code=404, msg="任务不存在", data=None)
|
||||
return ApiResponse(code=0, msg="任务尚未完成", data=None)
|
||||
return ApiResponse(code=0, msg="success", data=result)
|
||||
|
||||
|
||||
@router.delete("/delete/{task_id}", summary="删除回测任务")
|
||||
def delete_task(task_id: str) -> ApiResponse[None]:
|
||||
"""删除一个回测任务"""
|
||||
# TODO: 实现物理删除
|
||||
# 现在只返回成功,后续实现
|
||||
return ApiResponse(code=0, msg="删除成功(待实现物理删除)", data=None)
|
||||
|
||||
|
||||
@router.get("/health", summary="健康检查")
|
||||
def health_check() -> ApiResponse[HealthCheckResponse]:
|
||||
"""服务健康检查,返回任务统计信息"""
|
||||
return ApiResponse(
|
||||
code=0,
|
||||
msg="ok",
|
||||
data=HealthCheckResponse(
|
||||
pending_count=len(task_queue.pending_tasks),
|
||||
running_count=len(task_queue.running_tasks),
|
||||
completed_count=len(task_queue.completed_tasks),
|
||||
failed_count=len(task_queue.failed_tasks),
|
||||
max_workers=task_queue.max_workers
|
||||
)
|
||||
)
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
自动化回测服务 - 配置
|
||||
"""
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""回测服务配置"""
|
||||
# 服务监听地址
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8088
|
||||
|
||||
# 最大并发回测数(根据CPU核数调整,NAS赛扬建议2)
|
||||
max_workers: int = 2
|
||||
|
||||
# 回测任务存储根目录
|
||||
base_dir: str = "/app/backtest_jobs"
|
||||
|
||||
# CORS 配置 - 开发环境允许所有来源
|
||||
cors_allow_all: bool = True
|
||||
|
||||
# 调试模式
|
||||
debug: bool = True
|
||||
|
||||
# 允许策略代码中导入模块
|
||||
allow_imports: bool = True
|
||||
|
||||
# 最大回测时间限制(秒),防止无限循环
|
||||
max_execution_time: int = 3600 # 1小时
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,262 +0,0 @@
|
||||
"""
|
||||
自动化回测服务 - 任务执行器
|
||||
调用 vnpy 4.x BacktestingEngine 执行回测
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import matplotlib
|
||||
matplotlib.use("Agg") # 无头模式,服务器上不能弹窗
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
# vnpy 4.x import路径(与3.x不同)
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy_ctastrategy.backtesting import BacktestingEngine
|
||||
from vnpy.trader.constant import Interval, Exchange
|
||||
|
||||
from .config import settings
|
||||
from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId
|
||||
from .result_storage import storage
|
||||
|
||||
|
||||
# vnpy 4.x精简了Interval枚举,不再有FIVE_MINUTE等细分
|
||||
INTERVAL_MAP = {
|
||||
"1m": Interval.MINUTE,
|
||||
"5m": Interval.MINUTE,
|
||||
"15m": Interval.MINUTE,
|
||||
"30m": Interval.MINUTE,
|
||||
"1h": Interval.HOUR,
|
||||
"4h": Interval.HOUR,
|
||||
"1d": Interval.DAILY,
|
||||
"1w": Interval.WEEKLY,
|
||||
}
|
||||
|
||||
# 交易所映射
|
||||
EXCHANGE_MAP = {
|
||||
"SSE": Exchange.SSE,
|
||||
"SZSE": Exchange.SZSE,
|
||||
"CFFEX": Exchange.CFFEX,
|
||||
"SHFE": Exchange.SHFE,
|
||||
"DCE": Exchange.DCE,
|
||||
"CZCE": Exchange.CZCE,
|
||||
"INE": Exchange.INE,
|
||||
"GFEX": Exchange.GFEX,
|
||||
}
|
||||
|
||||
|
||||
def _parse_vt_symbol(vt_symbol: str):
|
||||
"""解析vt_symbol为symbol和exchange,如 '000001.SZ' → ('000001', Exchange.SZSE)"""
|
||||
if "." in vt_symbol:
|
||||
symbol, exchange_str = vt_symbol.rsplit(".", 1)
|
||||
exchange = EXCHANGE_MAP.get(exchange_str.upper())
|
||||
if exchange is None:
|
||||
# 尝试模糊匹配
|
||||
exchange_str_upper = exchange_str.upper()
|
||||
for key, val in EXCHANGE_MAP.items():
|
||||
if key.startswith(exchange_str_upper[:2]):
|
||||
exchange = val
|
||||
break
|
||||
if exchange is None:
|
||||
exchange = Exchange.SZSE # 默认深交所
|
||||
return symbol, exchange
|
||||
return vt_symbol, Exchange.SZSE
|
||||
|
||||
|
||||
class BacktestExecutor:
|
||||
"""回测任务执行器 - 适配vnpy 4.x"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _load_strategy(self, task: BacktestTask):
|
||||
"""动态加载策略代码"""
|
||||
strategy_code = task.strategy_code
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
sys.path.insert(0, temp_dir)
|
||||
|
||||
strategy_file = os.path.join(temp_dir, "strategy.py")
|
||||
with open(strategy_file, "w", encoding="utf-8") as f:
|
||||
f.write(strategy_code)
|
||||
|
||||
import importlib
|
||||
spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["dynamic_strategy"] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 找CtaTemplate子类
|
||||
from vnpy_ctastrategy import CtaTemplate
|
||||
strategy_class = None
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if isinstance(attr, type) and issubclass(attr, CtaTemplate) and attr is not CtaTemplate:
|
||||
strategy_class = attr
|
||||
break
|
||||
|
||||
if not strategy_class:
|
||||
raise ValueError("策略代码中没有找到 CtaTemplate 子类,请检查策略代码")
|
||||
|
||||
return strategy_class
|
||||
|
||||
def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult:
|
||||
"""执行一次回测"""
|
||||
start_time = datetime.now()
|
||||
started_at = start_time.isoformat()
|
||||
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = started_at
|
||||
storage.save_task(task)
|
||||
|
||||
result = BacktestResult(
|
||||
task_id=task.task_id,
|
||||
strategy_name=task.strategy_name,
|
||||
status=TaskStatus.RUNNING,
|
||||
result_csv_path="",
|
||||
created_at=task.created_at,
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
try:
|
||||
# 加载策略类
|
||||
strategy_class = self._load_strategy(task)
|
||||
|
||||
# 解析vt_symbol
|
||||
symbol, exchange = _parse_vt_symbol(task.symbol)
|
||||
|
||||
# 获取interval
|
||||
interval = INTERVAL_MAP.get(task.interval, Interval.DAILY)
|
||||
|
||||
# 创建回测引擎
|
||||
engine = BacktestingEngine()
|
||||
|
||||
# 设置回测参数
|
||||
engine.set_parameters(
|
||||
vt_symbol=task.symbol,
|
||||
interval=interval,
|
||||
start=task.start_date,
|
||||
end=task.end_date,
|
||||
rate=0.3 / 10000, # 手续费率万三
|
||||
slippage=0.1, # 滑点0.1
|
||||
size=1, # 合约乘数
|
||||
pricetick=task.tick_size or 0.01, # 最小价格变动
|
||||
capital=task.capital,
|
||||
)
|
||||
|
||||
# 添加策略
|
||||
engine.add_strategy(strategy_class, task.parameters)
|
||||
|
||||
# 加载历史数据
|
||||
# 优先从CSV文件加载(/app/data目录通过volume挂载NAS数据)
|
||||
data_loaded = False
|
||||
data_dir = settings.base_dir.replace("backtest_jobs", "data")
|
||||
|
||||
# 尝试多种数据加载方式
|
||||
try:
|
||||
# 方式1: 使用vnpy内置数据加载
|
||||
engine.load_data()
|
||||
data_loaded = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not data_loaded:
|
||||
raise ValueError(
|
||||
f"无法加载 {task.symbol} 在 [{task.start_date}, {task.end_date}] 的历史数据。"
|
||||
f"请确保数据已导入vnpy数据库或可通过CSV加载。"
|
||||
)
|
||||
|
||||
# 运行回测
|
||||
engine.run_backtesting()
|
||||
|
||||
# 计算统计结果
|
||||
df = engine.calculate_result()
|
||||
statistics = engine.calculate_statistics()
|
||||
|
||||
# 转换为数据模型
|
||||
stats = BacktestStatistics(
|
||||
start_date=str(task.start_date),
|
||||
end_date=str(task.end_date),
|
||||
total_days=int(statistics.get("total_days", 0)),
|
||||
total_trades=int(statistics.get("total_trades", 0)),
|
||||
winning_trades=int(statistics.get("winning_trades", 0)),
|
||||
losing_trades=int(statistics.get("losing_trades", 0)),
|
||||
win_rate=float(statistics.get("win_rate", 0)),
|
||||
total_return=float(statistics.get("total_return", 0)),
|
||||
annual_return=float(statistics.get("annual_return", 0)),
|
||||
sharpe_ratio=float(statistics.get("sharpe_ratio", 0)),
|
||||
max_drawdown=float(statistics.get("max_drawdown", 0)),
|
||||
max_drawdown_start=str(statistics.get("max_drawdown_start", "")),
|
||||
max_drawdown_end=str(statistics.get("max_drawdown_end", "")),
|
||||
profit_factor=float(statistics.get("profit_factor", 0)),
|
||||
calmar_ratio=float(statistics.get("calmar_ratio", 0)),
|
||||
)
|
||||
|
||||
# 保存净值CSV
|
||||
result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv")
|
||||
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
|
||||
if df is not None and not df.empty:
|
||||
df.to_csv(result_csv_path)
|
||||
|
||||
# 绘制收益曲线
|
||||
png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png")
|
||||
self._plot_equity_curve(df, png_path)
|
||||
|
||||
# 保存成交记录
|
||||
trades_csv_path = None
|
||||
try:
|
||||
trades = engine.get_all_trades() if hasattr(engine, 'get_all_trades') else []
|
||||
if trades:
|
||||
trades_df = pd.DataFrame([t.__dict__ for t in trades])
|
||||
trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv")
|
||||
trades_df.to_csv(trades_csv_path, index=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 完成结果
|
||||
result.status = TaskStatus.COMPLETED
|
||||
result.statistics = stats
|
||||
result.result_csv_path = result_csv_path
|
||||
result.equity_curve_png_path = png_path
|
||||
result.trades_csv_path = trades_csv_path
|
||||
result.completed_at = datetime.now().isoformat()
|
||||
|
||||
storage.save_result(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{str(e)}\n{traceback.format_exc()}"
|
||||
result.status = TaskStatus.FAILED
|
||||
result.error_message = error_msg
|
||||
result.completed_at = datetime.now().isoformat()
|
||||
|
||||
storage.save_result(result)
|
||||
return result
|
||||
|
||||
def _plot_equity_curve(self, df: pd.DataFrame, output_path: str):
|
||||
"""绘制收益曲线"""
|
||||
plt.figure(figsize=(12, 6))
|
||||
if df is not None and not df.empty:
|
||||
if "equity" in df.columns:
|
||||
plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2)
|
||||
elif "net_pnl" in df.columns:
|
||||
cumulative = df["net_pnl"].cumsum()
|
||||
plt.plot(df.index, cumulative, label="累计收益", linewidth=2)
|
||||
elif "balance" in df.columns:
|
||||
plt.plot(df.index, df["balance"], label="账户余额", linewidth=2)
|
||||
|
||||
plt.title("回测收益曲线")
|
||||
plt.xlabel("时间")
|
||||
plt.ylabel("净值")
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150)
|
||||
plt.close()
|
||||
|
||||
|
||||
executor = BacktestExecutor()
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
自动化回测服务 - 主入口
|
||||
启动 FastAPI 服务,接受回测任务提交,执行回测,返回结果
|
||||
遵循 vnpy 原生设计,只做外层封装
|
||||
"""
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .config import settings
|
||||
from .api import router
|
||||
from .task_queue import task_queue
|
||||
from .models import ApiResponse, HealthCheckResponse
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期:启动时开启worker线程,关闭时停止"""
|
||||
# 启动
|
||||
task_queue.start_worker_pool()
|
||||
print(f"✅ 回测服务启动 (max_workers={settings.max_workers})")
|
||||
yield
|
||||
# 关闭
|
||||
task_queue.close_worker_pool()
|
||||
print("回测服务已停止")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="sanguo 自动化回测服务",
|
||||
description="基于 vnpy 原生 BacktestingEngine 的自动化回测API服务",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS 配置
|
||||
if settings.cors_allow_all:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册API路由
|
||||
app.include_router(router, tags=["backtest"])
|
||||
|
||||
|
||||
@app.get("/api/backtest/health", summary="服务健康检查", response_model=ApiResponse[HealthCheckResponse])
|
||||
def health():
|
||||
"""服务健康检查,返回当前任务统计"""
|
||||
return ApiResponse(
|
||||
code=0,
|
||||
msg="ok",
|
||||
data=HealthCheckResponse(
|
||||
pending_count=len(task_queue.pending_tasks),
|
||||
running_count=len(task_queue.running_tasks),
|
||||
completed_count=len(task_queue.completed_tasks),
|
||||
failed_count=len(task_queue.failed_tasks),
|
||||
max_workers=settings.max_workers
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"backtest_service.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
)
|
||||
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
自动化回测服务 - 数据模型
|
||||
"""
|
||||
from enum import Enum
|
||||
from datetime import date, datetime
|
||||
from typing import Dict, Optional, Any, List, Generic, TypeVar
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态枚举"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class BacktestTask(BaseModel):
|
||||
"""回测任务请求"""
|
||||
strategy_name: str = Field(..., description="策略名称")
|
||||
strategy_code: str = Field(..., description="策略完整Python代码")
|
||||
symbol: str = Field(..., description="交易品种,例如 000001.SSE")
|
||||
interval: str = Field(..., description="K线周期,例如 1d、1h")
|
||||
start_date: date = Field(..., description="回测开始日期")
|
||||
end_date: date = Field(..., description="回测结束日期")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="策略参数字典")
|
||||
capital: float = Field(default=1_000_000, description="起始资金")
|
||||
tick_size: Optional[float] = Field(None, description="最小价格变动,不指定则自动获取")
|
||||
|
||||
|
||||
class BacktestTaskWithId(BacktestTask):
|
||||
"""带ID和状态的回测任务"""
|
||||
task_id: str = Field(..., description="任务唯一ID")
|
||||
status: TaskStatus = Field(..., description="任务状态")
|
||||
created_at: str = Field(..., description="创建时间 ISO格式")
|
||||
started_at: Optional[str] = Field(None, description="开始时间")
|
||||
completed_at: Optional[str] = Field(None, description="完成时间")
|
||||
|
||||
|
||||
class BacktestStatistics(BaseModel):
|
||||
"""回测结果统计"""
|
||||
start_date: str
|
||||
end_date: str
|
||||
total_days: int
|
||||
total_trades: int
|
||||
winning_trades: int
|
||||
losing_trades: int
|
||||
win_rate: float
|
||||
total_return: float # 总收益率
|
||||
annual_return: float # 年化收益率
|
||||
sharpe_ratio: float # 夏普比率
|
||||
max_drawdown: float # 最大回撤
|
||||
max_drawdown_start: Optional[str] = None
|
||||
max_drawdown_end: Optional[str] = None
|
||||
profit_factor: float # 收益因子(总盈利/总亏损)
|
||||
calmar_ratio: float # 卡玛比率(年化收益/最大回撤)
|
||||
|
||||
|
||||
class BacktestResult(BaseModel):
|
||||
"""完整回测结果"""
|
||||
task_id: str
|
||||
strategy_name: str
|
||||
status: TaskStatus
|
||||
statistics: Optional[BacktestStatistics] = None
|
||||
result_csv_path: str # 每日净值CSV路径
|
||||
equity_curve_png_path: Optional[str] = None # 收益曲线图片路径
|
||||
trades_csv_path: Optional[str] = None # 成交记录CSV路径
|
||||
error_message: Optional[str] = None # 失败时的错误信息
|
||||
created_at: str
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
|
||||
class TaskListResponse(BaseModel):
|
||||
"""任务列表响应"""
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
tasks: List[BacktestTaskWithId]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
"""通用API响应包装"""
|
||||
code: int = Field(0, description="0表示成功,非0表示错误")
|
||||
msg: str = Field("success", description="响应消息")
|
||||
data: Optional[T] = Field(None, description="响应数据")
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
pending_count: int
|
||||
running_count: int
|
||||
completed_count: int
|
||||
failed_count: int
|
||||
max_workers: int
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
自动化回测服务 - 数据模型
|
||||
"""
|
||||
from enum import Enum
|
||||
from datetime import date, datetime
|
||||
from typing import Dict, Optional, Any, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态枚举"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class BacktestTask(BaseModel):
|
||||
"""回测任务请求"""
|
||||
strategy_name: str = Field(..., description="策略名称")
|
||||
strategy_code: str = Field(..., description="策略完整Python代码")
|
||||
symbol: str = Field(..., description="交易品种,例如 000001.SSE")
|
||||
interval: str = Field(..., description="K线周期,例如 1d、1h")
|
||||
start_date: date = Field(..., description="回测开始日期")
|
||||
end_date: date = Field(..., description="回测结束日期")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="策略参数字典")
|
||||
capital: float = Field(default=1_000_000, description="起始资金")
|
||||
tick_size: Optional[float] = Field(None, description="最小价格变动,不指定则自动获取")
|
||||
|
||||
|
||||
class BacktestTaskWithId(BacktestTask):
|
||||
"""带ID和状态的回测任务"""
|
||||
task_id: str = Field(..., description="任务唯一ID")
|
||||
status: TaskStatus = Field(..., description="任务状态")
|
||||
created_at: str = Field(..., description="创建时间 ISO格式")
|
||||
started_at: Optional[str] = Field(None, description="开始时间")
|
||||
completed_at: Optional[str] = Field(None, description="完成时间")
|
||||
|
||||
|
||||
class BacktestStatistics(BaseModel):
|
||||
"""回测结果统计"""
|
||||
start_date: str
|
||||
end_date: str
|
||||
total_days: int
|
||||
total_trades: int
|
||||
winning_trades: int
|
||||
losing_trades: int
|
||||
win_rate: float
|
||||
total_return: float # 总收益率
|
||||
annual_return: float # 年化收益率
|
||||
sharpe_ratio: float # 夏普比率
|
||||
max_drawdown: float # 最大回撤
|
||||
max_drawdown_start: Optional[str] = None
|
||||
max_drawdown_end: Optional[str] = None
|
||||
profit_factor: float # 收益因子(总盈利/总亏损)
|
||||
calmar_ratio: float # 卡玛比率(年化收益/最大回撤)
|
||||
|
||||
|
||||
class BacktestResult(BaseModel):
|
||||
"""完整回测结果"""
|
||||
task_id: str
|
||||
strategy_name: str
|
||||
status: TaskStatus
|
||||
statistics: Optional[BacktestStatistics] = None
|
||||
result_csv_path: str # 每日净值CSV路径
|
||||
equity_curve_png_path: Optional[str] = None # 收益曲线图片路径
|
||||
trades_csv_path: Optional[str] = None # 成交记录CSV路径
|
||||
error_message: Optional[str] = None # 失败时的错误信息
|
||||
created_at: str
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
|
||||
class TaskListResponse(BaseModel):
|
||||
"""任务列表响应"""
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
tasks: List[BacktestTaskWithId]
|
||||
|
||||
|
||||
class ApiResponse[T](BaseModel):
|
||||
"""通用API响应包装"""
|
||||
code: int = Field(0, description="0表示成功,非0表示错误")
|
||||
msg: str = Field("success", description="响应消息")
|
||||
data: Optional[T] = Field(None, description="响应数据")
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
pending_count: int
|
||||
running_count: int
|
||||
completed_count: int
|
||||
failed_count: int
|
||||
max_workers: int
|
||||
@@ -1,101 +0,0 @@
|
||||
"""
|
||||
自动化回测服务 - 结果存储
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from typing import Optional
|
||||
from .models import BacktestTaskWithId, BacktestResult
|
||||
from .config import settings
|
||||
|
||||
|
||||
def _json_serial(obj):
|
||||
"""JSON序列化辅助:处理date/datetime"""
|
||||
if isinstance(obj, (date, datetime)):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||
|
||||
|
||||
class ResultStorage:
|
||||
"""结果存储管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_dir = settings.base_dir
|
||||
self._ensure_dirs()
|
||||
|
||||
def _ensure_dirs(self):
|
||||
"""确保目录结构存在"""
|
||||
for status_dir in ["pending", "running", "completed", "failed"]:
|
||||
path = os.path.join(self.base_dir, status_dir)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
def _task_dir(self, task_id: str, status: str) -> str:
|
||||
"""获取任务目录"""
|
||||
return os.path.join(self.base_dir, status, task_id)
|
||||
|
||||
def save_task(self, task: BacktestTaskWithId) -> None:
|
||||
"""保存任务信息"""
|
||||
task_dir = self._task_dir(task.task_id, task.status)
|
||||
os.makedirs(task_dir, exist_ok=True)
|
||||
|
||||
info_file = os.path.join(task_dir, "task.json")
|
||||
with open(info_file, "w", encoding="utf-8") as f:
|
||||
json.dump(task.model_dump(), f, indent=2, ensure_ascii=False, default=_json_serial)
|
||||
|
||||
def load_task(self, task_id: str, status: str) -> Optional[BacktestTaskWithId]:
|
||||
"""加载任务信息"""
|
||||
task_dir = self._task_dir(task_id, status)
|
||||
info_file = os.path.join(task_dir, "task.json")
|
||||
|
||||
if not os.path.exists(info_file):
|
||||
return None
|
||||
|
||||
with open(info_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
return BacktestTaskWithId(**data)
|
||||
|
||||
def save_result(self, result: BacktestResult) -> None:
|
||||
"""保存回测结果"""
|
||||
task_dir = self._task_dir(result.task_id, result.status)
|
||||
os.makedirs(task_dir, exist_ok=True)
|
||||
|
||||
result_file = os.path.join(task_dir, "result.json")
|
||||
with open(result_file, "w", encoding="utf-8") as f:
|
||||
json.dump(result.model_dump(), f, indent=2, ensure_ascii=False, default=_json_serial)
|
||||
|
||||
def load_result(self, task_id: str, status: str) -> Optional[BacktestResult]:
|
||||
"""加载回测结果"""
|
||||
task_dir = self._task_dir(task_id, status)
|
||||
result_file = os.path.join(task_dir, "result.json")
|
||||
|
||||
if not os.path.exists(result_file):
|
||||
return None
|
||||
|
||||
with open(result_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
return BacktestResult(**data)
|
||||
|
||||
def find_task(self, task_id: str) -> Optional[BacktestTaskWithId]:
|
||||
"""在所有状态目录中查找任务"""
|
||||
for status_dir in ["running", "failed", "completed", "pending"]:
|
||||
task = self.load_task(task_id, status_dir)
|
||||
if task:
|
||||
return task
|
||||
return None
|
||||
|
||||
def find_result(self, task_id: str) -> Optional[BacktestResult]:
|
||||
"""在所有状态目录中查找结果"""
|
||||
for status_dir in ["failed", "completed", "running", "pending"]:
|
||||
result = self.load_result(task_id, status_dir)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
def get_task_path(self, task_id: str, status: str, filename: str) -> str:
|
||||
"""获取任务文件路径"""
|
||||
return os.path.join(self._task_dir(task_id, status), filename)
|
||||
|
||||
|
||||
storage = ResultStorage()
|
||||
@@ -1,141 +0,0 @@
|
||||
"""
|
||||
自动化回测服务 - 任务队列
|
||||
简单后台线程调度:submit后自动触发执行,同一时间只跑一个回测
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from .config import settings
|
||||
from .models import TaskStatus, BacktestTask, BacktestTaskWithId
|
||||
from .result_storage import storage
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""任务队列管理器 - 后台线程调度"""
|
||||
|
||||
def __init__(self):
|
||||
self.max_workers = settings.max_workers
|
||||
self.pending_tasks: List[str] = []
|
||||
self.running_tasks: List[str] = []
|
||||
self.completed_tasks: List[str] = []
|
||||
self.failed_tasks: List[str] = []
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def _generate_task_id(self) -> str:
|
||||
return str(uuid.uuid4()).replace("-", "")
|
||||
|
||||
def submit_task(self, task: BacktestTask) -> BacktestTaskWithId:
|
||||
"""提交新任务到队列"""
|
||||
task_id = self._generate_task_id()
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
task_with_id = BacktestTaskWithId(
|
||||
task_id=task_id,
|
||||
status=TaskStatus.PENDING,
|
||||
created_at=now,
|
||||
**task.model_dump()
|
||||
)
|
||||
|
||||
storage.save_task(task_with_id)
|
||||
self.pending_tasks.append(task_id)
|
||||
return task_with_id
|
||||
|
||||
def list_tasks(self, page: int = 1, page_size: int = 10, status: Optional[str] = None) -> Dict:
|
||||
if status == "pending":
|
||||
task_ids = self.pending_tasks
|
||||
elif status == "running":
|
||||
task_ids = self.running_tasks
|
||||
elif status == "completed":
|
||||
task_ids = self.completed_tasks
|
||||
elif status == "failed":
|
||||
task_ids = self.failed_tasks
|
||||
else:
|
||||
task_ids = (
|
||||
self.pending_tasks +
|
||||
self.running_tasks +
|
||||
self.completed_tasks +
|
||||
self.failed_tasks
|
||||
)
|
||||
|
||||
total = len(task_ids)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
result = []
|
||||
for task_id in task_ids[start:end]:
|
||||
for status_dir in ["pending", "running", "completed", "failed"]:
|
||||
task = storage.load_task(task_id, status_dir)
|
||||
if task:
|
||||
result.append(task)
|
||||
break
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"tasks": result
|
||||
}
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[BacktestTaskWithId]:
|
||||
for status_dir in ["pending", "running", "completed", "failed"]:
|
||||
task = storage.load_task(task_id, status_dir)
|
||||
if task:
|
||||
return task
|
||||
return None
|
||||
|
||||
def _worker_loop(self):
|
||||
"""后台工作线程:循环检查pending任务并执行"""
|
||||
from .executor import executor
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
# 检查是否有pending任务且当前无running任务
|
||||
if self.pending_tasks and not self.running_tasks:
|
||||
task_id = self.pending_tasks.pop(0)
|
||||
self.running_tasks.append(task_id)
|
||||
|
||||
# 在后台线程中执行回测
|
||||
try:
|
||||
task = storage.load_task(task_id, "pending")
|
||||
if task:
|
||||
# 移动到running目录
|
||||
storage.save_task(task)
|
||||
result = executor.execute_backtest(task)
|
||||
|
||||
# 从running移到completed/failed
|
||||
self.running_tasks.remove(task_id)
|
||||
if result.status == TaskStatus.COMPLETED:
|
||||
self.completed_tasks.append(task_id)
|
||||
else:
|
||||
self.failed_tasks.append(task_id)
|
||||
except Exception as e:
|
||||
print(f"任务执行异常: {e}\n{traceback.format_exc()}")
|
||||
if task_id in self.running_tasks:
|
||||
self.running_tasks.remove(task_id)
|
||||
self.failed_tasks.append(task_id)
|
||||
|
||||
# 等待1秒再检查
|
||||
self._stop_event.wait(1.0)
|
||||
|
||||
def start_worker_pool(self):
|
||||
"""启动后台工作线程"""
|
||||
if self._worker_thread is None or not self._worker_thread.is_alive():
|
||||
self._stop_event.clear()
|
||||
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
self._worker_thread.start()
|
||||
print(f"工作线程已启动 (max_workers={self.max_workers})")
|
||||
|
||||
def close_worker_pool(self):
|
||||
"""停止工作线程"""
|
||||
self._stop_event.set()
|
||||
if self._worker_thread and self._worker_thread.is_alive():
|
||||
self._worker_thread.join(timeout=5)
|
||||
self._worker_thread = None
|
||||
|
||||
|
||||
task_queue = TaskQueue()
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
自动化回测服务启动脚本
|
||||
启动 FastAPI 服务,监听 8088 端口,接受回测任务提交
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加 backtest-service 到 sys.path
|
||||
backtest_dir = os.path.join(os.path.dirname(__file__), "backtest-service")
|
||||
sys.path.insert(0, backtest_dir)
|
||||
|
||||
# 现在导入 main
|
||||
from main import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,71 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
VNPY RPC 交易核心服务启动脚本
|
||||
启动 vnpy 交易核心,开启 RPC 服务端,供 VNPY Web Trader 连接
|
||||
|
||||
按照 vnpy 官方标准双进程架构:
|
||||
- 这个进程:交易核心 + RPC 服务端 (端口 2018/4102)
|
||||
- 另一个进程:VNPY Web Trader (端口 8000)
|
||||
"""
|
||||
|
||||
from vnpy.trader.event_engine import EventEngine
|
||||
from vnpy.trader.main_engine import MainEngine
|
||||
from vnpy.rpc import RpcServer
|
||||
|
||||
# 导入你的 gateway
|
||||
from vnpy_ctp import CtpGateway
|
||||
# from vnpy_tap import TapGateway
|
||||
# from vnpy_ib import IbGateway
|
||||
# 其他 gateway 根据需要添加
|
||||
|
||||
# 导入你的策略应用
|
||||
from vnpy_ctastrategy import CtaStrategyApp
|
||||
# from vnpy_portfoliostrategy import PortfolioStrategyApp
|
||||
# from vnpy_spreadtrading import SpreadTradingApp
|
||||
|
||||
|
||||
def main():
|
||||
# 创建事件引擎和主引擎
|
||||
event_engine = EventEngine()
|
||||
main_engine = MainEngine(event_engine)
|
||||
|
||||
# 添加 gateway
|
||||
main_engine.add_gateway(CtpGateway)
|
||||
# 如果你有其他gateway,在这里添加
|
||||
# main_engine.add_gateway(TapGateway)
|
||||
|
||||
# 添加策略应用
|
||||
main_engine.add_app(CtaStrategyApp)
|
||||
# 如果需要其他应用,在这里添加
|
||||
# main_engine.add_app(PortfolioStrategyApp)
|
||||
|
||||
# 启动 RPC 服务端
|
||||
# 请求端口: 2018,订阅端口: 4102
|
||||
rpc_server = RpcServer(
|
||||
main_engine,
|
||||
("0.0.0.0", 2018),
|
||||
("0.0.0.0", 4102)
|
||||
)
|
||||
|
||||
print("=" * 50)
|
||||
print("VNPY RPC 交易核心服务启动")
|
||||
print(f"RPC 请求地址: tcp://0.0.0.0:2018")
|
||||
print(f"RPC 订阅地址: tcp://0.0.0.0:4102")
|
||||
print("请确保已经在 vnpy 配置中配置好你的 gateway")
|
||||
print("=" * 50)
|
||||
|
||||
# 启动 RPC 服务
|
||||
rpc_server.start()
|
||||
|
||||
# 保持进程运行
|
||||
print("RPC 服务已启动,按 Ctrl+C 退出")
|
||||
try:
|
||||
input()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
pass
|
||||
|
||||
print("RPC 服务退出")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,38 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
VNPY Web Trader 服务启动脚本
|
||||
按照 vnpy 官方标准双进程架构:
|
||||
- 需要先启动 start_rpc_server.py 交易核心
|
||||
- 然后启动这个 Web Trader 服务
|
||||
- Web Trader 通过 RPC 连接交易核心
|
||||
"""
|
||||
|
||||
from vnpy_webtrader import run_web_trader
|
||||
|
||||
|
||||
def main():
|
||||
# RPC 连接地址,默认和交易核心同机
|
||||
rpc_request_address = "tcp://127.0.0.1:2018"
|
||||
rpc_subscribe_address = "tcp://127.0.0.1:4102"
|
||||
|
||||
print("=" * 50)
|
||||
print("VNPY Web Trader 服务启动")
|
||||
print(f"RPC 请求地址: {rpc_request_address}")
|
||||
print(f"RPC 订阅地址: {rpc_subscribe_address}")
|
||||
print(f"Web 服务监听: 0.0.0.0:8000")
|
||||
print("=" * 50)
|
||||
print("请确保先启动 start_rpc_server.py")
|
||||
print("=" * 50)
|
||||
|
||||
# 启动 Web Trader
|
||||
run_web_trader(
|
||||
rpc_request_address,
|
||||
rpc_subscribe_address,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
cors_allow_all=True # 开发环境允许跨域
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,469 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
vn.py本地数据适配器 - 姜维
|
||||
功能:让vn.py优先加载赵云将军下载的本地数据,本地没有再去akshare接口下载
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
import akshare as ak
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('vnpy_local_data_adapter.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VnpyLocalDataAdapter:
|
||||
"""
|
||||
vn.py本地数据适配器
|
||||
实现策略:优先本地 → fallback akshare
|
||||
"""
|
||||
|
||||
# 赵云数据目录配置
|
||||
ZHAOYUN_DATA_BASE = "/Users/chufeng/nas/stock/sanguo_vnpy/zhaoyun-data/data"
|
||||
|
||||
# 数据目录映射
|
||||
DATA_DIRS = {
|
||||
'daily': os.path.join(ZHAOYUN_DATA_BASE, "raw/daily"),
|
||||
'financial': os.path.join(ZHAOYUN_DATA_BASE, "raw/financial"),
|
||||
'stock_info': os.path.join(ZHAOYUN_DATA_BASE, "raw/stock_info"),
|
||||
'minute': os.path.join(ZHAOYUN_DATA_BASE, "raw/minute_kline"),
|
||||
}
|
||||
|
||||
# vn.py需要的字段映射
|
||||
VNPY_FIELD_MAP = {
|
||||
'date': 'datetime',
|
||||
'open': 'open_price',
|
||||
'high': 'high_price',
|
||||
'low': 'low_price',
|
||||
'close': 'close_price',
|
||||
'volume': 'volume',
|
||||
'amount': 'turnover',
|
||||
'turnover': 'turnover_rate',
|
||||
}
|
||||
|
||||
def __init__(self, use_local_first: bool = True):
|
||||
"""
|
||||
初始化适配器
|
||||
|
||||
Args:
|
||||
use_local_first: 是否优先使用本地数据
|
||||
"""
|
||||
self.use_local_first = use_local_first
|
||||
self._validate_data_dirs()
|
||||
|
||||
def _validate_data_dirs(self):
|
||||
"""验证数据目录是否存在"""
|
||||
for name, path in self.DATA_DIRS.items():
|
||||
if os.path.exists(path):
|
||||
logger.info(f"✅ 赵云数据目录 {name}: {path}")
|
||||
else:
|
||||
logger.warning(f"⚠️ 赵云数据目录不存在 {name}: {path}")
|
||||
|
||||
def _parse_symbol(self, symbol: str) -> Tuple[str, str]:
|
||||
"""
|
||||
解析股票代码,返回标准化代码和交易所
|
||||
|
||||
Args:
|
||||
symbol: 股票代码,如 "000001.SZ" 或 "600000"
|
||||
|
||||
Returns:
|
||||
(symbol_code, exchange): 如 ("000001", "SZ")
|
||||
"""
|
||||
# 移除后缀
|
||||
if '.' in symbol:
|
||||
symbol_code, exchange = symbol.split('.')
|
||||
exchange = exchange.upper()
|
||||
else:
|
||||
symbol_code = symbol
|
||||
# 根据代码判断交易所
|
||||
if symbol_code.startswith('6'):
|
||||
exchange = 'SH'
|
||||
elif symbol_code.startswith(('0', '3')):
|
||||
exchange = 'SZ'
|
||||
elif symbol_code.startswith('8'):
|
||||
exchange = 'BJ'
|
||||
else:
|
||||
exchange = 'SZ' # 默认深交所
|
||||
|
||||
return symbol_code, exchange
|
||||
|
||||
def _get_local_daily_file_path(self, symbol: str, year: int) -> Optional[str]:
|
||||
"""
|
||||
获取本地日线数据文件路径
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
year: 年份
|
||||
|
||||
Returns:
|
||||
文件路径,如果不存在返回None
|
||||
"""
|
||||
symbol_code, exchange = self._parse_symbol(symbol)
|
||||
|
||||
# 构建文件名格式
|
||||
if exchange == 'SH':
|
||||
file_prefix = f"sh{symbol_code}"
|
||||
elif exchange == 'SZ':
|
||||
file_prefix = f"sz{symbol_code}"
|
||||
elif exchange == 'BJ':
|
||||
file_prefix = f"bj{symbol_code}"
|
||||
else:
|
||||
file_prefix = symbol_code
|
||||
|
||||
# 查找文件
|
||||
pattern = os.path.join(self.DATA_DIRS['daily'], str(year), f"{file_prefix}_daily.parquet")
|
||||
if os.path.exists(pattern):
|
||||
return pattern
|
||||
|
||||
# 尝试其他可能的文件名格式
|
||||
pattern2 = os.path.join(self.DATA_DIRS['daily'], str(year), f"{symbol_code}_daily.parquet")
|
||||
if os.path.exists(pattern2):
|
||||
return pattern2
|
||||
|
||||
return None
|
||||
|
||||
def load_local_daily_data(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
从赵云本地数据加载日线数据
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
start_date: 开始日期 "YYYY-MM-DD"
|
||||
end_date: 结束日期 "YYYY-MM-DD"
|
||||
|
||||
Returns:
|
||||
日线数据DataFrame,如果本地没有返回None
|
||||
"""
|
||||
if not self.use_local_first:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 解析日期范围
|
||||
start_dt = pd.to_datetime(start_date)
|
||||
end_dt = pd.to_datetime(end_date)
|
||||
|
||||
# 收集所有年份的数据
|
||||
all_data = []
|
||||
for year in range(start_dt.year, end_dt.year + 1):
|
||||
file_path = self._get_local_daily_file_path(symbol, year)
|
||||
if file_path and os.path.exists(file_path):
|
||||
df = pd.read_parquet(file_path)
|
||||
|
||||
# 过滤日期范围
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
mask = (df['date'] >= start_dt) & (df['date'] <= end_dt)
|
||||
df_filtered = df[mask]
|
||||
|
||||
if not df_filtered.empty:
|
||||
all_data.append(df_filtered)
|
||||
logger.debug(f"✅ 从本地加载 {symbol} {year}年数据: {len(df_filtered)} 条")
|
||||
|
||||
if all_data:
|
||||
# 合并所有年份数据
|
||||
result = pd.concat(all_data, ignore_index=True)
|
||||
result = result.sort_values('date')
|
||||
|
||||
# 转换为vn.py字段名
|
||||
result = result.rename(columns=self.VNPY_FIELD_MAP)
|
||||
|
||||
# 添加symbol和exchange字段
|
||||
symbol_code, exchange = self._parse_symbol(symbol)
|
||||
result['symbol'] = symbol_code
|
||||
result['exchange'] = exchange
|
||||
result['interval'] = '1d'
|
||||
|
||||
logger.info(f"✅ 成功从本地加载 {symbol} 数据: {len(result)} 条 ({start_date} 到 {end_date})")
|
||||
return result
|
||||
else:
|
||||
logger.info(f"⚠️ 本地没有找到 {symbol} 的数据")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载本地数据失败 {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def fetch_akshare_daily_data(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
从akshare获取日线数据(fallback方案)
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
start_date: 开始日期 "YYYY-MM-DD"
|
||||
end_date: 结束日期 "YYYY-MM-DD"
|
||||
|
||||
Returns:
|
||||
日线数据DataFrame
|
||||
"""
|
||||
try:
|
||||
symbol_code, exchange = self._parse_symbol(symbol)
|
||||
|
||||
# 转换日期格式
|
||||
start_date_ak = start_date.replace('-', '')
|
||||
end_date_ak = end_date.replace('-', '')
|
||||
|
||||
logger.info(f"📡 从akshare获取 {symbol} 数据 ({start_date} 到 {end_date})")
|
||||
|
||||
# 获取数据
|
||||
df = ak.stock_zh_a_hist(
|
||||
symbol=symbol_code,
|
||||
period="daily",
|
||||
start_date=start_date_ak,
|
||||
end_date=end_date_ak,
|
||||
adjust="" # 不复权
|
||||
)
|
||||
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"⚠️ akshare没有 {symbol} 的数据")
|
||||
return None
|
||||
|
||||
# 重命名列
|
||||
df.rename(columns={
|
||||
'日期': 'datetime',
|
||||
'开盘': 'open_price',
|
||||
'收盘': 'close_price',
|
||||
'最高': 'high_price',
|
||||
'最低': 'low_price',
|
||||
'成交量': 'volume',
|
||||
'成交额': 'turnover',
|
||||
}, inplace=True)
|
||||
|
||||
# 格式化日期
|
||||
df['datetime'] = pd.to_datetime(df['datetime']).dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 添加其他字段
|
||||
df['symbol'] = symbol_code
|
||||
df['exchange'] = exchange
|
||||
df['interval'] = '1d'
|
||||
|
||||
logger.info(f"✅ 从akshare获取 {symbol} 数据成功: {len(df)} 条")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 从akshare获取数据失败 {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_daily_data(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||||
"""
|
||||
获取日线数据(优先本地,fallback akshare)
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
start_date: 开始日期 "YYYY-MM-DD"
|
||||
end_date: 结束日期 "YYYY-MM-DD"
|
||||
|
||||
Returns:
|
||||
日线数据DataFrame,如果都失败返回空DataFrame
|
||||
"""
|
||||
# 1. 优先尝试本地数据
|
||||
if self.use_local_first:
|
||||
local_data = self.load_local_daily_data(symbol, start_date, end_date)
|
||||
if local_data is not None and not local_data.empty:
|
||||
return local_data
|
||||
|
||||
# 2. fallback到akshare
|
||||
akshare_data = self.fetch_akshare_daily_data(symbol, start_date, end_date)
|
||||
if akshare_data is not None and not akshare_data.empty:
|
||||
return akshare_data
|
||||
|
||||
# 3. 都失败
|
||||
logger.error(f"❌ 无法获取 {symbol} 的数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
def verify_local_data_structure(self, symbol: str) -> Dict:
|
||||
"""
|
||||
验证本地数据结构是否符合vn.py要求
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
|
||||
Returns:
|
||||
验证结果字典
|
||||
"""
|
||||
result = {
|
||||
'symbol': symbol,
|
||||
'has_local_data': False,
|
||||
'data_years': [],
|
||||
'missing_fields': [],
|
||||
'recommendations': [],
|
||||
'status': 'UNKNOWN'
|
||||
}
|
||||
|
||||
try:
|
||||
# 查找所有年份的数据
|
||||
data_years = []
|
||||
for year in range(2010, 2027): # 假设数据范围
|
||||
file_path = self._get_local_daily_file_path(symbol, year)
|
||||
if file_path and os.path.exists(file_path):
|
||||
data_years.append(year)
|
||||
|
||||
# 检查字段
|
||||
df = pd.read_parquet(file_path)
|
||||
required_fields = ['date', 'open', 'high', 'low', 'close', 'volume']
|
||||
missing = [field for field in required_fields if field not in df.columns]
|
||||
|
||||
if missing:
|
||||
result['missing_fields'].extend(missing)
|
||||
|
||||
result['data_years'] = data_years
|
||||
result['has_local_data'] = len(data_years) > 0
|
||||
|
||||
if result['has_local_data']:
|
||||
if result['missing_fields']:
|
||||
result['status'] = 'INCOMPLETE'
|
||||
result['recommendations'].append(f"缺少字段: {result['missing_fields']}")
|
||||
result['recommendations'].append("建议:使用data_convert_tool.py转换数据格式")
|
||||
else:
|
||||
result['status'] = 'OK'
|
||||
result['recommendations'].append(f"✅ 数据结构完整,覆盖 {min(data_years)}-{max(data_years)} 年")
|
||||
else:
|
||||
result['status'] = 'NO_DATA'
|
||||
result['recommendations'].append("建议:联系赵云将军下载该股票数据")
|
||||
|
||||
except Exception as e:
|
||||
result['status'] = 'ERROR'
|
||||
result['recommendations'].append(f"验证错误: {e}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class DataConvertTool:
|
||||
"""
|
||||
数据格式转换工具
|
||||
用于将赵云的数据格式转换为vn.py需要的格式
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def convert_zhaoyun_to_vnpy(input_path: str, output_path: str, symbol: str):
|
||||
"""
|
||||
将赵云数据格式转换为vn.py格式
|
||||
|
||||
Args:
|
||||
input_path: 赵云数据文件路径
|
||||
output_path: 输出文件路径
|
||||
symbol: 股票代码
|
||||
"""
|
||||
try:
|
||||
# 读取赵云数据
|
||||
df = pd.read_parquet(input_path)
|
||||
|
||||
# 检查必要字段
|
||||
required = ['date', 'open', 'high', 'low', 'close', 'volume']
|
||||
missing = [field for field in required if field not in df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"缺少必要字段: {missing}")
|
||||
|
||||
# 转换为vn.py格式
|
||||
vnpy_df = pd.DataFrame()
|
||||
vnpy_df['datetime'] = pd.to_datetime(df['date']).dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
vnpy_df['open_price'] = df['open']
|
||||
vnpy_df['high_price'] = df['high']
|
||||
vnpy_df['low_price'] = df['low']
|
||||
vnpy_df['close_price'] = df['close']
|
||||
vnpy_df['volume'] = df['volume']
|
||||
|
||||
# 添加其他字段
|
||||
if 'amount' in df.columns:
|
||||
vnpy_df['turnover'] = df['amount']
|
||||
else:
|
||||
vnpy_df['turnover'] = df['volume'] * df['close'] # 估算成交额
|
||||
|
||||
if 'turnover' in df.columns:
|
||||
vnpy_df['turnover_rate'] = df['turnover']
|
||||
|
||||
# 添加标识字段
|
||||
symbol_code, exchange = VnpyLocalDataAdapter._parse_symbol(VnpyLocalDataAdapter(), symbol)
|
||||
vnpy_df['symbol'] = symbol_code
|
||||
vnpy_df['exchange'] = exchange
|
||||
vnpy_df['interval'] = '1d'
|
||||
|
||||
# 保存为parquet
|
||||
vnpy_df.to_parquet(output_path, index=False)
|
||||
logger.info(f"✅ 数据转换完成: {input_path} → {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 数据转换失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# vn.py数据管理器包装器
|
||||
class VnpyDataManagerWrapper:
|
||||
"""
|
||||
vn.py数据管理器包装器
|
||||
替换vn.py默认的数据获取逻辑
|
||||
"""
|
||||
|
||||
def __init__(self, original_data_manager, adapter: VnpyLocalDataAdapter):
|
||||
"""
|
||||
初始化包装器
|
||||
|
||||
Args:
|
||||
original_data_manager: 原始vn.py数据管理器
|
||||
adapter: 本地数据适配器
|
||||
"""
|
||||
self.original_dm = original_data_manager
|
||||
self.adapter = adapter
|
||||
self._patch_methods()
|
||||
|
||||
def _patch_methods(self):
|
||||
"""修补vn.py数据获取方法"""
|
||||
# 这里需要根据vn.py的具体API进行修补
|
||||
# 由于vn.py版本和实现不同,这里提供示例代码
|
||||
|
||||
logger.info("✅ vn.py数据管理器已修补为优先使用本地数据")
|
||||
|
||||
def get_daily_bar_data(self, symbol: str, start_date: str, end_date: str):
|
||||
"""获取日线数据(重写方法)"""
|
||||
return self.adapter.get_daily_data(symbol, start_date, end_date)
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 1. 创建适配器
|
||||
adapter = VnpyLocalDataAdapter(use_local_first=True)
|
||||
|
||||
# 2. 测试数据获取
|
||||
test_symbol = "000001.SZ" # 平安银行
|
||||
start_date = "2024-01-01"
|
||||
end_date = "2024-03-01"
|
||||
|
||||
print("=" * 60)
|
||||
print("vn.py本地数据适配器测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 3. 验证本地数据
|
||||
print("\n1. 验证本地数据结构:")
|
||||
verification = adapter.verify_local_data_structure(test_symbol)
|
||||
for key, value in verification.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# 4. 获取数据
|
||||
print(f"\n2. 获取 {test_symbol} 数据 ({start_date} 到 {end_date}):")
|
||||
data = adapter.get_daily_data(test_symbol, start_date, end_date)
|
||||
|
||||
if not data.empty:
|
||||
print(f"✅ 成功获取 {len(data)} 条数据")
|
||||
print(f"数据字段: {list(data.columns)}")
|
||||
print(f"时间范围: {data['datetime'].min()} 到 {data['datetime'].max()}")
|
||||
print(f"数据来源: {'本地' if 'outstanding_share' in data.columns else 'akshare'}")
|
||||
else:
|
||||
print("❌ 获取数据失败")
|
||||
|
||||
print("\n3. 使用建议:")
|
||||
print(" a) 在vn.py策略中导入此适配器")
|
||||
print(" b) 替换原有的数据获取逻辑")
|
||||
print(" c) 配置赵云数据目录路径")
|
||||
print(" d) 定期更新本地数据(联系赵云将军)")
|
||||
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user