""" ATR (Average True Range) 指标计算函数 使用 Wilder 平滑方法 Author: 张飞 翼德 Date: 2026-06-01 """ from typing import List def calc_atr( highs: List[float], lows: List[float], closes: List[float], period: int = 14, ) -> List[float]: """计算 ATR (Average True Range)。 Args: highs: 最高价序列 lows: 最低价序列 closes: 收盘价序列 period: ATR 周期,默认 14 Returns: ATR 值列表。数据不足 period+1 时返回空列表。 """ n = len(closes) if period <= 0 or n < period + 1: return [] if n != len(highs) or n != len(lows): raise ValueError("highs, lows, closes 长度必须一致") # 计算 True Range(需要前一根收盘价,从 i=1 开始) tr: List[float] = [] for i in range(1, n): tr.append(max( highs[i] - lows[i], abs(highs[i] - closes[i - 1]), abs(lows[i] - closes[i - 1]), )) # 数据不足 period 个 TR if len(tr) < period: return [] # 初值:前 period 个 TR 的 SMA result: List[float] = [] atr = sum(tr[:period]) / period result.append(atr) # 后续用 Wilder 平滑 for i in range(period, len(tr)): atr = (atr * (period - 1) + tr[i]) / period result.append(atr) return result # ---------- 单元测试 ---------- if __name__ == "__main__": # 基本用例 highs = [12.0, 13.0, 11.0, 14.0, 13.0, 15.0, 14.0, 16.0, 15.0, 14.0, 13.0, 12.0, 14.0, 15.0, 16.0, 14.0] lows = [10.0, 11.0, 9.0, 12.0, 11.0, 13.0, 12.0, 14.0, 13.0, 12.0, 11.0, 10.0, 12.0, 13.0, 14.0, 12.0] closes = [11.0, 12.0, 10.0, 13.0, 12.0, 14.0, 13.0, 15.0, 14.0, 13.0, 12.0, 11.0, 13.0, 14.0, 15.0, 13.0] atr = calc_atr(highs, lows, closes, period=5) assert len(atr) > 0 # 手动验证第一个 ATR trs = [] for i in range(1, 6): trs.append(max(highs[i]-lows[i], abs(highs[i]-closes[i-1]), abs(lows[i]-closes[i-1]))) expected_first = sum(trs) / 5 assert abs(atr[0] - expected_first) < 1e-9 # 验证后续平滑 prev = atr[0] for i in range(1, len(atr)): idx = i + 5 # 对应 tr 的索引 tr_val = max(highs[idx]-lows[idx], abs(highs[idx]-closes[idx-1]), abs(lows[idx]-closes[idx-1])) expected = (prev * 4 + tr_val) / 5 assert abs(atr[i] - expected) < 1e-9 prev = atr[i] # ATR 值应为正 for v in atr: assert v > 0 # 数据不足 → 空列表 assert calc_atr([1, 2], [1, 2], [1, 2], period=14) == [] # period+1 刚好 assert len(calc_atr(highs[:6], lows[:6], closes[:6], period=5)) == 1 # period <= 0 assert calc_atr(highs, lows, closes, period=0) == [] assert calc_atr(highs, lows, closes, period=-1) == [] # 空数据 assert calc_atr([], [], [], period=5) == [] print("所有测试通过 ✅")