Files
sanguo_quant_live/indicators/atr.py
T
2026-06-01 00:27:35 +08:00

105 lines
2.9 KiB
Python

"""
ATR (Average True Range) 指标计算函数
使用 Wilder 平滑方法
Author: 张飞 翼德
Date: 2026-06-01
"""
from typing import List
def calc_atr(
highs: List[float],
lows: List[float],
closes: List[float],
period: int = 14,
) -> List[float]:
"""计算 ATR (Average True Range)。
Args:
highs: 最高价序列
lows: 最低价序列
closes: 收盘价序列
period: ATR 周期,默认 14
Returns:
ATR 值列表。数据不足 period+1 时返回空列表。
"""
n = len(closes)
if period <= 0 or n < period + 1:
return []
if n != len(highs) or n != len(lows):
raise ValueError("highs, lows, closes 长度必须一致")
# 计算 True Range(需要前一根收盘价,从 i=1 开始)
tr: List[float] = []
for i in range(1, n):
tr.append(max(
highs[i] - lows[i],
abs(highs[i] - closes[i - 1]),
abs(lows[i] - closes[i - 1]),
))
# 数据不足 period 个 TR
if len(tr) < period:
return []
# 初值:前 period 个 TR 的 SMA
result: List[float] = []
atr = sum(tr[:period]) / period
result.append(atr)
# 后续用 Wilder 平滑
for i in range(period, len(tr)):
atr = (atr * (period - 1) + tr[i]) / period
result.append(atr)
return result
# ---------- 单元测试 ----------
if __name__ == "__main__":
# 基本用例
highs = [12.0, 13.0, 11.0, 14.0, 13.0, 15.0, 14.0, 16.0, 15.0, 14.0, 13.0, 12.0, 14.0, 15.0, 16.0, 14.0]
lows = [10.0, 11.0, 9.0, 12.0, 11.0, 13.0, 12.0, 14.0, 13.0, 12.0, 11.0, 10.0, 12.0, 13.0, 14.0, 12.0]
closes = [11.0, 12.0, 10.0, 13.0, 12.0, 14.0, 13.0, 15.0, 14.0, 13.0, 12.0, 11.0, 13.0, 14.0, 15.0, 13.0]
atr = calc_atr(highs, lows, closes, period=5)
assert len(atr) > 0
# 手动验证第一个 ATR
trs = []
for i in range(1, 6):
trs.append(max(highs[i]-lows[i], abs(highs[i]-closes[i-1]), abs(lows[i]-closes[i-1])))
expected_first = sum(trs) / 5
assert abs(atr[0] - expected_first) < 1e-9
# 验证后续平滑
prev = atr[0]
for i in range(1, len(atr)):
idx = i + 5 # 对应 tr 的索引
tr_val = max(highs[idx]-lows[idx], abs(highs[idx]-closes[idx-1]), abs(lows[idx]-closes[idx-1]))
expected = (prev * 4 + tr_val) / 5
assert abs(atr[i] - expected) < 1e-9
prev = atr[i]
# ATR 值应为正
for v in atr:
assert v > 0
# 数据不足 → 空列表
assert calc_atr([1, 2], [1, 2], [1, 2], period=14) == []
# period+1 刚好
assert len(calc_atr(highs[:6], lows[:6], closes[:6], period=5)) == 1
# period <= 0
assert calc_atr(highs, lows, closes, period=0) == []
assert calc_atr(highs, lows, closes, period=-1) == []
# 空数据
assert calc_atr([], [], [], period=5) == []
print("所有测试通过 ✅")