auto-sync: 2026-04-17 20:23:21

This commit is contained in:
cfdaily
2026-04-17 20:23:21 +08:00
parent e74b4f41e8
commit 5cb2b13b0c
@@ -0,0 +1,275 @@
"""
ATR (Average True Range) 指标计算工具
支持SMA和EMA两种计算方式
Author: 张飞 翼德
Date: 2026-04-17
"""
import numpy as np
import pandas as pd
from typing import Optional, Union
class ATRIndicator:
"""
ATR (Average True Range) 指标计算类
支持两种计算方式:
- SMA: 简单移动平均
- EMA: 指数移动平均
"""
def __init__(self, period: int = 14, method: str = 'sma'):
"""
初始化ATR指标
Parameters:
period: int, 默认14
ATR计算周期
method: str, 默认'sma'
计算方法,可选 'sma''ema'
"""
if period <= 0:
raise ValueError(f"周期必须大于0,当前为 {period}")
if method.lower() not in ['sma', 'ema']:
raise ValueError(f"计算方法必须是 'sma''ema',当前为 {method}")
self.period = period
self.method = method.lower()
self._last_atr: Optional[float] = None
@staticmethod
def calculate_tr(
high: Union[pd.Series, np.ndarray],
low: Union[pd.Series, np.ndarray],
close: Union[pd.Series, np.ndarray]
) -> Union[pd.Series, np.ndarray]:
"""
计算True Range (TR)
公式: TR = max[|high - low|, |high - prev_close|, |low - prev_close|]
Parameters:
high: 最高价序列
low: 最低价序列
close: 收盘价序列
Returns:
TR序列,第一个值为NaN
"""
# 转换为numpy数组便于计算
if isinstance(high, pd.Series):
tr_values = ATRIndicator._calculate_tr_numpy(
high.values, low.values, close.values
)
return pd.Series(tr_values, index=high.index, name='TR')
else:
return ATRIndicator._calculate_tr_numpy(high, low, close)
@staticmethod
def _calculate_tr_numpy(
high: np.ndarray,
low: np.ndarray,
close: np.ndarray
) -> np.ndarray:
"""使用numpy计算TR"""
if len(high) != len(low) or len(high) != len(close):
raise ValueError("high, low, close序列长度必须一致")
n = len(high)
tr = np.full(n, np.nan)
if n < 2:
return tr
# 计算三个成分
high_low = high[1:] - low[1:]
high_prev_close = np.abs(high[1:] - close[:-1])
low_prev_close = np.abs(low[1:] - close[:-1])
# 取最大值
tr[1:] = np.maximum(np.maximum(high_low, high_prev_close), low_prev_close)
return tr
def calculate(
self,
high: Union[pd.Series, np.ndarray],
low: Union[pd.Series, np.ndarray],
close: Union[pd.Series, np.ndarray]
) -> Union[pd.Series, np.ndarray]:
"""
计算ATR指标
Parameters:
high: 最高价序列
low: 最低价序列
close: 收盘价序列
Returns:
ATR序列
"""
# 先计算TR
tr = self.calculate_tr(high, low, close)
if self.method == 'sma':
return self._calculate_sma_atr(tr)
else:
return self._calculate_ema_atr(tr)
def _calculate_sma_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
"""使用SMA计算ATR"""
if isinstance(tr, pd.Series):
atr = tr.rolling(window=self.period, min_periods=self.period).mean()
atr.name = f'ATR_{self.period}'
if len(atr) > 0:
self._last_atr = atr.iloc[-1]
return atr
else:
n = len(tr)
atr = np.full(n, np.nan)
for i in range(self.period - 1, n):
if i >= self.period:
atr[i] = np.mean(tr[i - self.period + 1:i + 1])
if n > 0:
self._last_atr = atr[-1]
return atr
def _calculate_ema_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
"""使用EMA计算ATRWilder平滑方法"""
if isinstance(tr, pd.Series):
return self._calculate_ema_atr_pandas(tr)
else:
return self._calculate_ema_atr_numpy(tr)
def _calculate_ema_atr_pandas(self, tr: pd.Series) -> pd.Series:
"""Pandas版本EMA-ATR计算,使用Wilder平滑"""
n = len(tr)
atr = pd.Series(np.full(n, np.nan), index=tr.index, name=f'ATR_{self.period}')
if n < self.period:
self._last_atr = None
return atr
# 第一个ATR用SMA计算
first_atr = tr.iloc[1:self.period + 1].mean()
atr.iloc[self.period] = first_atr
self._last_atr = first_atr
# Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period
alpha = 1.0 / self.period
for i in range(self.period + 1, n):
if np.isnan(tr.iloc[i]):
atr.iloc[i] = atr.iloc[i - 1]
else:
atr.iloc[i] = atr.iloc[i - 1] * (1 - alpha) + tr.iloc[i] * alpha
if len(atr) > 0:
self._last_atr = atr.iloc[-1]
return atr
def _calculate_ema_atr_numpy(self, tr: np.ndarray) -> np.ndarray:
"""Numpy版本EMA-ATR计算,使用Wilder平滑"""
n = len(tr)
atr = np.full(n, np.nan)
if n < self.period:
self._last_atr = None
return atr
# 第一个ATR用SMA计算
first_atr = np.mean(tr[1:self.period + 1])
atr[self.period] = first_atr
self._last_atr = first_atr
# Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period
alpha = 1.0 / self.period
for i in range(self.period + 1, n):
if np.isnan(tr[i]):
atr[i] = atr[i - 1]
else:
atr[i] = atr[i - 1] * (1 - alpha) + tr[i] * alpha
self._last_atr = atr[-1]
return atr
def get_last_atr(self) -> Optional[float]:
"""获取最后一个ATR值"""
return self._last_atr
def update(self, high: float, low: float, prev_close: float) -> float:
"""
更新单个ATR值(实盘时增量更新使用)
Parameters:
high: 当前K线最高价
low: 当前K线最低价
prev_close: 前一根K线收盘价
Returns:
新的ATR值
"""
# 计算当前TR
tr1 = abs(high - low)
tr2 = abs(high - prev_close)
tr3 = abs(low - prev_close)
tr = max(tr1, tr2, tr3)
if self._last_atr is None:
# 还没有足够数据,直接返回TR
return tr
if self.method == 'sma':
# SMA不适合增量更新,这里简化处理,用户应该重新计算完整序列
raise ValueError("SMA方法不支持增量更新,请重新计算完整序列")
else:
# EMA支持增量更新
alpha = 1.0 / self.period
new_atr = self._last_atr * (1 - alpha) + tr * alpha
self._last_atr = new_atr
return new_atr
def calculate_atr(
df: pd.DataFrame,
period: int = 14,
method: str = 'sma',
high_col: str = 'high',
low_col: str = 'low',
close_col: str = 'close',
drop_na: bool = False
) -> pd.DataFrame:
"""
便捷函数:直接对DataFrame计算ATR并添加到原DataFrame
Parameters:
df: 输入DataFrame,必须包含high, low, close列
period: ATR周期,默认14
method: 计算方法 'sma''ema',默认'sma'
high_col: 最高价列名,默认'high'
low_col: 最低价列名,默认'low'
close_col: 收盘价列名,默认'close'
drop_na: 是否删除NaN值,默认False
Returns:
添加了TR和ATR列的DataFrame
"""
atr_indicator = ATRIndicator(period, method)
result = df.copy()
result['TR'] = atr_indicator.calculate_tr(
df[high_col], df[low_col], df[close_col]
)
result[f'ATR_{period}'] = atr_indicator.calculate(
df[high_col], df[low_col], df[close_col]
)
if drop_na:
result = result.dropna()
return result