From 4c6e14f9e8e76f24f4dc69a411835257009856ab Mon Sep 17 00:00:00 2001 From: cfdaily Date: Sun, 31 May 2026 18:33:45 +0800 Subject: [PATCH] auto-sync: 2026-05-31 18:33:45 --- indicators/rsi.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 indicators/rsi.py diff --git a/indicators/rsi.py b/indicators/rsi.py new file mode 100644 index 000000000..63fe83d80 --- /dev/null +++ b/indicators/rsi.py @@ -0,0 +1,99 @@ +""" +RSI (Relative Strength Index) 指标计算函数 +使用 Wilder 平滑方法 + +Author: 张飞 翼德 +Date: 2026-05-31 +""" + +from typing import List + + +def calc_rsi(closes: List[float], period: int = 14) -> List[float]: + """计算 RSI (Relative Strength Index)。 + + 使用 Wilder 平滑方法: + - 首次 avg_gain / avg_loss = 前 period 个涨跌幅的 SMA + - 后续 avg_gain = (prev_avg_gain * (period-1) + current_gain) / period + - RSI = 100 - 100 / (1 + avg_gain / avg_loss) + + Args: + closes: 收盘价序列 + period: RSI 周期,默认 14 + + Returns: + RSI 值列表。数据长度不足 period 时返回空列表。 + 值域 [0, 100]。 + """ + n = len(closes) + if n <= period or period <= 0: + return [] + + # 计算价格变化 + deltas = [closes[i] - closes[i - 1] for i in range(1, n)] + + # 首次平均涨跌 + gains = [d if d > 0 else 0.0 for d in deltas[:period]] + losses = [-d if d < 0 else 0.0 for d in deltas[:period]] + + avg_gain = sum(gains) / period + avg_loss = sum(losses) / period + + # 第一个 RSI(对应 closes[period]) + result: List[float] = [] + if avg_loss == 0: + result.append(100.0) + else: + rs = avg_gain / avg_loss + result.append(100.0 - 100.0 / (1.0 + rs)) + + # 后续用 Wilder 平滑 + for i in range(period, len(deltas)): + d = deltas[i] + gain = d if d > 0 else 0.0 + loss = -d if d < 0 else 0.0 + avg_gain = (avg_gain * (period - 1) + gain) / period + avg_loss = (avg_loss * (period - 1) + loss) / period + + if avg_loss == 0: + result.append(100.0) + else: + rs = avg_gain / avg_loss + result.append(100.0 - 100.0 / (1.0 + rs)) + + return result + + +# ---------- 单元测试 ---------- +if __name__ == "__main__": + # 基本用例:单边上涨 → RSI 接近 100 + prices_up = [10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0] + rsi = calc_rsi(prices_up, 14) + assert len(rsi) == 2 # 16 prices, period=14 → 2 values + assert rsi[0] == 100.0 # 全部上涨 + + # 单边下跌 → RSI 接近 0 + prices_down = list(reversed(prices_up)) + rsi_down = calc_rsi(prices_down, 14) + assert rsi_down[0] == 0.0 + + # 数据不足 → 空列表 + assert calc_rsi([1.0, 2.0, 3.0], 14) == [] + + # period <= 0 + assert calc_rsi(prices_up, 0) == [] + assert calc_rsi(prices_up, -1) == [] + + # 刚好 period+1 个数据 → 1 个 RSI + rsi_exact = calc_rsi(prices_up[:15], 14) + assert len(rsi_exact) == 1 + + # 值域检查 + prices_mixed = [44.0, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42, + 45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00, + 46.03, 46.41, 46.22, 45.64, 46.31, 46.23, 46.35, 46.50] + rsi_mixed = calc_rsi(prices_mixed, 14) + for v in rsi_mixed: + assert 0 <= v <= 100 + + print("所有测试通过 ✅")