277 lines
8.4 KiB
Python
277 lines
8.4 KiB
Python
"""
|
||
ATR (Average True Range) 指标计算工具
|
||
支持SMA和EMA两种计算方式
|
||
|
||
Author: 张飞 翼德
|
||
Date: 2026-04-17
|
||
"""
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from typing import Optional, Union
|
||
|
||
|
||
class ATRIndicator:
|
||
"""
|
||
ATR (Average True Range) 指标计算类
|
||
|
||
支持两种计算方式:
|
||
- SMA: 简单移动平均
|
||
- EMA: 指数移动平均
|
||
"""
|
||
|
||
def __init__(self, period: int = 14, method: str = 'sma'):
|
||
"""
|
||
初始化ATR指标
|
||
|
||
Parameters:
|
||
period: int, 默认14
|
||
ATR计算周期
|
||
method: str, 默认'sma'
|
||
计算方法,可选 'sma' 或 'ema'
|
||
"""
|
||
if period <= 0:
|
||
raise ValueError(f"周期必须大于0,当前为 {period}")
|
||
|
||
if method.lower() not in ['sma', 'ema']:
|
||
raise ValueError(f"计算方法必须是 'sma' 或 'ema',当前为 {method}")
|
||
|
||
self.period = period
|
||
self.method = method.lower()
|
||
self._last_atr: Optional[float] = None
|
||
|
||
@staticmethod
|
||
def calculate_tr(
|
||
high: Union[pd.Series, np.ndarray],
|
||
low: Union[pd.Series, np.ndarray],
|
||
close: Union[pd.Series, np.ndarray]
|
||
) -> Union[pd.Series, np.ndarray]:
|
||
"""
|
||
计算True Range (TR)
|
||
|
||
公式: TR = max[|high - low|, |high - prev_close|, |low - prev_close|]
|
||
|
||
Parameters:
|
||
high: 最高价序列
|
||
low: 最低价序列
|
||
close: 收盘价序列
|
||
|
||
Returns:
|
||
TR序列,第一个值为NaN
|
||
"""
|
||
# 转换为numpy数组便于计算
|
||
if isinstance(high, pd.Series):
|
||
tr_values = ATRIndicator._calculate_tr_numpy(
|
||
high.values, low.values, close.values
|
||
)
|
||
return pd.Series(tr_values, index=high.index, name='TR')
|
||
else:
|
||
return ATRIndicator._calculate_tr_numpy(high, low, close)
|
||
|
||
@staticmethod
|
||
def _calculate_tr_numpy(
|
||
high: np.ndarray,
|
||
low: np.ndarray,
|
||
close: np.ndarray
|
||
) -> np.ndarray:
|
||
"""使用numpy计算TR"""
|
||
if len(high) != len(low) or len(high) != len(close):
|
||
raise ValueError("high, low, close序列长度必须一致")
|
||
|
||
n = len(high)
|
||
tr = np.full(n, np.nan)
|
||
|
||
if n < 2:
|
||
return tr
|
||
|
||
# 计算三个成分
|
||
high_low = high[1:] - low[1:]
|
||
high_prev_close = np.abs(high[1:] - close[:-1])
|
||
low_prev_close = np.abs(low[1:] - close[:-1])
|
||
|
||
# 取最大值
|
||
tr[1:] = np.maximum(np.maximum(high_low, high_prev_close), low_prev_close)
|
||
|
||
return tr
|
||
|
||
def calculate(
|
||
self,
|
||
high: Union[pd.Series, np.ndarray],
|
||
low: Union[pd.Series, np.ndarray],
|
||
close: Union[pd.Series, np.ndarray]
|
||
) -> Union[pd.Series, np.ndarray]:
|
||
"""
|
||
计算ATR指标
|
||
|
||
Parameters:
|
||
high: 最高价序列
|
||
low: 最低价序列
|
||
close: 收盘价序列
|
||
|
||
Returns:
|
||
ATR序列
|
||
"""
|
||
# 先计算TR
|
||
tr = self.calculate_tr(high, low, close)
|
||
|
||
if self.method == 'sma':
|
||
return self._calculate_sma_atr(tr)
|
||
else:
|
||
return self._calculate_ema_atr(tr)
|
||
|
||
def _calculate_sma_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
|
||
"""使用SMA计算ATR"""
|
||
if isinstance(tr, pd.Series):
|
||
atr = tr.rolling(window=self.period, min_periods=self.period).mean()
|
||
atr.name = f'ATR_{self.period}'
|
||
if len(atr) > 0:
|
||
self._last_atr = atr.iloc[-1]
|
||
return atr
|
||
else:
|
||
n = len(tr)
|
||
atr = np.full(n, np.nan)
|
||
for i in range(self.period - 1, n):
|
||
if i >= self.period:
|
||
atr[i] = np.mean(tr[i - self.period + 1:i + 1])
|
||
if n > 0:
|
||
self._last_atr = atr[-1]
|
||
return atr
|
||
|
||
def _calculate_ema_atr(self, tr: Union[pd.Series, np.ndarray]) -> Union[pd.Series, np.ndarray]:
|
||
"""使用EMA计算ATR,Wilder平滑方法"""
|
||
if isinstance(tr, pd.Series):
|
||
return self._calculate_ema_atr_pandas(tr)
|
||
else:
|
||
return self._calculate_ema_atr_numpy(tr)
|
||
|
||
def _calculate_ema_atr_pandas(self, tr: pd.Series) -> pd.Series:
|
||
"""Pandas版本EMA-ATR计算,使用Wilder平滑"""
|
||
n = len(tr)
|
||
atr = pd.Series(np.full(n, np.nan), index=tr.index, name=f'ATR_{self.period}')
|
||
|
||
if n < self.period:
|
||
self._last_atr = None
|
||
return atr
|
||
|
||
# 第一个ATR用SMA计算
|
||
first_atr = tr.iloc[1:self.period + 1].mean()
|
||
atr.iloc[self.period] = first_atr
|
||
self._last_atr = first_atr
|
||
|
||
# Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period
|
||
alpha = 1.0 / self.period
|
||
|
||
for i in range(self.period + 1, n):
|
||
if np.isnan(tr.iloc[i]):
|
||
atr.iloc[i] = atr.iloc[i - 1]
|
||
else:
|
||
atr.iloc[i] = atr.iloc[i - 1] * (1 - alpha) + tr.iloc[i] * alpha
|
||
|
||
if len(atr) > 0:
|
||
self._last_atr = atr.iloc[-1]
|
||
|
||
return atr
|
||
|
||
def _calculate_ema_atr_numpy(self, tr: np.ndarray) -> np.ndarray:
|
||
"""Numpy版本EMA-ATR计算,使用Wilder平滑"""
|
||
n = len(tr)
|
||
atr = np.full(n, np.nan)
|
||
|
||
if n < self.period:
|
||
self._last_atr = None
|
||
return atr
|
||
|
||
# 第一个ATR用SMA计算
|
||
first_atr = np.mean(tr[1:self.period + 1])
|
||
atr[self.period] = first_atr
|
||
self._last_atr = first_atr
|
||
|
||
# Wilder平滑: ATR_t = (ATR_{t-1} * (period - 1) + TR_t) / period
|
||
alpha = 1.0 / self.period
|
||
|
||
for i in range(self.period + 1, n):
|
||
if np.isnan(tr[i]):
|
||
atr[i] = atr[i - 1]
|
||
else:
|
||
atr[i] = atr[i - 1] * (1 - alpha) + tr[i] * alpha
|
||
|
||
self._last_atr = atr[-1]
|
||
return atr
|
||
|
||
def get_last_atr(self) -> Optional[float]:
|
||
"""获取最后一个ATR值"""
|
||
return self._last_atr
|
||
|
||
def update(self, high: float, low: float, prev_close: float) -> float:
|
||
"""
|
||
更新单个ATR值(实盘时增量更新使用)
|
||
|
||
Parameters:
|
||
high: 当前K线最高价
|
||
low: 当前K线最低价
|
||
prev_close: 前一根K线收盘价
|
||
|
||
Returns:
|
||
新的ATR值
|
||
"""
|
||
if self.method == 'sma':
|
||
# SMA不适合增量更新,这里简化处理,用户应该重新计算完整序列
|
||
raise ValueError("SMA方法不支持增量更新,请重新计算完整序列")
|
||
|
||
# 计算当前TR
|
||
tr1 = abs(high - low)
|
||
tr2 = abs(high - prev_close)
|
||
tr3 = abs(low - prev_close)
|
||
tr = max(tr1, tr2, tr3)
|
||
|
||
if self._last_atr is None:
|
||
# 还没有足够数据,初始化ATR为TR
|
||
self._last_atr = tr
|
||
return tr
|
||
|
||
# EMA支持增量更新
|
||
alpha = 1.0 / self.period
|
||
new_atr = self._last_atr * (1 - alpha) + tr * alpha
|
||
self._last_atr = new_atr
|
||
return new_atr
|
||
|
||
|
||
def calculate_atr(
|
||
df: pd.DataFrame,
|
||
period: int = 14,
|
||
method: str = 'sma',
|
||
high_col: str = 'high',
|
||
low_col: str = 'low',
|
||
close_col: str = 'close',
|
||
drop_na: bool = False
|
||
) -> pd.DataFrame:
|
||
"""
|
||
便捷函数:直接对DataFrame计算ATR并添加到原DataFrame
|
||
|
||
Parameters:
|
||
df: 输入DataFrame,必须包含high, low, close列
|
||
period: ATR周期,默认14
|
||
method: 计算方法 'sma' 或 'ema',默认'sma'
|
||
high_col: 最高价列名,默认'high'
|
||
low_col: 最低价列名,默认'low'
|
||
close_col: 收盘价列名,默认'close'
|
||
drop_na: 是否删除NaN值,默认False
|
||
|
||
Returns:
|
||
添加了TR和ATR列的DataFrame
|
||
"""
|
||
atr_indicator = ATRIndicator(period, method)
|
||
|
||
result = df.copy()
|
||
result['TR'] = atr_indicator.calculate_tr(
|
||
df[high_col], df[low_col], df[close_col]
|
||
)
|
||
result[f'ATR_{period}'] = atr_indicator.calculate(
|
||
df[high_col], df[low_col], df[close_col]
|
||
)
|
||
|
||
if drop_na:
|
||
result = result.dropna()
|
||
|
||
return result
|