auto-sync: 2026-04-17 20:23:35

This commit is contained in:
cfdaily
2026-04-17 20:23:35 +08:00
parent a5fce664d3
commit 385bbabce5
@@ -0,0 +1,157 @@
"""
ATR指标计算示例代码
演示如何使用ATRIndicator类计算ATR指标
"""
import pandas as pd
import numpy as np
from atr_indicator import ATRIndicator, calculate_atr
def generate_sample_data(days: int = 100) -> pd.DataFrame:
"""生成示例测试数据"""
np.random.seed(42)
dates = pd.date_range(start='2025-01-01', periods=days)
close = np.zeros(days)
close[0] = 100.0
# 生成随机游走价格
for i in range(1, days):
close[i] = close[i-1] + np.random.normal(0, 1.5)
# 根据收盘价生成高低价
high = close + np.random.uniform(0.5, 2.0, days)
low = close - np.random.uniform(0.5, 2.0, days)
df = pd.DataFrame({
'date': dates,
'open': close,
'high': high,
'low': low,
'close': close
})
df = df.set_index('date')
return df
def example_basic_usage():
"""基本使用示例"""
print("=" * 60)
print("示例1: 基本使用方法")
print("=" * 60)
# 生成示例数据
df = generate_sample_data(60)
print(f"生成了 {len(df)} 根K线数据")
print(df.head())
print()
# 创建ATR指标实例
atr_sma = ATRIndicator(period=14, method='sma')
# 计算ATR
tr_sma = atr_sma.calculate_tr(df['high'], df['low'], df['close'])
atr_values_sma = atr_sma.calculate(df['high'], df['low'], df['close'])
print("TR计算结果(前10行):")
print(tr_sma.head(10))
print()
print("ATR(SMA)计算结果(最后10行):")
print(atr_values_sma.tail(10))
print()
print(f"最后一个ATR值: {atr_sma.get_last_atr():.4f}")
print()
def example_compare_methods():
"""比较SMA和EMA两种方法"""
print("=" * 60)
print("示例2: 比较SMA和EMA两种ATR计算方法")
print("=" * 60)
df = generate_sample_data(100)
# 分别计算两种ATR
atr_sma = ATRIndicator(period=14, method='sma')
atr_ema = ATRIndicator(period=14, method='ema')
atr_sma_values = atr_sma.calculate(df['high'], df['low'], df['close'])
atr_ema_values = atr_ema.calculate(df['high'], df['low'], df['close'])
# 创建对比DataFrame
compare_df = pd.DataFrame({
'SMA_ATR_14': atr_sma_values,
'EMA_ATR_14': atr_ema_values
})
print("最后15行对比结果:")
print(compare_df.tail(15))
print()
print(f"SMA方法最后ATR: {atr_sma.get_last_atr():.4f}")
print(f"EMA方法最后ATR: {atr_ema.get_last_atr():.4f}")
print()
def example_convenience_function():
"""便捷函数使用示例"""
print("=" * 60)
print("示例3: 使用便捷函数calculate_atr")
print("=" * 60)
df = generate_sample_data(50)
print("原始数据(前5行):")
print(df.head())
print()
# 一键计算ATR,直接添加到DataFrame
result_df = calculate_atr(df, period=14, method='ema', drop_na=True)
print("计算结果(包含TR和ATRdrop_na=True后):")
print(result_df)
print()
print(f"结果形状: {result_df.shape}")
def example_incremental_update():
"""增量更新示例(实盘场景)"""
print("=" * 60)
print("示例4: EMA增量更新(实盘场景)")
print("=" * 60)
# 先用历史数据计算
df = generate_sample_data(30)
atr_indicator = ATRIndicator(period=14, method='ema')
atr_values = atr_indicator.calculate(df['high'], df['low'], df['close'])
print(f"历史数据计算完成,最后ATR: {atr_indicator.get_last_atr():.4f}")
print()
# 模拟新K线到来,增量更新
new_high = 105.2
new_low = 103.8
prev_close = 104.5
new_atr = atr_indicator.update(new_high, new_low, prev_close)
print(f"新增一根K线后,新的ATR: {new_atr:.4f}")
print()
def main():
"""运行所有示例"""
example_basic_usage()
example_compare_methods()
example_convenience_function()
example_incremental_update()
print("=" * 60)
print("所有示例运行完成!")
print("=" * 60)
if __name__ == "__main__":
main()