auto-sync: 2026-04-17 20:23:35
This commit is contained in:
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
ATR指标计算示例代码
|
||||||
|
演示如何使用ATRIndicator类计算ATR指标
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from atr_indicator import ATRIndicator, calculate_atr
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sample_data(days: int = 100) -> pd.DataFrame:
|
||||||
|
"""生成示例测试数据"""
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
dates = pd.date_range(start='2025-01-01', periods=days)
|
||||||
|
close = np.zeros(days)
|
||||||
|
close[0] = 100.0
|
||||||
|
|
||||||
|
# 生成随机游走价格
|
||||||
|
for i in range(1, days):
|
||||||
|
close[i] = close[i-1] + np.random.normal(0, 1.5)
|
||||||
|
|
||||||
|
# 根据收盘价生成高低价
|
||||||
|
high = close + np.random.uniform(0.5, 2.0, days)
|
||||||
|
low = close - np.random.uniform(0.5, 2.0, days)
|
||||||
|
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'date': dates,
|
||||||
|
'open': close,
|
||||||
|
'high': high,
|
||||||
|
'low': low,
|
||||||
|
'close': close
|
||||||
|
})
|
||||||
|
df = df.set_index('date')
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def example_basic_usage():
|
||||||
|
"""基本使用示例"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("示例1: 基本使用方法")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 生成示例数据
|
||||||
|
df = generate_sample_data(60)
|
||||||
|
print(f"生成了 {len(df)} 根K线数据")
|
||||||
|
print(df.head())
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 创建ATR指标实例
|
||||||
|
atr_sma = ATRIndicator(period=14, method='sma')
|
||||||
|
|
||||||
|
# 计算ATR
|
||||||
|
tr_sma = atr_sma.calculate_tr(df['high'], df['low'], df['close'])
|
||||||
|
atr_values_sma = atr_sma.calculate(df['high'], df['low'], df['close'])
|
||||||
|
|
||||||
|
print("TR计算结果(前10行):")
|
||||||
|
print(tr_sma.head(10))
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("ATR(SMA)计算结果(最后10行):")
|
||||||
|
print(atr_values_sma.tail(10))
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"最后一个ATR值: {atr_sma.get_last_atr():.4f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def example_compare_methods():
|
||||||
|
"""比较SMA和EMA两种方法"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("示例2: 比较SMA和EMA两种ATR计算方法")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
df = generate_sample_data(100)
|
||||||
|
|
||||||
|
# 分别计算两种ATR
|
||||||
|
atr_sma = ATRIndicator(period=14, method='sma')
|
||||||
|
atr_ema = ATRIndicator(period=14, method='ema')
|
||||||
|
|
||||||
|
atr_sma_values = atr_sma.calculate(df['high'], df['low'], df['close'])
|
||||||
|
atr_ema_values = atr_ema.calculate(df['high'], df['low'], df['close'])
|
||||||
|
|
||||||
|
# 创建对比DataFrame
|
||||||
|
compare_df = pd.DataFrame({
|
||||||
|
'SMA_ATR_14': atr_sma_values,
|
||||||
|
'EMA_ATR_14': atr_ema_values
|
||||||
|
})
|
||||||
|
|
||||||
|
print("最后15行对比结果:")
|
||||||
|
print(compare_df.tail(15))
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"SMA方法最后ATR: {atr_sma.get_last_atr():.4f}")
|
||||||
|
print(f"EMA方法最后ATR: {atr_ema.get_last_atr():.4f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def example_convenience_function():
|
||||||
|
"""便捷函数使用示例"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("示例3: 使用便捷函数calculate_atr")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
df = generate_sample_data(50)
|
||||||
|
print("原始数据(前5行):")
|
||||||
|
print(df.head())
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 一键计算ATR,直接添加到DataFrame
|
||||||
|
result_df = calculate_atr(df, period=14, method='ema', drop_na=True)
|
||||||
|
|
||||||
|
print("计算结果(包含TR和ATR,drop_na=True后):")
|
||||||
|
print(result_df)
|
||||||
|
print()
|
||||||
|
print(f"结果形状: {result_df.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
def example_incremental_update():
|
||||||
|
"""增量更新示例(实盘场景)"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("示例4: EMA增量更新(实盘场景)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 先用历史数据计算
|
||||||
|
df = generate_sample_data(30)
|
||||||
|
atr_indicator = ATRIndicator(period=14, method='ema')
|
||||||
|
atr_values = atr_indicator.calculate(df['high'], df['low'], df['close'])
|
||||||
|
|
||||||
|
print(f"历史数据计算完成,最后ATR: {atr_indicator.get_last_atr():.4f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 模拟新K线到来,增量更新
|
||||||
|
new_high = 105.2
|
||||||
|
new_low = 103.8
|
||||||
|
prev_close = 104.5
|
||||||
|
|
||||||
|
new_atr = atr_indicator.update(new_high, new_low, prev_close)
|
||||||
|
print(f"新增一根K线后,新的ATR: {new_atr:.4f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""运行所有示例"""
|
||||||
|
example_basic_usage()
|
||||||
|
example_compare_methods()
|
||||||
|
example_convenience_function()
|
||||||
|
example_incremental_update()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("所有示例运行完成!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user