58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
"""
|
|
BaseStrategy - 策略基类
|
|
|
|
策略开发者只需继承此类并实现 generate_signals() 方法。
|
|
回测引擎通过统一接口调用策略。
|
|
|
|
Usage:
|
|
from data_platform.strategy_base import BaseStrategy
|
|
|
|
class MAStrategy(BaseStrategy):
|
|
def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
data = data.copy()
|
|
data["ma5"] = data["close"].rolling(5).mean()
|
|
data["ma20"] = data["close"].rolling(20).mean()
|
|
data["signal"] = 0
|
|
data.loc[data["ma5"] > data["ma20"], "signal"] = 1 # 买入
|
|
data.loc[data["ma5"] < data["ma20"], "signal"] = -1 # 卖出
|
|
return data
|
|
"""
|
|
|
|
import pandas as pd
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
class BaseStrategy(ABC):
|
|
"""
|
|
策略基类 —— 所有回测策略必须继承此类
|
|
|
|
子类只需实现 generate_signals(data) -> data_with_signals
|
|
signal 列约定:
|
|
1 = 买入信号
|
|
-1 = 卖出信号
|
|
0 = 无操作
|
|
"""
|
|
|
|
@abstractmethod
|
|
def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
"""
|
|
根据行情数据生成交易信号
|
|
|
|
Args:
|
|
data: 日线行情 DataFrame,至少包含 date, open, high, low, close, volume
|
|
|
|
Returns:
|
|
原始数据追加 signal 列(1=买, -1=卖, 0=无操作)
|
|
"""
|
|
...
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""策略名称,默认取类名"""
|
|
return self.__class__.__name__
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
"""策略描述,子类可覆盖"""
|
|
return ""
|