""" ATR (Average True Range) 指标计算工具 支持SMA和EMA两种计算方式 Author: 张飞 翼德 Date: 2026-04-17 """ import numpy as np import pandas as pd from typing import Optional, Union class ATRIndicator: """ ATR (Average True Range) 指标计算类 支持两种计算方式: - SMA: 简单移动平均 - EMA: 指数移动平均 """ def __init__(self, period: int = 14, method: str = 'sma'): """ 初始化ATR指标 Parameters: period: int, 默认14 ATR计算周期 method: str, 默认'sma' 计算方法,可选 'sma' 或 'ema' """ if period <= 0: raise ValueError(f"周期必须大于0,当前为 {period}") if method.lower() not in ['sma', 'ema']: raise ValueError(f"计算方法必须是 'sma' 或 'ema',当前为 {method}") self.period = period self.method = method.lower() self._last_atr: Optional[float] = None @staticmethod def calculate_tr( high: Union[pd.Series, np.ndarray], low: Union[pd.Series, np.ndarray], close: Union[pd.Series, np.ndarray] ) -> Union[pd.Series, np.ndarray]: """ 计算True Range (TR) 公式: TR = max[|high - low|, |high - prev_close|, |low - prev_close|] Parameters: high: 最高价序列 low: 最低价序列 close: 收盘价序列 Returns: TR序列,第一个值为NaN """ # 转换为numpy数组便于计算 if isinstance(high, pd.Series): tr_values = ATRIndicator._calculate_tr_numpy( high.values, low.values, close.values ) return pd.Series(tr_values, index=high.index, name='TR') else: return ATRIndicator._calculate_tr_numpy(high, low, close) @staticmethod def _calculate_tr_numpy( high: np.ndarray, low: np.ndarray, close: np.ndarray ) -> np.ndarray: """使用numpy计算TR""" if len(high) != len(low) or len(high) != len(close): raise ValueError("high, low, close序列长度必须一致") n = len(high) tr = np.full(n, np.nan) if n < 2: return tr # 计算三个成分 high_low = high[1:] - low[1:] high_prev_close = np.abs(high[1:] - close[:-1]) low_prev_close = np.abs(low[1:] - close[:-1]) # 取最大值 tr[1:] = np.maximum(np.maximum(high_low, high_prev_close), low_prev_close) return tr def calculate( self, high: Union[pd.Series, np.ndarray], low: Union[pd.Series, np.ndarray], close: Union[pd.Series, np.ndarray] ) -> Union[pd.Series, np.ndarray]: """ 计算ATR指标 Parameters: high: 最高价序列 low: 最低价序列 close: 收盘价序列 Returns: ATR序列 """ # 先计算TR tr = self.calculate_tr(high, low, close) if self.method == 'sma': return self._calculate_sma_atr(tr) else: return self._calculate_ema_atr(tr) def _calculate_sma_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: """使用SMA计算ATR""" if isinstance(tr, pd.Series): atr = tr.rolling(window=self.period, min_periods=self.period).mean() atr.name = f'ATR_{self.period}' if len(atr) > 0: self._last_atr = atr.iloc[-1] return atr else: n = len(tr) atr = np.full(n, np.nan) for i in range(self.period - 1, n): if i >= self.period: atr[i] = np.mean(tr[i - self.period + 1:i + 1]) if n > 0: self._last_atr = atr[-1] return atr def _calculate_ema_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]: """使用EMA计算ATR,Wilder平滑方法""" if isinstance(tr, pd.Series): return self._calculate_ema_atr_pandas(tr) else: return self._calculate_ema_atr_numpy(tr) def _calculate_ema_atr_pandas(self, tr: pd.Series) -> pd.Series: """Pandas版本EMA-ATR计算,使用Wilder平滑""" n = len(tr) atr = pd.Series(np.full(n, np.nan), index=tr.index, name=f'ATR_{self.period}') if n < self.period: self._last_atr = None return atr # 第一个ATR用SMA计算 first_atr = tr.iloc[1:self.period + 1].mean() atr.iloc[self.period] = first_atr self._last_atr = first_atr # Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period alpha = 1.0 / self.period for i in range(self.period + 1, n): if np.isnan(tr.iloc[i]): atr.iloc[i] = atr.iloc[i - 1] else: atr.iloc[i] = atr.iloc[i - 1] * (1 - alpha) + tr.iloc[i] * alpha if len(atr) > 0: self._last_atr = atr.iloc[-1] return atr def _calculate_ema_atr_numpy(self, tr: np.ndarray) -> np.ndarray: """Numpy版本EMA-ATR计算,使用Wilder平滑""" n = len(tr) atr = np.full(n, np.nan) if n < self.period: self._last_atr = None return atr # 第一个ATR用SMA计算 first_atr = np.mean(tr[1:self.period + 1]) atr[self.period] = first_atr self._last_atr = first_atr # Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period alpha = 1.0 / self.period for i in range(self.period + 1, n): if np.isnan(tr[i]): atr[i] = atr[i - 1] else: atr[i] = atr[i - 1] * (1 - alpha) + tr[i] * alpha self._last_atr = atr[-1] return atr def get_last_atr(self) -> Optional[float]: """获取最后一个ATR值""" return self._last_atr def update(self, high: float, low: float, prev_close: float) -> float: """ 更新单个ATR值(实盘时增量更新使用) Parameters: high: 当前K线最高价 low: 当前K线最低价 prev_close: 前一根K线收盘价 Returns: 新的ATR值 """ if self.method == 'sma': # SMA不适合增量更新,这里简化处理,用户应该重新计算完整序列 raise ValueError("SMA方法不支持增量更新,请重新计算完整序列") # 计算当前TR tr1 = abs(high - low) tr2 = abs(high - prev_close) tr3 = abs(low - prev_close) tr = max(tr1, tr2, tr3) if self._last_atr is None: # 还没有足够数据,初始化ATR为TR self._last_atr = tr return tr # EMA支持增量更新 alpha = 1.0 / self.period new_atr = self._last_atr * (1 - alpha) + tr * alpha self._last_atr = new_atr return new_atr def calculate_atr( df: pd.DataFrame, period: int = 14, method: str = 'sma', high_col: str = 'high', low_col: str = 'low', close_col: str = 'close', drop_na: bool = False ) -> pd.DataFrame: """ 便捷函数:直接对DataFrame计算ATR并添加到原DataFrame Parameters: df: 输入DataFrame,必须包含high, low, close列 period: ATR周期,默认14 method: 计算方法 'sma' 或 'ema',默认'sma' high_col: 最高价列名,默认'high' low_col: 最低价列名,默认'low' close_col: 收盘价列名,默认'close' drop_na: 是否删除NaN值,默认False Returns: 添加了TR和ATR列的DataFrame """ atr_indicator = ATRIndicator(period, method) result = df.copy() result['TR'] = atr_indicator.calculate_tr( df[high_col], df[low_col], df[close_col] ) result[f'ATR_{period}'] = atr_indicator.calculate( df[high_col], df[low_col], df[close_col] ) if drop_na: result = result.dropna() return result