diff --git a/indicators/atr.py b/indicators/atr.py new file mode 100644 index 000000000..31f6ba472 --- /dev/null +++ b/indicators/atr.py @@ -0,0 +1,104 @@ +""" +ATR (Average True Range) 指标计算函数 +使用 Wilder 平滑方法 + +Author: 张飞 翼德 +Date: 2026-06-01 +""" + +from typing import List + + +def calc_atr( + highs: List[float], + lows: List[float], + closes: List[float], + period: int = 14, +) -> List[float]: + """计算 ATR (Average True Range)。 + + Args: + highs: 最高价序列 + lows: 最低价序列 + closes: 收盘价序列 + period: ATR 周期,默认 14 + + Returns: + ATR 值列表。数据不足 period+1 时返回空列表。 + """ + n = len(closes) + if period <= 0 or n < period + 1: + return [] + if n != len(highs) or n != len(lows): + raise ValueError("highs, lows, closes 长度必须一致") + + # 计算 True Range(需要前一根收盘价,从 i=1 开始) + tr: List[float] = [] + for i in range(1, n): + tr.append(max( + highs[i] - lows[i], + abs(highs[i] - closes[i - 1]), + abs(lows[i] - closes[i - 1]), + )) + + # 数据不足 period 个 TR + if len(tr) < period: + return [] + + # 初值:前 period 个 TR 的 SMA + result: List[float] = [] + atr = sum(tr[:period]) / period + result.append(atr) + + # 后续用 Wilder 平滑 + for i in range(period, len(tr)): + atr = (atr * (period - 1) + tr[i]) / period + result.append(atr) + + return result + + +# ---------- 单元测试 ---------- +if __name__ == "__main__": + # 基本用例 + highs = [12.0, 13.0, 11.0, 14.0, 13.0, 15.0, 14.0, 16.0, 15.0, 14.0, 13.0, 12.0, 14.0, 15.0, 16.0, 14.0] + lows = [10.0, 11.0, 9.0, 12.0, 11.0, 13.0, 12.0, 14.0, 13.0, 12.0, 11.0, 10.0, 12.0, 13.0, 14.0, 12.0] + closes = [11.0, 12.0, 10.0, 13.0, 12.0, 14.0, 13.0, 15.0, 14.0, 13.0, 12.0, 11.0, 13.0, 14.0, 15.0, 13.0] + + atr = calc_atr(highs, lows, closes, period=5) + assert len(atr) > 0 + + # 手动验证第一个 ATR + trs = [] + for i in range(1, 6): + trs.append(max(highs[i]-lows[i], abs(highs[i]-closes[i-1]), abs(lows[i]-closes[i-1]))) + expected_first = sum(trs) / 5 + assert abs(atr[0] - expected_first) < 1e-9 + + # 验证后续平滑 + prev = atr[0] + for i in range(1, len(atr)): + idx = i + 5 # 对应 tr 的索引 + tr_val = max(highs[idx]-lows[idx], abs(highs[idx]-closes[idx-1]), abs(lows[idx]-closes[idx-1])) + expected = (prev * 4 + tr_val) / 5 + assert abs(atr[i] - expected) < 1e-9 + prev = atr[i] + + # ATR 值应为正 + for v in atr: + assert v > 0 + + # 数据不足 → 空列表 + assert calc_atr([1, 2], [1, 2], [1, 2], period=14) == [] + + # period+1 刚好 + assert len(calc_atr(highs[:6], lows[:6], closes[:6], period=5)) == 1 + + # period <= 0 + assert calc_atr(highs, lows, closes, period=0) == [] + assert calc_atr(highs, lows, closes, period=-1) == [] + + # 空数据 + assert calc_atr([], [], [], period=5) == [] + + print("所有测试通过 ✅")