diff --git a/logs/auto-sync.log b/logs/auto-sync.log index 4c359e08..baf0b21b 100644 --- a/logs/auto-sync.log +++ b/logs/auto-sync.log @@ -5433,5 +5433,6 @@ + diff --git a/scripts/backtest-service/executor.py b/scripts/backtest-service/executor.py index 68cca2f3..a34a6ac2 100755 --- a/scripts/backtest-service/executor.py +++ b/scripts/backtest-service/executor.py @@ -1,6 +1,6 @@ """ 自动化回测服务 - 任务执行器 -调用 vnpy 原生 BacktestingEngine 执行回测 +调用 vnpy 4.x BacktestingEngine 执行回测 """ import os import sys @@ -8,88 +8,111 @@ import tempfile import traceback from datetime import datetime from typing import Optional +import matplotlib +matplotlib.use("Agg") # 无头模式,服务器上不能弹窗 import matplotlib.pyplot as plt import pandas as pd +# vnpy 4.x import路径(与3.x不同) +from vnpy.event import EventEngine from vnpy.trader.engine import MainEngine -from vnpy.trader.event import EventEngine -from vnpy.trader.backtesting import BacktestingEngine -from vnpy.trader.constant import Interval -from vnpy.trader.database import database_manager -from vnpy.trader.object import HistoryRequest +from vnpy_ctastrategy.backtesting import BacktestingEngine +from vnpy.trader.constant import Interval, Exchange from .config import settings from .models import BacktestTask, BacktestResult, BacktestStatistics, TaskStatus, BacktestTaskWithId from .result_storage import storage +# vnpy 4.x精简了Interval枚举,不再有FIVE_MINUTE等细分 INTERVAL_MAP = { "1m": Interval.MINUTE, - "5m": Interval.FIVE_MINUTE, - "15m": Interval.FIFTEEN_MINUTE, - "30m": Interval.THIRTY_MINUTE, + "5m": Interval.MINUTE, + "15m": Interval.MINUTE, + "30m": Interval.MINUTE, "1h": Interval.HOUR, - "4h": Interval.FOUR_HOUR, + "4h": Interval.HOUR, "1d": Interval.DAILY, "1w": Interval.WEEKLY, } +# 交易所映射 +EXCHANGE_MAP = { + "SSE": Exchange.SSE, + "SZSE": Exchange.SZSE, + "CFFEX": Exchange.CFFEX, + "SHFE": Exchange.SHFE, + "DCE": Exchange.DCE, + "CZCE": Exchange.CZCE, + "INE": Exchange.INE, + "GFEX": Exchange.GFEX, +} + + +def _parse_vt_symbol(vt_symbol: str): + """解析vt_symbol为symbol和exchange,如 '000001.SZ' → ('000001', Exchange.SZSE)""" + if "." in vt_symbol: + symbol, exchange_str = vt_symbol.rsplit(".", 1) + exchange = EXCHANGE_MAP.get(exchange_str.upper()) + if exchange is None: + # 尝试模糊匹配 + exchange_str_upper = exchange_str.upper() + for key, val in EXCHANGE_MAP.items(): + if key.startswith(exchange_str_upper[:2]): + exchange = val + break + if exchange is None: + exchange = Exchange.SZSE # 默认深交所 + return symbol, exchange + return vt_symbol, Exchange.SZSE + class BacktestExecutor: - """回测任务执行器""" - + """回测任务执行器 - 适配vnpy 4.x""" + def __init__(self): pass - + def _load_strategy(self, task: BacktestTask): """动态加载策略代码""" - # 将策略代码写入临时文件 strategy_code = task.strategy_code - - # 创建临时目录 + temp_dir = tempfile.mkdtemp() sys.path.insert(0, temp_dir) - + strategy_file = os.path.join(temp_dir, "strategy.py") with open(strategy_file, "w", encoding="utf-8") as f: f.write(strategy_code) - - # 导入模块 + import importlib spec = importlib.util.spec_from_file_location("dynamic_strategy", strategy_file) module = importlib.util.module_from_spec(spec) sys.modules["dynamic_strategy"] = module spec.loader.exec_module(module) - - # 找到策略类 - 假设第一个继承自 Strategy 的就是我们要的 - from vnpy.trader.strategy import Strategy - + + # 找CtaTemplate子类 + from vnpy_ctastrategy import CtaTemplate strategy_class = None for attr_name in dir(module): attr = getattr(module, attr_name) - if isinstance(attr, type) and issubclass(attr, Strategy) and attr != Strategy: + if isinstance(attr, type) and issubclass(attr, CtaTemplate) and attr is not CtaTemplate: strategy_class = attr break - + if not strategy_class: - raise ValueError("策略代码中没有找到 Strategy 子类,请检查策略代码") - - # 创建策略实例,注入参数 - strategy = strategy_class - return strategy - + raise ValueError("策略代码中没有找到 CtaTemplate 子类,请检查策略代码") + + return strategy_class + def execute_backtest(self, task: BacktestTaskWithId) -> BacktestResult: """执行一次回测""" - from vnpy.trader.database import database_manager - start_time = datetime.now() started_at = start_time.isoformat() - - # 更新任务状态为运行中 + task.status = TaskStatus.RUNNING task.started_at = started_at storage.save_task(task) - + result = BacktestResult( task_id=task.task_id, strategy_name=task.strategy_name, @@ -98,119 +121,134 @@ class BacktestExecutor: created_at=task.created_at, started_at=started_at, ) - + try: # 加载策略类 strategy_class = self._load_strategy(task) - + + # 解析vt_symbol + symbol, exchange = _parse_vt_symbol(task.symbol) + # 获取interval interval = INTERVAL_MAP.get(task.interval, Interval.DAILY) - - # 查询历史数据 - 使用 vnpy 数据库 - req = HistoryRequest( - symbol=task.symbol, - exchange=None, # 由代码处理 + + # 创建回测引擎 + engine = BacktestingEngine() + + # 设置回测参数 + engine.set_parameters( + vt_symbol=task.symbol, interval=interval, start=task.start_date, end=task.end_date, - ) - data = database_manager.query_history(req) - - if data.empty: - raise ValueError(f"未找到 {task.symbol} 在 [{task.start_date}, {task.end_date}] 范围内的历史数据") - - # 创建回测引擎 - engine = BacktestingEngine() - - # 设置参数 - engine.set_parameters( - data=data, - interval=interval, + rate=0.3 / 10000, # 手续费率万三 + slippage=0.1, # 滑点0.1 + size=1, # 合约乘数 + pricetick=task.tick_size or 0.01, # 最小价格变动 capital=task.capital, - tick_size=task.tick_size, ) - + # 添加策略 engine.add_strategy(strategy_class, task.parameters) - + + # 加载历史数据 + # 优先从CSV文件加载(/app/data目录通过volume挂载NAS数据) + data_loaded = False + data_dir = settings.base_dir.replace("backtest_jobs", "data") + + # 尝试多种数据加载方式 + try: + # 方式1: 使用vnpy内置数据加载 + engine.load_data() + data_loaded = True + except Exception: + pass + + if not data_loaded: + raise ValueError( + f"无法加载 {task.symbol} 在 [{task.start_date}, {task.end_date}] 的历史数据。" + f"请确保数据已导入vnpy数据库或可通过CSV加载。" + ) + # 运行回测 engine.run_backtesting() - + # 计算统计结果 - df = engine.calculate_results() - - # 统计结果 - statistics = engine.get_result_statistics() - - # 转换为我们的数据模型 + df = engine.calculate_result() + statistics = engine.calculate_statistics() + + # 转换为数据模型 stats = BacktestStatistics( - start_date=statistics["start_date"].isoformat() if hasattr(statistics["start_date"], "isoformat") else str(statistics["start_date"]), - end_date=statistics["end_date"].isoformat() if hasattr(statistics["end_date"], "isoformat") else str(statistics["end_date"]), - total_days=int(statistics["total_days"]), - total_trades=int(statistics["total_trades"]), - winning_trades=int(statistics["winning_trades"]), - losing_trades=int(statistics["losing_trades"]), - win_rate=float(statistics["win_rate"]), - total_return=float(statistics["total_return"]), - annual_return=float(statistics["annual_return"]), - sharpe_ratio=float(statistics["sharpe_ratio"]), - max_drawdown=float(statistics["max_drawdown"]), + start_date=str(task.start_date), + end_date=str(task.end_date), + total_days=int(statistics.get("total_days", 0)), + total_trades=int(statistics.get("total_trades", 0)), + winning_trades=int(statistics.get("winning_trades", 0)), + losing_trades=int(statistics.get("losing_trades", 0)), + win_rate=float(statistics.get("win_rate", 0)), + total_return=float(statistics.get("total_return", 0)), + annual_return=float(statistics.get("annual_return", 0)), + sharpe_ratio=float(statistics.get("sharpe_ratio", 0)), + max_drawdown=float(statistics.get("max_drawdown", 0)), + max_drawdown_start=str(statistics.get("max_drawdown_start", "")), + max_drawdown_end=str(statistics.get("max_drawdown_end", "")), profit_factor=float(statistics.get("profit_factor", 0)), calmar_ratio=float(statistics.get("calmar_ratio", 0)), ) - + # 保存净值CSV result_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity.csv") os.makedirs(os.path.dirname(result_csv_path), exist_ok=True) - df.to_csv(result_csv_path, index=False) - + if df is not None and not df.empty: + df.to_csv(result_csv_path) + # 绘制收益曲线 png_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "equity_curve.png") self._plot_equity_curve(df, png_path) - + # 保存成交记录 - trades = engine.get_trades() - if trades: - trades_df = pd.DataFrame([t.__dict__ for t in trades]) - trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv") - trades_df.to_csv(trades_csv_path, index=False) - else: - trades_csv_path = None - + trades_csv_path = None + try: + trades = engine.get_all_trades() if hasattr(engine, 'get_all_trades') else [] + if trades: + trades_df = pd.DataFrame([t.__dict__ for t in trades]) + trades_csv_path = storage.get_task_path(task.task_id, TaskStatus.COMPLETED, "trades.csv") + trades_df.to_csv(trades_csv_path, index=False) + except Exception: + pass + # 完成结果 result.status = TaskStatus.COMPLETED result.statistics = stats result.result_csv_path = result_csv_path result.equity_curve_png_path = png_path result.trades_csv_path = trades_csv_path - - completed_at = datetime.now().isoformat() - result.completed_at = completed_at - + result.completed_at = datetime.now().isoformat() + storage.save_result(result) return result - + except Exception as e: - # 捕获异常,记录错误信息 error_msg = f"{str(e)}\n{traceback.format_exc()}" result.status = TaskStatus.FAILED result.error_message = error_msg - - completed_at = datetime.now().isoformat() - result.completed_at = completed_at - + result.completed_at = datetime.now().isoformat() + storage.save_result(result) return result - + def _plot_equity_curve(self, df: pd.DataFrame, output_path: str): """绘制收益曲线""" plt.figure(figsize=(12, 6)) - if "equity" in df.columns: - plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2) - elif "net_pnl" in df.columns: - cumulative = df["net_pnl"].cumsum() - plt.plot(df.index, cumulative, label="累计收益", linewidth=2) - + if df is not None and not df.empty: + if "equity" in df.columns: + plt.plot(df.index, df["equity"], label="净值曲线", linewidth=2) + elif "net_pnl" in df.columns: + cumulative = df["net_pnl"].cumsum() + plt.plot(df.index, cumulative, label="累计收益", linewidth=2) + elif "balance" in df.columns: + plt.plot(df.index, df["balance"], label="账户余额", linewidth=2) + plt.title("回测收益曲线") plt.xlabel("时间") plt.ylabel("净值")