""" ATR指标计算示例代码 演示如何使用ATRIndicator类计算ATR指标 """ import pandas as pd import numpy as np from atr_indicator import ATRIndicator, calculate_atr def generate_sample_data(days: int = 100) -> pd.DataFrame: """生成示例测试数据""" np.random.seed(42) dates = pd.date_range(start='2025-01-01', periods=days) close = np.zeros(days) close[0] = 100.0 # 生成随机游走价格 for i in range(1, days): close[i] = close[i-1] + np.random.normal(0, 1.5) # 根据收盘价生成高低价 high = close + np.random.uniform(0.5, 2.0, days) low = close - np.random.uniform(0.5, 2.0, days) df = pd.DataFrame({ 'date': dates, 'open': close, 'high': high, 'low': low, 'close': close }) df = df.set_index('date') return df def example_basic_usage(): """基本使用示例""" print("=" * 60) print("示例1: 基本使用方法") print("=" * 60) # 生成示例数据 df = generate_sample_data(60) print(f"生成了 {len(df)} 根K线数据") print(df.head()) print() # 创建ATR指标实例 atr_sma = ATRIndicator(period=14, method='sma') # 计算ATR tr_sma = atr_sma.calculate_tr(df['high'], df['low'], df['close']) atr_values_sma = atr_sma.calculate(df['high'], df['low'], df['close']) print("TR计算结果(前10行):") print(tr_sma.head(10)) print() print("ATR(SMA)计算结果(最后10行):") print(atr_values_sma.tail(10)) print() print(f"最后一个ATR值: {atr_sma.get_last_atr():.4f}") print() def example_compare_methods(): """比较SMA和EMA两种方法""" print("=" * 60) print("示例2: 比较SMA和EMA两种ATR计算方法") print("=" * 60) df = generate_sample_data(100) # 分别计算两种ATR atr_sma = ATRIndicator(period=14, method='sma') atr_ema = ATRIndicator(period=14, method='ema') atr_sma_values = atr_sma.calculate(df['high'], df['low'], df['close']) atr_ema_values = atr_ema.calculate(df['high'], df['low'], df['close']) # 创建对比DataFrame compare_df = pd.DataFrame({ 'SMA_ATR_14': atr_sma_values, 'EMA_ATR_14': atr_ema_values }) print("最后15行对比结果:") print(compare_df.tail(15)) print() print(f"SMA方法最后ATR: {atr_sma.get_last_atr():.4f}") print(f"EMA方法最后ATR: {atr_ema.get_last_atr():.4f}") print() def example_convenience_function(): """便捷函数使用示例""" print("=" * 60) print("示例3: 使用便捷函数calculate_atr") print("=" * 60) df = generate_sample_data(50) print("原始数据(前5行):") print(df.head()) print() # 一键计算ATR,直接添加到DataFrame result_df = calculate_atr(df, period=14, method='ema', drop_na=True) print("计算结果(包含TR和ATR,drop_na=True后):") print(result_df) print() print(f"结果形状: {result_df.shape}") def example_incremental_update(): """增量更新示例(实盘场景)""" print("=" * 60) print("示例4: EMA增量更新(实盘场景)") print("=" * 60) # 先用历史数据计算 df = generate_sample_data(30) atr_indicator = ATRIndicator(period=14, method='ema') atr_values = atr_indicator.calculate(df['high'], df['low'], df['close']) print(f"历史数据计算完成,最后ATR: {atr_indicator.get_last_atr():.4f}") print() # 模拟新K线到来,增量更新 new_high = 105.2 new_low = 103.8 prev_close = 104.5 new_atr = atr_indicator.update(new_high, new_low, prev_close) print(f"新增一根K线后,新的ATR: {new_atr:.4f}") print() def main(): """运行所有示例""" example_basic_usage() example_compare_methods() example_convenience_function() example_incremental_update() print("=" * 60) print("所有示例运行完成!") print("=" * 60) if __name__ == "__main__": main()