""" RSI (Relative Strength Index) 指标计算函数 使用 Wilder 平滑方法 Author: 张飞 翼德 Date: 2026-05-31 """ from typing import List def calc_rsi(closes: List[float], period: int = 14) -> List[float]: """计算 RSI (Relative Strength Index)。 使用 Wilder 平滑方法: - 首次 avg_gain / avg_loss = 前 period 个涨跌幅的 SMA - 后续 avg_gain = (prev_avg_gain * (period-1) + current_gain) / period - RSI = 100 - 100 / (1 + avg_gain / avg_loss) Args: closes: 收盘价序列 period: RSI 周期,默认 14 Returns: RSI 值列表。数据长度不足 period 时返回空列表。 值域 [0, 100]。 """ n = len(closes) if n <= period or period <= 0: return [] # 计算价格变化 deltas = [closes[i] - closes[i - 1] for i in range(1, n)] # 首次平均涨跌 gains = [d if d > 0 else 0.0 for d in deltas[:period]] losses = [-d if d < 0 else 0.0 for d in deltas[:period]] avg_gain = sum(gains) / period avg_loss = sum(losses) / period # 第一个 RSI(对应 closes[period]) result: List[float] = [] if avg_loss == 0: result.append(100.0) else: rs = avg_gain / avg_loss result.append(100.0 - 100.0 / (1.0 + rs)) # 后续用 Wilder 平滑 for i in range(period, len(deltas)): d = deltas[i] gain = d if d > 0 else 0.0 loss = -d if d < 0 else 0.0 avg_gain = (avg_gain * (period - 1) + gain) / period avg_loss = (avg_loss * (period - 1) + loss) / period if avg_loss == 0: result.append(100.0) else: rs = avg_gain / avg_loss result.append(100.0 - 100.0 / (1.0 + rs)) return result # ---------- 单元测试 ---------- if __name__ == "__main__": # 基本用例:单边上涨 → RSI 接近 100 prices_up = [10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0] rsi = calc_rsi(prices_up, 14) assert len(rsi) == 2 # 16 prices, period=14 → 2 values assert rsi[0] == 100.0 # 全部上涨 # 单边下跌 → RSI 接近 0 prices_down = list(reversed(prices_up)) rsi_down = calc_rsi(prices_down, 14) assert rsi_down[0] == 0.0 # 数据不足 → 空列表 assert calc_rsi([1.0, 2.0, 3.0], 14) == [] # period <= 0 assert calc_rsi(prices_up, 0) == [] assert calc_rsi(prices_up, -1) == [] # 刚好 period+1 个数据 → 1 个 RSI rsi_exact = calc_rsi(prices_up[:15], 14) assert len(rsi_exact) == 1 # 值域检查 prices_mixed = [44.0, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00, 46.03, 46.41, 46.22, 45.64, 46.31, 46.23, 46.35, 46.50] rsi_mixed = calc_rsi(prices_mixed, 14) for v in rsi_mixed: assert 0 <= v <= 100 print("所有测试通过 ✅")