auto-sync: 2026-05-31 18:33:45
This commit is contained in:
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
RSI (Relative Strength Index) 指标计算函数
|
||||
使用 Wilder 平滑方法
|
||||
|
||||
Author: 张飞 翼德
|
||||
Date: 2026-05-31
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def calc_rsi(closes: List[float], period: int = 14) -> List[float]:
|
||||
"""计算 RSI (Relative Strength Index)。
|
||||
|
||||
使用 Wilder 平滑方法:
|
||||
- 首次 avg_gain / avg_loss = 前 period 个涨跌幅的 SMA
|
||||
- 后续 avg_gain = (prev_avg_gain * (period-1) + current_gain) / period
|
||||
- RSI = 100 - 100 / (1 + avg_gain / avg_loss)
|
||||
|
||||
Args:
|
||||
closes: 收盘价序列
|
||||
period: RSI 周期,默认 14
|
||||
|
||||
Returns:
|
||||
RSI 值列表。数据长度不足 period 时返回空列表。
|
||||
值域 [0, 100]。
|
||||
"""
|
||||
n = len(closes)
|
||||
if n <= period or period <= 0:
|
||||
return []
|
||||
|
||||
# 计算价格变化
|
||||
deltas = [closes[i] - closes[i - 1] for i in range(1, n)]
|
||||
|
||||
# 首次平均涨跌
|
||||
gains = [d if d > 0 else 0.0 for d in deltas[:period]]
|
||||
losses = [-d if d < 0 else 0.0 for d in deltas[:period]]
|
||||
|
||||
avg_gain = sum(gains) / period
|
||||
avg_loss = sum(losses) / period
|
||||
|
||||
# 第一个 RSI(对应 closes[period])
|
||||
result: List[float] = []
|
||||
if avg_loss == 0:
|
||||
result.append(100.0)
|
||||
else:
|
||||
rs = avg_gain / avg_loss
|
||||
result.append(100.0 - 100.0 / (1.0 + rs))
|
||||
|
||||
# 后续用 Wilder 平滑
|
||||
for i in range(period, len(deltas)):
|
||||
d = deltas[i]
|
||||
gain = d if d > 0 else 0.0
|
||||
loss = -d if d < 0 else 0.0
|
||||
avg_gain = (avg_gain * (period - 1) + gain) / period
|
||||
avg_loss = (avg_loss * (period - 1) + loss) / period
|
||||
|
||||
if avg_loss == 0:
|
||||
result.append(100.0)
|
||||
else:
|
||||
rs = avg_gain / avg_loss
|
||||
result.append(100.0 - 100.0 / (1.0 + rs))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------- 单元测试 ----------
|
||||
if __name__ == "__main__":
|
||||
# 基本用例:单边上涨 → RSI 接近 100
|
||||
prices_up = [10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0]
|
||||
rsi = calc_rsi(prices_up, 14)
|
||||
assert len(rsi) == 2 # 16 prices, period=14 → 2 values
|
||||
assert rsi[0] == 100.0 # 全部上涨
|
||||
|
||||
# 单边下跌 → RSI 接近 0
|
||||
prices_down = list(reversed(prices_up))
|
||||
rsi_down = calc_rsi(prices_down, 14)
|
||||
assert rsi_down[0] == 0.0
|
||||
|
||||
# 数据不足 → 空列表
|
||||
assert calc_rsi([1.0, 2.0, 3.0], 14) == []
|
||||
|
||||
# period <= 0
|
||||
assert calc_rsi(prices_up, 0) == []
|
||||
assert calc_rsi(prices_up, -1) == []
|
||||
|
||||
# 刚好 period+1 个数据 → 1 个 RSI
|
||||
rsi_exact = calc_rsi(prices_up[:15], 14)
|
||||
assert len(rsi_exact) == 1
|
||||
|
||||
# 值域检查
|
||||
prices_mixed = [44.0, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42,
|
||||
45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00,
|
||||
46.03, 46.41, 46.22, 45.64, 46.31, 46.23, 46.35, 46.50]
|
||||
rsi_mixed = calc_rsi(prices_mixed, 14)
|
||||
for v in rsi_mixed:
|
||||
assert 0 <= v <= 100
|
||||
|
||||
print("所有测试通过 ✅")
|
||||
Reference in New Issue
Block a user