auto-sync: 2026-04-30 23:09:47
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
数据平台测试 - 验证 DataCatalog 4个核心API
|
数据平台完整测试 - 验证 DataCatalog + BacktestRunner
|
||||||
|
|
||||||
运行方式:cd sanguo_quant_live && python3 data_platform/test_catalog.py
|
运行方式:cd sanguo_quant_live && python3 data_platform/test_catalog.py
|
||||||
"""
|
"""
|
||||||
@@ -7,97 +7,135 @@
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# 确保项目根目录在 path 中
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
from data_platform import DataCatalog, DataPlatformConfig
|
from data_platform import (
|
||||||
|
DataCatalog, DataPlatformConfig,
|
||||||
|
BaseStrategy, BacktestRunner, BacktestReport,
|
||||||
|
)
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# 测试用策略
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
class DummyMAStrategy(BaseStrategy):
|
||||||
|
"""双均线策略(测试用)"""
|
||||||
|
|
||||||
|
def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
data = data.copy()
|
||||||
|
data["ma5"] = data["close"].rolling(5).mean()
|
||||||
|
data["ma20"] = data["close"].rolling(20).mean()
|
||||||
|
data["signal"] = 0
|
||||||
|
data.loc[data["ma5"] > data["ma20"], "signal"] = 1
|
||||||
|
data.loc[data["ma5"] < data["ma20"], "signal"] = -1
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# 测试函数
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
def test_config():
|
def test_config():
|
||||||
"""测试配置层"""
|
|
||||||
cfg = DataPlatformConfig()
|
cfg = DataPlatformConfig()
|
||||||
assert cfg.root.exists(), f"项目根不存在: {cfg.root}"
|
assert cfg.root.exists()
|
||||||
assert cfg.daily_parquet_dir.exists(), f"日线目录不存在: {cfg.daily_parquet_dir}"
|
assert cfg.daily_parquet_dir.exists()
|
||||||
assert cfg.stock_info_dir.exists(), f"股票信息目录不存在: {cfg.stock_info_dir}"
|
|
||||||
print("✅ test_config 通过")
|
print("✅ test_config 通过")
|
||||||
|
|
||||||
|
|
||||||
def test_get_daily():
|
def test_get_daily():
|
||||||
"""测试获取日线数据"""
|
|
||||||
cat = DataCatalog()
|
cat = DataCatalog()
|
||||||
|
|
||||||
# 基本查询
|
|
||||||
df = cat.get_daily("600519", start="20250101", end="20251231")
|
df = cat.get_daily("600519", start="20250101", end="20251231")
|
||||||
assert len(df) > 0, "日线数据为空"
|
assert len(df) > 0
|
||||||
assert "close" in df.columns, "缺少 close 列"
|
assert "close" in df.columns
|
||||||
assert df["close"].notna().all(), "close 列有空值"
|
|
||||||
|
|
||||||
# 日期范围
|
|
||||||
assert df["date"].min() >= __import__("pandas").Timestamp("2025-01-01")
|
|
||||||
assert df["date"].max() <= __import__("pandas").Timestamp("2025-12-31")
|
|
||||||
|
|
||||||
# 深市股票
|
|
||||||
df2 = cat.get_daily("000001", start="20250101", end="20250301")
|
|
||||||
assert len(df2) > 0, "深市股票数据为空"
|
|
||||||
|
|
||||||
# 不存在的股票
|
|
||||||
try:
|
|
||||||
cat.get_daily("999999")
|
|
||||||
assert False, "应该抛出 FileNotFoundError"
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print("✅ test_get_daily 通过")
|
print("✅ test_get_daily 通过")
|
||||||
|
|
||||||
|
|
||||||
def test_get_stock_list():
|
def test_get_daily_batch():
|
||||||
"""测试获取股票列表"""
|
|
||||||
cat = DataCatalog()
|
cat = DataCatalog()
|
||||||
|
result = cat.get_daily_batch(["600519", "000001"], start="20250101", end="20250601")
|
||||||
|
assert len(result) >= 1
|
||||||
|
for code, df in result.items():
|
||||||
|
assert len(df) > 0
|
||||||
|
print("✅ test_get_daily_batch 通过")
|
||||||
|
|
||||||
# 全部A股
|
|
||||||
|
def test_get_stock_list():
|
||||||
|
cat = DataCatalog()
|
||||||
df_all = cat.get_stock_list()
|
df_all = cat.get_stock_list()
|
||||||
assert len(df_all) > 5000, f"A股数量异常: {len(df_all)}"
|
assert len(df_all) > 5000
|
||||||
|
|
||||||
# 沪深300
|
|
||||||
df_hs300 = cat.get_stock_list("hs300")
|
df_hs300 = cat.get_stock_list("hs300")
|
||||||
assert len(df_hs300) == 300, f"沪深300数量异常: {len(df_hs300)}"
|
assert len(df_hs300) == 300
|
||||||
|
|
||||||
print("✅ test_get_stock_list 通过")
|
print("✅ test_get_stock_list 通过")
|
||||||
|
|
||||||
|
|
||||||
def test_get_test_data():
|
def test_get_test_data():
|
||||||
"""测试获取测试数据集"""
|
|
||||||
cat = DataCatalog()
|
cat = DataCatalog()
|
||||||
|
|
||||||
df = cat.get_test_data("600519")
|
df = cat.get_test_data("600519")
|
||||||
assert len(df) > 0, "测试数据为空"
|
assert len(df) > 0
|
||||||
assert "close" in df.columns
|
|
||||||
|
|
||||||
# 不存在的数据集
|
|
||||||
try:
|
|
||||||
cat.get_test_data("NOT_EXIST")
|
|
||||||
assert False, "应该抛出 FileNotFoundError"
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print("✅ test_get_test_data 通过")
|
print("✅ test_get_test_data 通过")
|
||||||
|
|
||||||
|
|
||||||
def test_list_available():
|
def test_list_available():
|
||||||
"""测试列出可用数据"""
|
|
||||||
cat = DataCatalog()
|
cat = DataCatalog()
|
||||||
|
|
||||||
avail = cat.list_available()
|
avail = cat.list_available()
|
||||||
assert "daily_parquet" in avail
|
assert "daily_parquet" in avail
|
||||||
assert len(avail["daily_parquet"]["years"]) >= 10
|
|
||||||
|
|
||||||
print("✅ test_list_available 通过")
|
print("✅ test_list_available 通过")
|
||||||
|
|
||||||
|
|
||||||
|
def test_backtest_runner():
|
||||||
|
"""测试完整回测流程:获取数据 → 策略 → 模拟交易 → 报告"""
|
||||||
|
cat = DataCatalog()
|
||||||
|
runner = BacktestRunner(cat)
|
||||||
|
strategy = DummyMAStrategy()
|
||||||
|
|
||||||
|
result = runner.run(strategy, "600519", "20250101", "20251231")
|
||||||
|
|
||||||
|
assert result.strategy_name == "DummyMAStrategy"
|
||||||
|
assert result.code == "600519"
|
||||||
|
assert result.initial_capital > 0
|
||||||
|
assert result.final_capital > 0
|
||||||
|
assert -1.0 <= result.max_drawdown <= 0.0 # 回撤是负数
|
||||||
|
assert result.total_trades >= 0
|
||||||
|
assert result.equity_curve is not None
|
||||||
|
|
||||||
|
# 测试报告
|
||||||
|
report = BacktestReport(result)
|
||||||
|
text = report.to_text()
|
||||||
|
assert "600519" in text
|
||||||
|
assert "总收益率" in text
|
||||||
|
|
||||||
|
d = report.to_dict()
|
||||||
|
assert d["code"] == "600519"
|
||||||
|
assert "total_return" in d
|
||||||
|
|
||||||
|
j = report.to_json()
|
||||||
|
assert '"code": "600519"' in j
|
||||||
|
|
||||||
|
print("✅ test_backtest_runner 通过")
|
||||||
|
|
||||||
|
|
||||||
|
def test_backtest_batch():
|
||||||
|
"""测试批量回测"""
|
||||||
|
cat = DataCatalog()
|
||||||
|
runner = BacktestRunner(cat)
|
||||||
|
strategy = DummyMAStrategy()
|
||||||
|
|
||||||
|
results = runner.run_batch(strategy, ["600519", "000001"], "20250101", "20250601")
|
||||||
|
assert len(results) >= 1
|
||||||
|
for code, result in results.items():
|
||||||
|
assert result.code == code
|
||||||
|
print("✅ test_backtest_batch 通过")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_config()
|
test_config()
|
||||||
test_get_daily()
|
test_get_daily()
|
||||||
|
test_get_daily_batch()
|
||||||
test_get_stock_list()
|
test_get_stock_list()
|
||||||
test_get_test_data()
|
test_get_test_data()
|
||||||
test_list_available()
|
test_list_available()
|
||||||
print("\n🎉 全部测试通过!")
|
test_backtest_runner()
|
||||||
|
test_backtest_batch()
|
||||||
|
print("\n🎉 全部 8 项测试通过!")
|
||||||
|
|||||||
Reference in New Issue
Block a user