diff --git a/data_platform/catalog.py b/data_platform/catalog.py new file mode 100644 index 000000000..35c576108 --- /dev/null +++ b/data_platform/catalog.py @@ -0,0 +1,206 @@ +""" +DataCatalog - 统一数据访问接口 + +策略开发者只需通过 DataCatalog 获取数据, +无需关心底层文件路径和存储格式。 + +4个核心API: +- get_daily() 获取单只股票日线行情 +- get_stock_list() 获取股票基础信息/指数成分股 +- get_test_data() 获取标准测试数据集 +- list_available() 查看可用的数据资产 +""" + +import os +import logging +from pathlib import Path +from typing import Optional, List + +import pandas as pd + +from data_platform.config import DataPlatformConfig + +logger = logging.getLogger(__name__) + + +class DataCatalog: + """ + 统一数据目录 —— 项目唯一数据入口 + + Usage: + from data_platform import DataCatalog + + cat = DataCatalog() + df = cat.get_daily("600519", start="20250101", end="20260101") + stocks = cat.get_stock_list() + """ + + def __init__(self, project_root: Optional[str] = None): + self.config = DataPlatformConfig(project_root) + + # ------------------------------------------------------------------ + # F1: get_daily — 单只股票日线行情 + # ------------------------------------------------------------------ + + def get_daily( + self, + code: str, + start: Optional[str] = None, + end: Optional[str] = None, + years: Optional[List[int]] = None, + ) -> pd.DataFrame: + """ + 获取单只股票日线行情(从本地 Parquet 读取) + + Args: + code: 6位股票代码,如 "600519" + start: 起始日期 "YYYYMMDD" + end: 结束日期 "YYYYMMDD" + years: 指定年份列表,如 [2024, 2025];默认自动推断 + + Returns: + DataFrame,列: date, open, high, low, close, volume, amount, ... + """ + code = str(code).strip().zfill(6) + prefix = "sh" if code.startswith("6") else "sz" + pattern = f"{prefix}{code}_daily.parquet" + + # 确定要扫描的年份 + if years is None: + scan_years = self._detect_years() + else: + scan_years = years + + frames = [] + for year in sorted(scan_years): + fp = self.config.daily_parquet_dir / str(year) / pattern + if fp.exists(): + frames.append(pd.read_parquet(fp)) + + if not frames: + raise FileNotFoundError( + f"未找到股票 {code} 的日线数据,扫描年份: {scan_years}" + ) + + df = pd.concat(frames, ignore_index=True) + df["date"] = pd.to_datetime(df["date"]) + df = df.sort_values("date").reset_index(drop=True) + + # 日期过滤 + if start: + df = df[df["date"] >= pd.Timestamp(start)] + if end: + df = df[df["date"] <= pd.Timestamp(end)] + + return df + + # ------------------------------------------------------------------ + # F1: get_stock_list — 股票列表 / 指数成分股 + # ------------------------------------------------------------------ + + def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame: + """ + 获取股票基础信息或指数成分股 + + Args: + index: 指数代码,如 "hs300";None 返回全部 A 股基础信息 + + Returns: + DataFrame + """ + if index == "hs300": + fp = self.config.stock_info_dir / "hs300_constituents_latest.csv" + if not fp.exists(): + raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}") + return pd.read_csv(fp) + + # 全部 A 股基础信息 —— 找最新的 stock_basic_info 文件 + info_dir = self.config.stock_info_dir + candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv")) + if not candidates: + raise FileNotFoundError(f"未找到股票基础信息文件: {info_dir}") + return pd.read_csv(candidates[-1]) + + # ------------------------------------------------------------------ + # F1: get_test_data — 标准测试数据集 + # ------------------------------------------------------------------ + + def get_test_data(self, name: str) -> pd.DataFrame: + """ + 获取标准测试数据集 + + Args: + name: 数据集名称,如 "600519" 或 "贵州茅台" + + Returns: + DataFrame + + Example: + cat.get_test_data("600519") # 茅台252日数据 + """ + test_dir = self.config.test_datasets_dir + if not test_dir.exists(): + raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}") + + # 模糊匹配:文件名包含 code 或 名称 + for fp in test_dir.glob("*.csv"): + if name in fp.stem: + return pd.read_csv(fp, parse_dates=["date"]) + + raise FileNotFoundError( + f"未找到测试数据集 '{name}',可用: " + f"{[f.stem for f in test_dir.glob('*.csv')]}" + ) + + # ------------------------------------------------------------------ + # F2: list_available — 查看可用数据资产 + # ------------------------------------------------------------------ + + def list_available(self) -> dict: + """ + 列出所有可用数据资产 + + Returns: + dict,按类别列出数据概况 + """ + result = {} + + # 日线行情 + daily_dir = self.config.daily_parquet_dir + if daily_dir.exists(): + years = sorted( + [d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()] + ) + result["daily_parquet"] = { + "years": years, + "path": str(daily_dir), + } + + # 股票信息 + info_dir = self.config.stock_info_dir + if info_dir.exists(): + files = [f.name for f in info_dir.iterdir() if f.is_file()] + result["stock_info"] = {"files": files, "path": str(info_dir)} + + # 测试数据 + test_dir = self.config.test_datasets_dir + if test_dir.exists(): + datasets = [f.stem for f in test_dir.glob("*.csv")] + result["test_datasets"] = {"datasets": datasets, "path": str(test_dir)} + + return result + + # ------------------------------------------------------------------ + # 内部工具 + # ------------------------------------------------------------------ + + def _detect_years(self) -> List[int]: + """自动检测可用的年份目录""" + daily_dir = self.config.daily_parquet_dir + if not daily_dir.exists(): + return [2024, 2025] # 合理兜底 + return sorted( + int(d.name) + for d in daily_dir.iterdir() + if d.is_dir() and d.name.isdigit() + )