Files
sanguo_quant_live/data_platform/test_catalog.py
T
2026-04-30 21:12:55 +08:00

104 lines
2.7 KiB
Python

"""
数据平台测试 - 验证 DataCatalog 4个核心API
运行方式:cd sanguo_quant_live && python3 data_platform/test_catalog.py
"""
import sys
from pathlib import Path
# 确保项目根目录在 path 中
sys.path.insert(0, str(Path(__file__).parent.parent))
from data_platform import DataCatalog, DataPlatformConfig
def test_config():
"""测试配置层"""
cfg = DataPlatformConfig()
assert cfg.root.exists(), f"项目根不存在: {cfg.root}"
assert cfg.daily_parquet_dir.exists(), f"日线目录不存在: {cfg.daily_parquet_dir}"
assert cfg.stock_info_dir.exists(), f"股票信息目录不存在: {cfg.stock_info_dir}"
print("✅ test_config 通过")
def test_get_daily():
"""测试获取日线数据"""
cat = DataCatalog()
# 基本查询
df = cat.get_daily("600519", start="20250101", end="20251231")
assert len(df) > 0, "日线数据为空"
assert "close" in df.columns, "缺少 close 列"
assert df["close"].notna().all(), "close 列有空值"
# 日期范围
assert df["date"].min() >= __import__("pandas").Timestamp("2025-01-01")
assert df["date"].max() <= __import__("pandas").Timestamp("2025-12-31")
# 深市股票
df2 = cat.get_daily("000001", start="20250101", end="20250301")
assert len(df2) > 0, "深市股票数据为空"
# 不存在的股票
try:
cat.get_daily("999999")
assert False, "应该抛出 FileNotFoundError"
except FileNotFoundError:
pass
print("✅ test_get_daily 通过")
def test_get_stock_list():
"""测试获取股票列表"""
cat = DataCatalog()
# 全部A股
df_all = cat.get_stock_list()
assert len(df_all) > 5000, f"A股数量异常: {len(df_all)}"
# 沪深300
df_hs300 = cat.get_stock_list("hs300")
assert len(df_hs300) == 300, f"沪深300数量异常: {len(df_hs300)}"
print("✅ test_get_stock_list 通过")
def test_get_test_data():
"""测试获取测试数据集"""
cat = DataCatalog()
df = cat.get_test_data("600519")
assert len(df) > 0, "测试数据为空"
assert "close" in df.columns
# 不存在的数据集
try:
cat.get_test_data("NOT_EXIST")
assert False, "应该抛出 FileNotFoundError"
except FileNotFoundError:
pass
print("✅ test_get_test_data 通过")
def test_list_available():
"""测试列出可用数据"""
cat = DataCatalog()
avail = cat.list_available()
assert "daily_parquet" in avail
assert len(avail["daily_parquet"]["years"]) >= 10
print("✅ test_list_available 通过")
if __name__ == "__main__":
test_config()
test_get_daily()
test_get_stock_list()
test_get_test_data()
test_list_available()
print("\n🎉 全部测试通过!")