diff --git a/data_platform/test_catalog.py b/data_platform/test_catalog.py index a4026a497..0dde75375 100644 --- a/data_platform/test_catalog.py +++ b/data_platform/test_catalog.py @@ -1,5 +1,5 @@ """ -数据平台测试 - 验证 DataCatalog 4个核心API +数据平台完整测试 - 验证 DataCatalog + BacktestRunner 运行方式:cd sanguo_quant_live && python3 data_platform/test_catalog.py """ @@ -7,97 +7,135 @@ import sys from pathlib import Path -# 确保项目根目录在 path 中 sys.path.insert(0, str(Path(__file__).parent.parent)) -from data_platform import DataCatalog, DataPlatformConfig +from data_platform import ( + DataCatalog, DataPlatformConfig, + BaseStrategy, BacktestRunner, BacktestReport, +) +import pandas as pd +# ====================================================================== +# 测试用策略 +# ====================================================================== + +class DummyMAStrategy(BaseStrategy): + """双均线策略(测试用)""" + + def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame: + data = data.copy() + data["ma5"] = data["close"].rolling(5).mean() + data["ma20"] = data["close"].rolling(20).mean() + data["signal"] = 0 + data.loc[data["ma5"] > data["ma20"], "signal"] = 1 + data.loc[data["ma5"] < data["ma20"], "signal"] = -1 + return data + + +# ====================================================================== +# 测试函数 +# ====================================================================== + def test_config(): - """测试配置层""" cfg = DataPlatformConfig() - assert cfg.root.exists(), f"项目根不存在: {cfg.root}" - assert cfg.daily_parquet_dir.exists(), f"日线目录不存在: {cfg.daily_parquet_dir}" - assert cfg.stock_info_dir.exists(), f"股票信息目录不存在: {cfg.stock_info_dir}" + assert cfg.root.exists() + assert cfg.daily_parquet_dir.exists() print("✅ test_config 通过") def test_get_daily(): - """测试获取日线数据""" cat = DataCatalog() - - # 基本查询 df = cat.get_daily("600519", start="20250101", end="20251231") - assert len(df) > 0, "日线数据为空" - assert "close" in df.columns, "缺少 close 列" - assert df["close"].notna().all(), "close 列有空值" - - # 日期范围 - assert df["date"].min() >= __import__("pandas").Timestamp("2025-01-01") - assert df["date"].max() <= __import__("pandas").Timestamp("2025-12-31") - - # 深市股票 - df2 = cat.get_daily("000001", start="20250101", end="20250301") - assert len(df2) > 0, "深市股票数据为空" - - # 不存在的股票 - try: - cat.get_daily("999999") - assert False, "应该抛出 FileNotFoundError" - except FileNotFoundError: - pass - + assert len(df) > 0 + assert "close" in df.columns print("✅ test_get_daily 通过") -def test_get_stock_list(): - """测试获取股票列表""" +def test_get_daily_batch(): cat = DataCatalog() + result = cat.get_daily_batch(["600519", "000001"], start="20250101", end="20250601") + assert len(result) >= 1 + for code, df in result.items(): + assert len(df) > 0 + print("✅ test_get_daily_batch 通过") - # 全部A股 + +def test_get_stock_list(): + cat = DataCatalog() df_all = cat.get_stock_list() - assert len(df_all) > 5000, f"A股数量异常: {len(df_all)}" - - # 沪深300 + assert len(df_all) > 5000 df_hs300 = cat.get_stock_list("hs300") - assert len(df_hs300) == 300, f"沪深300数量异常: {len(df_hs300)}" - + assert len(df_hs300) == 300 print("✅ test_get_stock_list 通过") def test_get_test_data(): - """测试获取测试数据集""" cat = DataCatalog() - df = cat.get_test_data("600519") - assert len(df) > 0, "测试数据为空" - assert "close" in df.columns - - # 不存在的数据集 - try: - cat.get_test_data("NOT_EXIST") - assert False, "应该抛出 FileNotFoundError" - except FileNotFoundError: - pass - + assert len(df) > 0 print("✅ test_get_test_data 通过") def test_list_available(): - """测试列出可用数据""" cat = DataCatalog() - avail = cat.list_available() assert "daily_parquet" in avail - assert len(avail["daily_parquet"]["years"]) >= 10 - print("✅ test_list_available 通过") +def test_backtest_runner(): + """测试完整回测流程:获取数据 → 策略 → 模拟交易 → 报告""" + cat = DataCatalog() + runner = BacktestRunner(cat) + strategy = DummyMAStrategy() + + result = runner.run(strategy, "600519", "20250101", "20251231") + + assert result.strategy_name == "DummyMAStrategy" + assert result.code == "600519" + assert result.initial_capital > 0 + assert result.final_capital > 0 + assert -1.0 <= result.max_drawdown <= 0.0 # 回撤是负数 + assert result.total_trades >= 0 + assert result.equity_curve is not None + + # 测试报告 + report = BacktestReport(result) + text = report.to_text() + assert "600519" in text + assert "总收益率" in text + + d = report.to_dict() + assert d["code"] == "600519" + assert "total_return" in d + + j = report.to_json() + assert '"code": "600519"' in j + + print("✅ test_backtest_runner 通过") + + +def test_backtest_batch(): + """测试批量回测""" + cat = DataCatalog() + runner = BacktestRunner(cat) + strategy = DummyMAStrategy() + + results = runner.run_batch(strategy, ["600519", "000001"], "20250101", "20250601") + assert len(results) >= 1 + for code, result in results.items(): + assert result.code == code + print("✅ test_backtest_batch 通过") + + if __name__ == "__main__": test_config() test_get_daily() + test_get_daily_batch() test_get_stock_list() test_get_test_data() test_list_available() - print("\n🎉 全部测试通过!") + test_backtest_runner() + test_backtest_batch() + print("\n🎉 全部 8 项测试通过!")