Files
2026-04-17 20:23:35 +08:00

158 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ATR指标计算示例代码
演示如何使用ATRIndicator类计算ATR指标
"""
import pandas as pd
import numpy as np
from atr_indicator import ATRIndicator, calculate_atr
def generate_sample_data(days: int = 100) -> pd.DataFrame:
"""生成示例测试数据"""
np.random.seed(42)
dates = pd.date_range(start='2025-01-01', periods=days)
close = np.zeros(days)
close[0] = 100.0
# 生成随机游走价格
for i in range(1, days):
close[i] = close[i-1] + np.random.normal(0, 1.5)
# 根据收盘价生成高低价
high = close + np.random.uniform(0.5, 2.0, days)
low = close - np.random.uniform(0.5, 2.0, days)
df = pd.DataFrame({
'date': dates,
'open': close,
'high': high,
'low': low,
'close': close
})
df = df.set_index('date')
return df
def example_basic_usage():
"""基本使用示例"""
print("=" * 60)
print("示例1: 基本使用方法")
print("=" * 60)
# 生成示例数据
df = generate_sample_data(60)
print(f"生成了 {len(df)} 根K线数据")
print(df.head())
print()
# 创建ATR指标实例
atr_sma = ATRIndicator(period=14, method='sma')
# 计算ATR
tr_sma = atr_sma.calculate_tr(df['high'], df['low'], df['close'])
atr_values_sma = atr_sma.calculate(df['high'], df['low'], df['close'])
print("TR计算结果(前10行):")
print(tr_sma.head(10))
print()
print("ATR(SMA)计算结果(最后10行):")
print(atr_values_sma.tail(10))
print()
print(f"最后一个ATR值: {atr_sma.get_last_atr():.4f}")
print()
def example_compare_methods():
"""比较SMA和EMA两种方法"""
print("=" * 60)
print("示例2: 比较SMA和EMA两种ATR计算方法")
print("=" * 60)
df = generate_sample_data(100)
# 分别计算两种ATR
atr_sma = ATRIndicator(period=14, method='sma')
atr_ema = ATRIndicator(period=14, method='ema')
atr_sma_values = atr_sma.calculate(df['high'], df['low'], df['close'])
atr_ema_values = atr_ema.calculate(df['high'], df['low'], df['close'])
# 创建对比DataFrame
compare_df = pd.DataFrame({
'SMA_ATR_14': atr_sma_values,
'EMA_ATR_14': atr_ema_values
})
print("最后15行对比结果:")
print(compare_df.tail(15))
print()
print(f"SMA方法最后ATR: {atr_sma.get_last_atr():.4f}")
print(f"EMA方法最后ATR: {atr_ema.get_last_atr():.4f}")
print()
def example_convenience_function():
"""便捷函数使用示例"""
print("=" * 60)
print("示例3: 使用便捷函数calculate_atr")
print("=" * 60)
df = generate_sample_data(50)
print("原始数据(前5行):")
print(df.head())
print()
# 一键计算ATR,直接添加到DataFrame
result_df = calculate_atr(df, period=14, method='ema', drop_na=True)
print("计算结果(包含TR和ATRdrop_na=True后):")
print(result_df)
print()
print(f"结果形状: {result_df.shape}")
def example_incremental_update():
"""增量更新示例(实盘场景)"""
print("=" * 60)
print("示例4: EMA增量更新(实盘场景)")
print("=" * 60)
# 先用历史数据计算
df = generate_sample_data(30)
atr_indicator = ATRIndicator(period=14, method='ema')
atr_values = atr_indicator.calculate(df['high'], df['low'], df['close'])
print(f"历史数据计算完成,最后ATR: {atr_indicator.get_last_atr():.4f}")
print()
# 模拟新K线到来,增量更新
new_high = 105.2
new_low = 103.8
prev_close = 104.5
new_atr = atr_indicator.update(new_high, new_low, prev_close)
print(f"新增一根K线后,新的ATR: {new_atr:.4f}")
print()
def main():
"""运行所有示例"""
example_basic_usage()
example_compare_methods()
example_convenience_function()
example_incremental_update()
print("=" * 60)
print("所有示例运行完成!")
print("=" * 60)
if __name__ == "__main__":
main()