223 lines
6.5 KiB
Python
223 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
在Docker容器内直接运行回测 - 不经过HTTP API
|
|
针对 510300.SSE 的单票回测
|
|
"""
|
|
|
|
# ============================================
|
|
# 1. 风控模块代码 (risk_control.py)
|
|
# ============================================
|
|
RISK_CONTROL_CODE = '''
|
|
"""
|
|
风控模块 - 量化策略风控系统
|
|
功能:
|
|
1. 单票15%止损规则
|
|
2. 整体回撤分级风控(10%/20%/25% 分级降仓)
|
|
3. 黑天鹅过滤(ST、跌停、财务造假排除)
|
|
|
|
Author: 关羽(云长)
|
|
Date: 2026-03-27
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
import pandas as pd
|
|
|
|
|
|
@dataclass
|
|
class StockInfo:
|
|
"""单票基本信息"""
|
|
code: str
|
|
name: str
|
|
cost_price: float
|
|
current_price: float
|
|
is_st: bool = False
|
|
is_limit_down: bool = False
|
|
is_fraud: bool = False
|
|
volume: float = 0.0 # 日成交额(亿)
|
|
|
|
|
|
@dataclass
|
|
class PortfolioInfo:
|
|
"""组合信息"""
|
|
total_capital: float
|
|
current_capital: float
|
|
positions: dict[str, float] # code -> position_size
|
|
|
|
|
|
class SingleStockRiskControl:
|
|
"""单票风控:15%止损规则"""
|
|
|
|
def __init__(self, stop_loss_pct: float = 0.15):
|
|
self.stop_loss_pct = stop_loss_pct
|
|
|
|
def check_stop_loss(self, stock: StockInfo) -> bool:
|
|
"""检查是否触发止损"""
|
|
if stock.cost_price <= 0:
|
|
return False
|
|
|
|
drawdown = (stock.current_price - stock.cost_price) / stock.cost_price
|
|
return drawdown <= -self.stop_loss_pct
|
|
|
|
def get_drawdown(self, stock: StockInfo) -> float:
|
|
"""计算单票当前回撤"""
|
|
if stock.cost_price <= 0:
|
|
return 0.0
|
|
return (stock.current_price - stock.cost_price) / stock.cost_price
|
|
|
|
|
|
class PortfolioDrawdownRiskControl:
|
|
"""整体回撤分级风控"""
|
|
|
|
def __init__(self, drawdown_levels=None, reduce_ratios=None):
|
|
self.drawdown_levels = drawdown_levels or [0.10, 0.20, 0.25]
|
|
self.reduce_ratios = reduce_ratios or [0.50, 0.25, 0.00]
|
|
|
|
def calculate_total_drawdown(self, portfolio: PortfolioInfo) -> float:
|
|
"""计算组合总回撤"""
|
|
if portfolio.total_capital <= 0:
|
|
return 0.0
|
|
return (portfolio.total_capital - portfolio.current_capital) / portfolio.total_capital
|
|
|
|
def get_target_position_ratio(self, portfolio: PortfolioInfo) -> float:
|
|
"""获取目标仓位比例"""
|
|
drawdown = self.calculate_total_drawdown(portfolio)
|
|
|
|
for level, ratio in reversed(list(zip(self.drawdown_levels, self.reduce_ratios))):
|
|
if drawdown >= level:
|
|
return ratio
|
|
|
|
return 1.0
|
|
|
|
|
|
class RiskController:
|
|
"""总风控控制器"""
|
|
|
|
def __init__(self):
|
|
self.single_stock_rc = SingleStockRiskControl()
|
|
self.portfolio_rc = PortfolioDrawdownRiskControl()
|
|
'''
|
|
|
|
# ============================================
|
|
# 2. 简化策略代码 (回测510300.SSE单票)
|
|
# ============================================
|
|
SIMPLE_STRATEGY_CODE = '''
|
|
"""
|
|
简化版策略 - 针对510300.SSE的单票回测测试
|
|
"""
|
|
|
|
from vnpy.app.cta_strategy import CtaTemplate
|
|
|
|
class SingleStockStopLossStrategy(CtaTemplate):
|
|
"""
|
|
简化版单票策略 - 测试510300.SSE
|
|
"""
|
|
|
|
parameters = ["fast_window", "slow_window", "stop_loss_pct"]
|
|
variables = ["stop_loss_triggered"]
|
|
|
|
def __init__(self, cta_engine, strategy_name, setting_dict):
|
|
super().__init__(cta_engine, strategy_name, setting_dict)
|
|
self.fast_window = getattr(self, 'fast_window', 5)
|
|
self.slow_window = getattr(self, 'slow_window', 20)
|
|
self.stop_loss_pct = getattr(self, 'stop_loss_pct', 0.15)
|
|
self.stop_loss_triggered = False
|
|
|
|
def on_init(self):
|
|
self.write_log("策略初始化")
|
|
self.load_bar(100)
|
|
|
|
def on_bar(self, bar):
|
|
# 检查止损
|
|
if self.pos > 0:
|
|
profit = (bar.close_price - self.avg_price) / self
|
|
if profit <= -self.stop_loss_pct:
|
|
self.write_log(f"触发止损: {bar.datetime}, 回撤: {profit:.2%}")
|
|
self.stop_loss_triggered = True
|
|
self.sell(bar.close_price, abs(self.pos))
|
|
return
|
|
|
|
# 简单均线策略
|
|
if not self.stop_loss_triggered:
|
|
if bar.close_price > self.amo(bar.close_price, self.fast_window):
|
|
if self.pos == 0:
|
|
self.buy(bar.close_price, 10000)
|
|
elif bar.close_price < self.amo(bar.close_price, self.slow_window):
|
|
if self.pos > 0:
|
|
self.sell(bar.close_price, abs(self.pos))
|
|
'''
|
|
|
|
# ============================================
|
|
# 3. 执行回测
|
|
# ============================================
|
|
|
|
# 导入必要模块
|
|
import types
|
|
import sys
|
|
|
|
print("=" * 80)
|
|
print("🚀 在Docker容器内直接运行回测")
|
|
print("=" * 80)
|
|
|
|
# 加载vnpy.app兼容性模块
|
|
vnpy_app_module = types.ModuleType('vnpy.app')
|
|
sys.modules['vnpy.app'] = vnpy_app_module
|
|
|
|
submodules = ['cta_strategy', 'cta_backtester', 'data_manager']
|
|
for name in submodules:
|
|
full_name = f'vnpy.app.{name}'
|
|
submodule = types.ModuleType(full_name)
|
|
sys.modules[full_name] = submodule
|
|
setattr(vnpy_app_module, name, submodule)
|
|
|
|
from vnpy_ctastrategy import CtaTemplate, BarData
|
|
sys.modules['vnpy.app.cta_strategy'].CtaTemplate = CtaTemplate
|
|
|
|
from vnpy_ctabacktester import BacktesterEngine
|
|
sys.modules['vnpy.app.cta_backtester'].BacktesterEngine = BacktesterEngine
|
|
|
|
print("✅ vnpy.app兼容性模块加载完成")
|
|
|
|
# 执行风控代码
|
|
exec(RISK_CONTROL_CODE, globals())
|
|
|
|
# 执行策略代码
|
|
exec(SIMPLE_STRATEGY_CODE, globals())
|
|
|
|
# 运行回测
|
|
from vnpy.event import EventEngine
|
|
from vnpy.trader.engine import MainEngine
|
|
from vnpy.trader.constant import Exchange, Interval
|
|
from datetime import datetime
|
|
|
|
print("\n" + "=" * 80)
|
|
print("初始化回测引擎...")
|
|
print("=" * 80)
|
|
|
|
event_engine = EventEngine()
|
|
main_engine = MainEngine(event_engine)
|
|
|
|
# 手动实例化BacktesterEngine
|
|
backtester_engine = BacktesterEngine(main_engine, event_engine)
|
|
backtester_engine.classes["SingleStockStopLossStrategy"] = SingleStockStopLossStrategy
|
|
|
|
print("✅ BacktesterEngine 初始化完成")
|
|
|
|
# 加载数据
|
|
print("\n" + "=" * 80)
|
|
print("加载数据...")
|
|
print("=" * 80)
|
|
|
|
from vnpy.trader.database import get_database
|
|
db = get_database()
|
|
|
|
symbol = "510300"
|
|
exchange = Exchange.SSE
|
|
interval = Interval.DAILY
|
|
start = datetime(2021, 1, 1)
|
|
end = datetime(2026, 3, 1)
|
|
|
|
bars = db.load_bar_data(symbol, exchange, interval, start, end)
|
|
print(f"✅ 加载了 {len(bars)} 条bar数据")
|
|
|
|
if len(bars) == 0:
|
|
[TRUNCATED] |