diff --git a/data_platform/strategy_base.py b/data_platform/strategy_base.py new file mode 100644 index 000000000..b97b8e691 --- /dev/null +++ b/data_platform/strategy_base.py @@ -0,0 +1,57 @@ +""" +BaseStrategy - 策略基类 + +策略开发者只需继承此类并实现 generate_signals() 方法。 +回测引擎通过统一接口调用策略。 + +Usage: + from data_platform.strategy_base import BaseStrategy + + class MAStrategy(BaseStrategy): + def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame: + data = data.copy() + data["ma5"] = data["close"].rolling(5).mean() + data["ma20"] = data["close"].rolling(20).mean() + data["signal"] = 0 + data.loc[data["ma5"] > data["ma20"], "signal"] = 1 # 买入 + data.loc[data["ma5"] < data["ma20"], "signal"] = -1 # 卖出 + return data +""" + +import pandas as pd +from abc import ABC, abstractmethod + + +class BaseStrategy(ABC): + """ + 策略基类 —— 所有回测策略必须继承此类 + + 子类只需实现 generate_signals(data) -> data_with_signals + signal 列约定: + 1 = 买入信号 + -1 = 卖出信号 + 0 = 无操作 + """ + + @abstractmethod + def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame: + """ + 根据行情数据生成交易信号 + + Args: + data: 日线行情 DataFrame,至少包含 date, open, high, low, close, volume + + Returns: + 原始数据追加 signal 列(1=买, -1=卖, 0=无操作) + """ + ... + + @property + def name(self) -> str: + """策略名称,默认取类名""" + return self.__class__.__name__ + + @property + def description(self) -> str: + """策略描述,子类可覆盖""" + return ""