diff --git a/data_platform/test_catalog.py b/data_platform/test_catalog.py new file mode 100644 index 000000000..a4026a497 --- /dev/null +++ b/data_platform/test_catalog.py @@ -0,0 +1,103 @@ +""" +数据平台测试 - 验证 DataCatalog 4个核心API + +运行方式:cd sanguo_quant_live && python3 data_platform/test_catalog.py +""" + +import sys +from pathlib import Path + +# 确保项目根目录在 path 中 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from data_platform import DataCatalog, DataPlatformConfig + + +def test_config(): + """测试配置层""" + cfg = DataPlatformConfig() + assert cfg.root.exists(), f"项目根不存在: {cfg.root}" + assert cfg.daily_parquet_dir.exists(), f"日线目录不存在: {cfg.daily_parquet_dir}" + assert cfg.stock_info_dir.exists(), f"股票信息目录不存在: {cfg.stock_info_dir}" + print("✅ test_config 通过") + + +def test_get_daily(): + """测试获取日线数据""" + cat = DataCatalog() + + # 基本查询 + df = cat.get_daily("600519", start="20250101", end="20251231") + assert len(df) > 0, "日线数据为空" + assert "close" in df.columns, "缺少 close 列" + assert df["close"].notna().all(), "close 列有空值" + + # 日期范围 + assert df["date"].min() >= __import__("pandas").Timestamp("2025-01-01") + assert df["date"].max() <= __import__("pandas").Timestamp("2025-12-31") + + # 深市股票 + df2 = cat.get_daily("000001", start="20250101", end="20250301") + assert len(df2) > 0, "深市股票数据为空" + + # 不存在的股票 + try: + cat.get_daily("999999") + assert False, "应该抛出 FileNotFoundError" + except FileNotFoundError: + pass + + print("✅ test_get_daily 通过") + + +def test_get_stock_list(): + """测试获取股票列表""" + cat = DataCatalog() + + # 全部A股 + df_all = cat.get_stock_list() + assert len(df_all) > 5000, f"A股数量异常: {len(df_all)}" + + # 沪深300 + df_hs300 = cat.get_stock_list("hs300") + assert len(df_hs300) == 300, f"沪深300数量异常: {len(df_hs300)}" + + print("✅ test_get_stock_list 通过") + + +def test_get_test_data(): + """测试获取测试数据集""" + cat = DataCatalog() + + df = cat.get_test_data("600519") + assert len(df) > 0, "测试数据为空" + assert "close" in df.columns + + # 不存在的数据集 + try: + cat.get_test_data("NOT_EXIST") + assert False, "应该抛出 FileNotFoundError" + except FileNotFoundError: + pass + + print("✅ test_get_test_data 通过") + + +def test_list_available(): + """测试列出可用数据""" + cat = DataCatalog() + + avail = cat.list_available() + assert "daily_parquet" in avail + assert len(avail["daily_parquet"]["years"]) >= 10 + + print("✅ test_list_available 通过") + + +if __name__ == "__main__": + test_config() + test_get_daily() + test_get_stock_list() + test_get_test_data() + test_list_available() + print("\n🎉 全部测试通过!")