Files
2026-05-31 18:33:45 +08:00

100 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
RSI (Relative Strength Index) 指标计算函数
使用 Wilder 平滑方法
Author: 张飞 翼德
Date: 2026-05-31
"""
from typing import List
def calc_rsi(closes: List[float], period: int = 14) -> List[float]:
"""计算 RSI (Relative Strength Index)。
使用 Wilder 平滑方法:
- 首次 avg_gain / avg_loss = 前 period 个涨跌幅的 SMA
- 后续 avg_gain = (prev_avg_gain * (period-1) + current_gain) / period
- RSI = 100 - 100 / (1 + avg_gain / avg_loss)
Args:
closes: 收盘价序列
period: RSI 周期,默认 14
Returns:
RSI 值列表。数据长度不足 period 时返回空列表。
值域 [0, 100]。
"""
n = len(closes)
if n <= period or period <= 0:
return []
# 计算价格变化
deltas = [closes[i] - closes[i - 1] for i in range(1, n)]
# 首次平均涨跌
gains = [d if d > 0 else 0.0 for d in deltas[:period]]
losses = [-d if d < 0 else 0.0 for d in deltas[:period]]
avg_gain = sum(gains) / period
avg_loss = sum(losses) / period
# 第一个 RSI(对应 closes[period]
result: List[float] = []
if avg_loss == 0:
result.append(100.0)
else:
rs = avg_gain / avg_loss
result.append(100.0 - 100.0 / (1.0 + rs))
# 后续用 Wilder 平滑
for i in range(period, len(deltas)):
d = deltas[i]
gain = d if d > 0 else 0.0
loss = -d if d < 0 else 0.0
avg_gain = (avg_gain * (period - 1) + gain) / period
avg_loss = (avg_loss * (period - 1) + loss) / period
if avg_loss == 0:
result.append(100.0)
else:
rs = avg_gain / avg_loss
result.append(100.0 - 100.0 / (1.0 + rs))
return result
# ---------- 单元测试 ----------
if __name__ == "__main__":
# 基本用例:单边上涨 → RSI 接近 100
prices_up = [10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0]
rsi = calc_rsi(prices_up, 14)
assert len(rsi) == 2 # 16 prices, period=14 → 2 values
assert rsi[0] == 100.0 # 全部上涨
# 单边下跌 → RSI 接近 0
prices_down = list(reversed(prices_up))
rsi_down = calc_rsi(prices_down, 14)
assert rsi_down[0] == 0.0
# 数据不足 → 空列表
assert calc_rsi([1.0, 2.0, 3.0], 14) == []
# period <= 0
assert calc_rsi(prices_up, 0) == []
assert calc_rsi(prices_up, -1) == []
# 刚好 period+1 个数据 → 1 个 RSI
rsi_exact = calc_rsi(prices_up[:15], 14)
assert len(rsi_exact) == 1
# 值域检查
prices_mixed = [44.0, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42,
45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00,
46.03, 46.41, 46.22, 45.64, 46.31, 46.23, 46.35, 46.50]
rsi_mixed = calc_rsi(prices_mixed, 14)
for v in rsi_mixed:
assert 0 <= v <= 100
print("所有测试通过 ✅")