100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
"""
|
||
RSI (Relative Strength Index) 指标计算函数
|
||
使用 Wilder 平滑方法
|
||
|
||
Author: 张飞 翼德
|
||
Date: 2026-05-31
|
||
"""
|
||
|
||
from typing import List
|
||
|
||
|
||
def calc_rsi(closes: List[float], period: int = 14) -> List[float]:
|
||
"""计算 RSI (Relative Strength Index)。
|
||
|
||
使用 Wilder 平滑方法:
|
||
- 首次 avg_gain / avg_loss = 前 period 个涨跌幅的 SMA
|
||
- 后续 avg_gain = (prev_avg_gain * (period-1) + current_gain) / period
|
||
- RSI = 100 - 100 / (1 + avg_gain / avg_loss)
|
||
|
||
Args:
|
||
closes: 收盘价序列
|
||
period: RSI 周期,默认 14
|
||
|
||
Returns:
|
||
RSI 值列表。数据长度不足 period 时返回空列表。
|
||
值域 [0, 100]。
|
||
"""
|
||
n = len(closes)
|
||
if n <= period or period <= 0:
|
||
return []
|
||
|
||
# 计算价格变化
|
||
deltas = [closes[i] - closes[i - 1] for i in range(1, n)]
|
||
|
||
# 首次平均涨跌
|
||
gains = [d if d > 0 else 0.0 for d in deltas[:period]]
|
||
losses = [-d if d < 0 else 0.0 for d in deltas[:period]]
|
||
|
||
avg_gain = sum(gains) / period
|
||
avg_loss = sum(losses) / period
|
||
|
||
# 第一个 RSI(对应 closes[period])
|
||
result: List[float] = []
|
||
if avg_loss == 0:
|
||
result.append(100.0)
|
||
else:
|
||
rs = avg_gain / avg_loss
|
||
result.append(100.0 - 100.0 / (1.0 + rs))
|
||
|
||
# 后续用 Wilder 平滑
|
||
for i in range(period, len(deltas)):
|
||
d = deltas[i]
|
||
gain = d if d > 0 else 0.0
|
||
loss = -d if d < 0 else 0.0
|
||
avg_gain = (avg_gain * (period - 1) + gain) / period
|
||
avg_loss = (avg_loss * (period - 1) + loss) / period
|
||
|
||
if avg_loss == 0:
|
||
result.append(100.0)
|
||
else:
|
||
rs = avg_gain / avg_loss
|
||
result.append(100.0 - 100.0 / (1.0 + rs))
|
||
|
||
return result
|
||
|
||
|
||
# ---------- 单元测试 ----------
|
||
if __name__ == "__main__":
|
||
# 基本用例:单边上涨 → RSI 接近 100
|
||
prices_up = [10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0]
|
||
rsi = calc_rsi(prices_up, 14)
|
||
assert len(rsi) == 2 # 16 prices, period=14 → 2 values
|
||
assert rsi[0] == 100.0 # 全部上涨
|
||
|
||
# 单边下跌 → RSI 接近 0
|
||
prices_down = list(reversed(prices_up))
|
||
rsi_down = calc_rsi(prices_down, 14)
|
||
assert rsi_down[0] == 0.0
|
||
|
||
# 数据不足 → 空列表
|
||
assert calc_rsi([1.0, 2.0, 3.0], 14) == []
|
||
|
||
# period <= 0
|
||
assert calc_rsi(prices_up, 0) == []
|
||
assert calc_rsi(prices_up, -1) == []
|
||
|
||
# 刚好 period+1 个数据 → 1 个 RSI
|
||
rsi_exact = calc_rsi(prices_up[:15], 14)
|
||
assert len(rsi_exact) == 1
|
||
|
||
# 值域检查
|
||
prices_mixed = [44.0, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42,
|
||
45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00,
|
||
46.03, 46.41, 46.22, 45.64, 46.31, 46.23, 46.35, 46.50]
|
||
rsi_mixed = calc_rsi(prices_mixed, 14)
|
||
for v in rsi_mixed:
|
||
assert 0 <= v <= 100
|
||
|
||
print("所有测试通过 ✅")
|