Files
2026-04-30 23:09:48 +08:00

142 lines
3.9 KiB
Python

"""
数据平台完整测试 - 验证 DataCatalog + BacktestRunner
运行方式:cd sanguo_quant_live && python3 data_platform/test_catalog.py
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from data_platform import (
DataCatalog, DataPlatformConfig,
BaseStrategy, BacktestRunner, BacktestReport,
)
import pandas as pd
# ======================================================================
# 测试用策略
# ======================================================================
class DummyMAStrategy(BaseStrategy):
"""双均线策略(测试用)"""
def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame:
data = data.copy()
data["ma5"] = data["close"].rolling(5).mean()
data["ma20"] = data["close"].rolling(20).mean()
data["signal"] = 0
data.loc[data["ma5"] > data["ma20"], "signal"] = 1
data.loc[data["ma5"] < data["ma20"], "signal"] = -1
return data
# ======================================================================
# 测试函数
# ======================================================================
def test_config():
cfg = DataPlatformConfig()
assert cfg.root.exists()
assert cfg.daily_parquet_dir.exists()
print("✅ test_config 通过")
def test_get_daily():
cat = DataCatalog()
df = cat.get_daily("600519", start="20250101", end="20251231")
assert len(df) > 0
assert "close" in df.columns
print("✅ test_get_daily 通过")
def test_get_daily_batch():
cat = DataCatalog()
result = cat.get_daily_batch(["600519", "000001"], start="20250101", end="20250601")
assert len(result) >= 1
for code, df in result.items():
assert len(df) > 0
print("✅ test_get_daily_batch 通过")
def test_get_stock_list():
cat = DataCatalog()
df_all = cat.get_stock_list()
assert len(df_all) > 5000
df_hs300 = cat.get_stock_list("hs300")
assert len(df_hs300) == 300
print("✅ test_get_stock_list 通过")
def test_get_test_data():
cat = DataCatalog()
df = cat.get_test_data("600519")
assert len(df) > 0
print("✅ test_get_test_data 通过")
def test_list_available():
cat = DataCatalog()
avail = cat.list_available()
assert "daily_parquet" in avail
print("✅ test_list_available 通过")
def test_backtest_runner():
"""测试完整回测流程:获取数据 → 策略 → 模拟交易 → 报告"""
cat = DataCatalog()
runner = BacktestRunner(cat)
strategy = DummyMAStrategy()
result = runner.run(strategy, "600519", "20250101", "20251231")
assert result.strategy_name == "DummyMAStrategy"
assert result.code == "600519"
assert result.initial_capital > 0
assert result.final_capital > 0
assert -1.0 <= result.max_drawdown <= 0.0 # 回撤是负数
assert result.total_trades >= 0
assert result.equity_curve is not None
# 测试报告
report = BacktestReport(result)
text = report.to_text()
assert "600519" in text
assert "总收益率" in text
d = report.to_dict()
assert d["code"] == "600519"
assert "total_return" in d
j = report.to_json()
assert '"code": "600519"' in j
print("✅ test_backtest_runner 通过")
def test_backtest_batch():
"""测试批量回测"""
cat = DataCatalog()
runner = BacktestRunner(cat)
strategy = DummyMAStrategy()
results = runner.run_batch(strategy, ["600519", "000001"], "20250101", "20250601")
assert len(results) >= 1
for code, result in results.items():
assert result.code == code
print("✅ test_backtest_batch 通过")
if __name__ == "__main__":
test_config()
test_get_daily()
test_get_daily_batch()
test_get_stock_list()
test_get_test_data()
test_list_available()
test_backtest_runner()
test_backtest_batch()
print("\n🎉 全部 8 项测试通过!")