""" 数据平台完整测试 - 验证 DataCatalog + BacktestRunner 运行方式:cd sanguo_quant_live && python3 data_platform/test_catalog.py """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from data_platform import ( DataCatalog, DataPlatformConfig, BaseStrategy, BacktestRunner, BacktestReport, ) import pandas as pd # ====================================================================== # 测试用策略 # ====================================================================== class DummyMAStrategy(BaseStrategy): """双均线策略(测试用)""" def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame: data = data.copy() data["ma5"] = data["close"].rolling(5).mean() data["ma20"] = data["close"].rolling(20).mean() data["signal"] = 0 data.loc[data["ma5"] > data["ma20"], "signal"] = 1 data.loc[data["ma5"] < data["ma20"], "signal"] = -1 return data # ====================================================================== # 测试函数 # ====================================================================== def test_config(): cfg = DataPlatformConfig() assert cfg.root.exists() assert cfg.daily_parquet_dir.exists() print("✅ test_config 通过") def test_get_daily(): cat = DataCatalog() df = cat.get_daily("600519", start="20250101", end="20251231") assert len(df) > 0 assert "close" in df.columns print("✅ test_get_daily 通过") def test_get_daily_batch(): cat = DataCatalog() result = cat.get_daily_batch(["600519", "000001"], start="20250101", end="20250601") assert len(result) >= 1 for code, df in result.items(): assert len(df) > 0 print("✅ test_get_daily_batch 通过") def test_get_stock_list(): cat = DataCatalog() df_all = cat.get_stock_list() assert len(df_all) > 5000 df_hs300 = cat.get_stock_list("hs300") assert len(df_hs300) == 300 print("✅ test_get_stock_list 通过") def test_get_test_data(): cat = DataCatalog() df = cat.get_test_data("600519") assert len(df) > 0 print("✅ test_get_test_data 通过") def test_list_available(): cat = DataCatalog() avail = cat.list_available() assert "daily_parquet" in avail print("✅ test_list_available 通过") def test_backtest_runner(): """测试完整回测流程:获取数据 → 策略 → 模拟交易 → 报告""" cat = DataCatalog() runner = BacktestRunner(cat) strategy = DummyMAStrategy() result = runner.run(strategy, "600519", "20250101", "20251231") assert result.strategy_name == "DummyMAStrategy" assert result.code == "600519" assert result.initial_capital > 0 assert result.final_capital > 0 assert -1.0 <= result.max_drawdown <= 0.0 # 回撤是负数 assert result.total_trades >= 0 assert result.equity_curve is not None # 测试报告 report = BacktestReport(result) text = report.to_text() assert "600519" in text assert "总收益率" in text d = report.to_dict() assert d["code"] == "600519" assert "total_return" in d j = report.to_json() assert '"code": "600519"' in j print("✅ test_backtest_runner 通过") def test_backtest_batch(): """测试批量回测""" cat = DataCatalog() runner = BacktestRunner(cat) strategy = DummyMAStrategy() results = runner.run_batch(strategy, ["600519", "000001"], "20250101", "20250601") assert len(results) >= 1 for code, result in results.items(): assert result.code == code print("✅ test_backtest_batch 通过") if __name__ == "__main__": test_config() test_get_daily() test_get_daily_batch() test_get_stock_list() test_get_test_data() test_list_available() test_backtest_runner() test_backtest_batch() print("\n🎉 全部 8 项测试通过!")