Files
2026-04-17 20:24:28 +08:00

277 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ATR (Average True Range) 指标计算工具
支持SMA和EMA两种计算方式
Author: 张飞 翼德
Date: 2026-04-17
"""
import numpy as np
import pandas as pd
from typing import Optional, Union
class ATRIndicator:
"""
ATR (Average True Range) 指标计算类
支持两种计算方式:
- SMA: 简单移动平均
- EMA: 指数移动平均
"""
def __init__(self, period: int = 14, method: str = 'sma'):
"""
初始化ATR指标
Parameters:
period: int, 默认14
ATR计算周期
method: str, 默认'sma'
计算方法,可选 'sma''ema'
"""
if period <= 0:
raise ValueError(f"周期必须大于0,当前为 {period}")
if method.lower() not in ['sma', 'ema']:
raise ValueError(f"计算方法必须是 'sma''ema',当前为 {method}")
self.period = period
self.method = method.lower()
self._last_atr: Optional[float] = None
@staticmethod
def calculate_tr(
high: Union[pd.Series, np.ndarray],
low: Union[pd.Series, np.ndarray],
close: Union[pd.Series, np.ndarray]
) -> Union[pd.Series, np.ndarray]:
"""
计算True Range (TR)
公式: TR = max[|high - low|, |high - prev_close|, |low - prev_close|]
Parameters:
high: 最高价序列
low: 最低价序列
close: 收盘价序列
Returns:
TR序列,第一个值为NaN
"""
# 转换为numpy数组便于计算
if isinstance(high, pd.Series):
tr_values = ATRIndicator._calculate_tr_numpy(
high.values, low.values, close.values
)
return pd.Series(tr_values, index=high.index, name='TR')
else:
return ATRIndicator._calculate_tr_numpy(high, low, close)
@staticmethod
def _calculate_tr_numpy(
high: np.ndarray,
low: np.ndarray,
close: np.ndarray
) -> np.ndarray:
"""使用numpy计算TR"""
if len(high) != len(low) or len(high) != len(close):
raise ValueError("high, low, close序列长度必须一致")
n = len(high)
tr = np.full(n, np.nan)
if n < 2:
return tr
# 计算三个成分
high_low = high[1:] - low[1:]
high_prev_close = np.abs(high[1:] - close[:-1])
low_prev_close = np.abs(low[1:] - close[:-1])
# 取最大值
tr[1:] = np.maximum(np.maximum(high_low, high_prev_close), low_prev_close)
return tr
def calculate(
self,
high: Union[pd.Series, np.ndarray],
low: Union[pd.Series, np.ndarray],
close: Union[pd.Series, np.ndarray]
) -> Union[pd.Series, np.ndarray]:
"""
计算ATR指标
Parameters:
high: 最高价序列
low: 最低价序列
close: 收盘价序列
Returns:
ATR序列
"""
# 先计算TR
tr = self.calculate_tr(high, low, close)
if self.method == 'sma':
return self._calculate_sma_atr(tr)
else:
return self._calculate_ema_atr(tr)
def _calculate_sma_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
"""使用SMA计算ATR"""
if isinstance(tr, pd.Series):
atr = tr.rolling(window=self.period, min_periods=self.period).mean()
atr.name = f'ATR_{self.period}'
if len(atr) > 0:
self._last_atr = atr.iloc[-1]
return atr
else:
n = len(tr)
atr = np.full(n, np.nan)
for i in range(self.period - 1, n):
if i >= self.period:
atr[i] = np.mean(tr[i - self.period + 1:i + 1])
if n > 0:
self._last_atr = atr[-1]
return atr
def _calculate_ema_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
"""使用EMA计算ATRWilder平滑方法"""
if isinstance(tr, pd.Series):
return self._calculate_ema_atr_pandas(tr)
else:
return self._calculate_ema_atr_numpy(tr)
def _calculate_ema_atr_pandas(self, tr: pd.Series) -> pd.Series:
"""Pandas版本EMA-ATR计算,使用Wilder平滑"""
n = len(tr)
atr = pd.Series(np.full(n, np.nan), index=tr.index, name=f'ATR_{self.period}')
if n < self.period:
self._last_atr = None
return atr
# 第一个ATR用SMA计算
first_atr = tr.iloc[1:self.period + 1].mean()
atr.iloc[self.period] = first_atr
self._last_atr = first_atr
# Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period
alpha = 1.0 / self.period
for i in range(self.period + 1, n):
if np.isnan(tr.iloc[i]):
atr.iloc[i] = atr.iloc[i - 1]
else:
atr.iloc[i] = atr.iloc[i - 1] * (1 - alpha) + tr.iloc[i] * alpha
if len(atr) > 0:
self._last_atr = atr.iloc[-1]
return atr
def _calculate_ema_atr_numpy(self, tr: np.ndarray) -> np.ndarray:
"""Numpy版本EMA-ATR计算,使用Wilder平滑"""
n = len(tr)
atr = np.full(n, np.nan)
if n < self.period:
self._last_atr = None
return atr
# 第一个ATR用SMA计算
first_atr = np.mean(tr[1:self.period + 1])
atr[self.period] = first_atr
self._last_atr = first_atr
# Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period
alpha = 1.0 / self.period
for i in range(self.period + 1, n):
if np.isnan(tr[i]):
atr[i] = atr[i - 1]
else:
atr[i] = atr[i - 1] * (1 - alpha) + tr[i] * alpha
self._last_atr = atr[-1]
return atr
def get_last_atr(self) -> Optional[float]:
"""获取最后一个ATR值"""
return self._last_atr
def update(self, high: float, low: float, prev_close: float) -> float:
"""
更新单个ATR值(实盘时增量更新使用)
Parameters:
high: 当前K线最高价
low: 当前K线最低价
prev_close: 前一根K线收盘价
Returns:
新的ATR值
"""
if self.method == 'sma':
# SMA不适合增量更新,这里简化处理,用户应该重新计算完整序列
raise ValueError("SMA方法不支持增量更新,请重新计算完整序列")
# 计算当前TR
tr1 = abs(high - low)
tr2 = abs(high - prev_close)
tr3 = abs(low - prev_close)
tr = max(tr1, tr2, tr3)
if self._last_atr is None:
# 还没有足够数据,初始化ATR为TR
self._last_atr = tr
return tr
# EMA支持增量更新
alpha = 1.0 / self.period
new_atr = self._last_atr * (1 - alpha) + tr * alpha
self._last_atr = new_atr
return new_atr
def calculate_atr(
df: pd.DataFrame,
period: int = 14,
method: str = 'sma',
high_col: str = 'high',
low_col: str = 'low',
close_col: str = 'close',
drop_na: bool = False
) -> pd.DataFrame:
"""
便捷函数:直接对DataFrame计算ATR并添加到原DataFrame
Parameters:
df: 输入DataFrame,必须包含high, low, close列
period: ATR周期,默认14
method: 计算方法 'sma''ema',默认'sma'
high_col: 最高价列名,默认'high'
low_col: 最低价列名,默认'low'
close_col: 收盘价列名,默认'close'
drop_na: 是否删除NaN值,默认False
Returns:
添加了TR和ATR列的DataFrame
"""
atr_indicator = ATRIndicator(period, method)
result = df.copy()
result['TR'] = atr_indicator.calculate_tr(
df[high_col], df[low_col], df[close_col]
)
result[f'ATR_{period}'] = atr_indicator.calculate(
df[high_col], df[low_col], df[close_col]
)
if drop_na:
result = result.dropna()
return result