105 lines
2.9 KiB
Python
105 lines
2.9 KiB
Python
"""
|
|
ATR (Average True Range) 指标计算函数
|
|
使用 Wilder 平滑方法
|
|
|
|
Author: 张飞 翼德
|
|
Date: 2026-06-01
|
|
"""
|
|
|
|
from typing import List
|
|
|
|
|
|
def calc_atr(
|
|
highs: List[float],
|
|
lows: List[float],
|
|
closes: List[float],
|
|
period: int = 14,
|
|
) -> List[float]:
|
|
"""计算 ATR (Average True Range)。
|
|
|
|
Args:
|
|
highs: 最高价序列
|
|
lows: 最低价序列
|
|
closes: 收盘价序列
|
|
period: ATR 周期,默认 14
|
|
|
|
Returns:
|
|
ATR 值列表。数据不足 period+1 时返回空列表。
|
|
"""
|
|
n = len(closes)
|
|
if period <= 0 or n < period + 1:
|
|
return []
|
|
if n != len(highs) or n != len(lows):
|
|
raise ValueError("highs, lows, closes 长度必须一致")
|
|
|
|
# 计算 True Range(需要前一根收盘价,从 i=1 开始)
|
|
tr: List[float] = []
|
|
for i in range(1, n):
|
|
tr.append(max(
|
|
highs[i] - lows[i],
|
|
abs(highs[i] - closes[i - 1]),
|
|
abs(lows[i] - closes[i - 1]),
|
|
))
|
|
|
|
# 数据不足 period 个 TR
|
|
if len(tr) < period:
|
|
return []
|
|
|
|
# 初值:前 period 个 TR 的 SMA
|
|
result: List[float] = []
|
|
atr = sum(tr[:period]) / period
|
|
result.append(atr)
|
|
|
|
# 后续用 Wilder 平滑
|
|
for i in range(period, len(tr)):
|
|
atr = (atr * (period - 1) + tr[i]) / period
|
|
result.append(atr)
|
|
|
|
return result
|
|
|
|
|
|
# ---------- 单元测试 ----------
|
|
if __name__ == "__main__":
|
|
# 基本用例
|
|
highs = [12.0, 13.0, 11.0, 14.0, 13.0, 15.0, 14.0, 16.0, 15.0, 14.0, 13.0, 12.0, 14.0, 15.0, 16.0, 14.0]
|
|
lows = [10.0, 11.0, 9.0, 12.0, 11.0, 13.0, 12.0, 14.0, 13.0, 12.0, 11.0, 10.0, 12.0, 13.0, 14.0, 12.0]
|
|
closes = [11.0, 12.0, 10.0, 13.0, 12.0, 14.0, 13.0, 15.0, 14.0, 13.0, 12.0, 11.0, 13.0, 14.0, 15.0, 13.0]
|
|
|
|
atr = calc_atr(highs, lows, closes, period=5)
|
|
assert len(atr) > 0
|
|
|
|
# 手动验证第一个 ATR
|
|
trs = []
|
|
for i in range(1, 6):
|
|
trs.append(max(highs[i]-lows[i], abs(highs[i]-closes[i-1]), abs(lows[i]-closes[i-1])))
|
|
expected_first = sum(trs) / 5
|
|
assert abs(atr[0] - expected_first) < 1e-9
|
|
|
|
# 验证后续平滑
|
|
prev = atr[0]
|
|
for i in range(1, len(atr)):
|
|
idx = i + 5 # 对应 tr 的索引
|
|
tr_val = max(highs[idx]-lows[idx], abs(highs[idx]-closes[idx-1]), abs(lows[idx]-closes[idx-1]))
|
|
expected = (prev * 4 + tr_val) / 5
|
|
assert abs(atr[i] - expected) < 1e-9
|
|
prev = atr[i]
|
|
|
|
# ATR 值应为正
|
|
for v in atr:
|
|
assert v > 0
|
|
|
|
# 数据不足 → 空列表
|
|
assert calc_atr([1, 2], [1, 2], [1, 2], period=14) == []
|
|
|
|
# period+1 刚好
|
|
assert len(calc_atr(highs[:6], lows[:6], closes[:6], period=5)) == 1
|
|
|
|
# period <= 0
|
|
assert calc_atr(highs, lows, closes, period=0) == []
|
|
assert calc_atr(highs, lows, closes, period=-1) == []
|
|
|
|
# 空数据
|
|
assert calc_atr([], [], [], period=5) == []
|
|
|
|
print("所有测试通过 ✅")
|