diff --git a/zhangfei-technical/research/task-20260417-atr-indicator/atr_indicator.py b/zhangfei-technical/research/task-20260417-atr-indicator/atr_indicator.py new file mode 100644 index 000000000..a0ab14c29 --- /dev/null +++ b/zhangfei-technical/research/task-20260417-atr-indicator/atr_indicator.py @@ -0,0 +1,275 @@ +""" +ATR (Average True Range) 指标计算工具 +支持SMA和EMA两种计算方式 + +Author: 张飞 翼德 +Date: 2026-04-17 +""" + +import numpy as np +import pandas as pd +from typing import Optional, Union + + +class ATRIndicator: + """ + ATR (Average True Range) 指标计算类 + + 支持两种计算方式: + - SMA: 简单移动平均 + - EMA: 指数移动平均 + """ + + def __init__(self, period: int = 14, method: str = 'sma'): + """ + 初始化ATR指标 + + Parameters: + period: int, 默认14 + ATR计算周期 + method: str, 默认'sma' + 计算方法,可选 'sma' 或 'ema' + """ + if period <= 0: + raise ValueError(f"周期必须大于0,当前为 {period}") + + if method.lower() not in ['sma', 'ema']: + raise ValueError(f"计算方法必须是 'sma' 或 'ema',当前为 {method}") + + self.period = period + self.method = method.lower() + self._last_atr: Optional[float] = None + + @staticmethod + def calculate_tr( + high: Union[pd.Series, np.ndarray], + low: Union[pd.Series, np.ndarray], + close: Union[pd.Series, np.ndarray] + ) -> Union[pd.Series, np.ndarray]: + """ + 计算True Range (TR) + + 公式: TR = max[|high - low|, |high - prev_close|, |low - prev_close|] + + Parameters: + high: 最高价序列 + low: 最低价序列 + close: 收盘价序列 + + Returns: + TR序列,第一个值为NaN + """ + # 转换为numpy数组便于计算 + if isinstance(high, pd.Series): + tr_values = ATRIndicator._calculate_tr_numpy( + high.values, low.values, close.values + ) + return pd.Series(tr_values, index=high.index, name='TR') + else: + return ATRIndicator._calculate_tr_numpy(high, low, close) + + @staticmethod + def _calculate_tr_numpy( + high: np.ndarray, + low: np.ndarray, + close: np.ndarray + ) -> np.ndarray: + """使用numpy计算TR""" + if len(high) != len(low) or len(high) != len(close): + raise ValueError("high, low, close序列长度必须一致") + + n = len(high) + tr = np.full(n, np.nan) + + if n < 2: + return tr + + # 计算三个成分 + high_low = high[1:] - low[1:] + high_prev_close = np.abs(high[1:] - close[:-1]) + low_prev_close = np.abs(low[1:] - close[:-1]) + + # 取最大值 + tr[1:] = np.maximum(np.maximum(high_low, high_prev_close), low_prev_close) + + return tr + + def calculate( + self, + high: Union[pd.Series, np.ndarray], + low: Union[pd.Series, np.ndarray], + close: Union[pd.Series, np.ndarray] + ) -> Union[pd.Series, np.ndarray]: + """ + 计算ATR指标 + + Parameters: + high: 最高价序列 + low: 最低价序列 + close: 收盘价序列 + + Returns: + ATR序列 + """ + # 先计算TR + tr = self.calculate_tr(high, low, close) + + if self.method == 'sma': + return self._calculate_sma_atr(tr) + else: + return self._calculate_ema_atr(tr) + + def _calculate_sma_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: + """使用SMA计算ATR""" + if isinstance(tr, pd.Series): + atr = tr.rolling(window=self.period, min_periods=self.period).mean() + atr.name = f'ATR_{self.period}' + if len(atr) > 0: + self._last_atr = atr.iloc[-1] + return atr + else: + n = len(tr) + atr = np.full(n, np.nan) + for i in range(self.period - 1, n): + if i >= self.period: + atr[i] = np.mean(tr[i - self.period + 1:i + 1]) + if n > 0: + self._last_atr = atr[-1] + return atr + + def _calculate_ema_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: + """使用EMA计算ATR,Wilder平滑方法""" + if isinstance(tr, pd.Series): + return self._calculate_ema_atr_pandas(tr) + else: + return self._calculate_ema_atr_numpy(tr) + + def _calculate_ema_atr_pandas(self, tr: pd.Series) -> pd.Series: + """Pandas版本EMA-ATR计算,使用Wilder平滑""" + n = len(tr) + atr = pd.Series(np.full(n, np.nan), index=tr.index, name=f'ATR_{self.period}') + + if n < self.period: + self._last_atr = None + return atr + + # 第一个ATR用SMA计算 + first_atr = tr.iloc[1:self.period + 1].mean() + atr.iloc[self.period] = first_atr + self._last_atr = first_atr + + # Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period + alpha = 1.0 / self.period + + for i in range(self.period + 1, n): + if np.isnan(tr.iloc[i]): + atr.iloc[i] = atr.iloc[i - 1] + else: + atr.iloc[i] = atr.iloc[i - 1] * (1 - alpha) + tr.iloc[i] * alpha + + if len(atr) > 0: + self._last_atr = atr.iloc[-1] + + return atr + + def _calculate_ema_atr_numpy(self, tr: np.ndarray) -> np.ndarray: + """Numpy版本EMA-ATR计算,使用Wilder平滑""" + n = len(tr) + atr = np.full(n, np.nan) + + if n < self.period: + self._last_atr = None + return atr + + # 第一个ATR用SMA计算 + first_atr = np.mean(tr[1:self.period + 1]) + atr[self.period] = first_atr + self._last_atr = first_atr + + # Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period + alpha = 1.0 / self.period + + for i in range(self.period + 1, n): + if np.isnan(tr[i]): + atr[i] = atr[i - 1] + else: + atr[i] = atr[i - 1] * (1 - alpha) + tr[i] * alpha + + self._last_atr = atr[-1] + return atr + + def get_last_atr(self) -> Optional[float]: + """获取最后一个ATR值""" + return self._last_atr + + def update(self, high: float, low: float, prev_close: float) -> float: + """ + 更新单个ATR值(实盘时增量更新使用) + + Parameters: + high: 当前K线最高价 + low: 当前K线最低价 + prev_close: 前一根K线收盘价 + + Returns: + 新的ATR值 + """ + # 计算当前TR + tr1 = abs(high - low) + tr2 = abs(high - prev_close) + tr3 = abs(low - prev_close) + tr = max(tr1, tr2, tr3) + + if self._last_atr is None: + # 还没有足够数据,直接返回TR + return tr + + if self.method == 'sma': + # SMA不适合增量更新,这里简化处理,用户应该重新计算完整序列 + raise ValueError("SMA方法不支持增量更新,请重新计算完整序列") + else: + # EMA支持增量更新 + alpha = 1.0 / self.period + new_atr = self._last_atr * (1 - alpha) + tr * alpha + self._last_atr = new_atr + return new_atr + + +def calculate_atr( + df: pd.DataFrame, + period: int = 14, + method: str = 'sma', + high_col: str = 'high', + low_col: str = 'low', + close_col: str = 'close', + drop_na: bool = False +) -> pd.DataFrame: + """ + 便捷函数:直接对DataFrame计算ATR并添加到原DataFrame + + Parameters: + df: 输入DataFrame,必须包含high, low, close列 + period: ATR周期,默认14 + method: 计算方法 'sma' 或 'ema',默认'sma' + high_col: 最高价列名,默认'high' + low_col: 最低价列名,默认'low' + close_col: 收盘价列名,默认'close' + drop_na: 是否删除NaN值,默认False + + Returns: + 添加了TR和ATR列的DataFrame + """ + atr_indicator = ATRIndicator(period, method) + + result = df.copy() + result['TR'] = atr_indicator.calculate_tr( + df[high_col], df[low_col], df[close_col] + ) + result[f'ATR_{period}'] = atr_indicator.calculate( + df[high_col], df[low_col], df[close_col] + ) + + if drop_na: + result = result.dropna() + + return result