""" DataCatalog - 统一数据访问接口 策略开发者只需通过 DataCatalog 获取数据, 无需关心底层文件路径和存储格式。 核心API: - get_daily() 获取单只股票日线行情 - get_daily_batch() 批量获取多只股票日线行情 - get_stock_list() 获取股票基础信息/指数成分股 - get_test_data() 获取标准测试数据集 - list_available() 查看可用的数据资产 """ import os import logging from pathlib import Path from typing import Dict, Optional, List import pandas as pd from data_platform.config import DataPlatformConfig logger = logging.getLogger(__name__) class DataCatalog: """ 统一数据目录 —— 项目唯一数据入口 Usage: from data_platform import DataCatalog cat = DataCatalog() df = cat.get_daily("600519", start="20250101", end="20260101") stocks = cat.get_stock_list() """ def __init__(self, project_root: Optional[str] = None): self.config = DataPlatformConfig(project_root) # ------------------------------------------------------------------ # get_daily — 单只股票日线行情 # ------------------------------------------------------------------ def get_daily( self, code: str, start: Optional[str] = None, end: Optional[str] = None, years: Optional[List[int]] = None, ) -> pd.DataFrame: """ 获取单只股票日线行情(从本地 Parquet 读取) Args: code: 6位股票代码,如 "600519" start: 起始日期 "YYYYMMDD" end: 结束日期 "YYYYMMDD" years: 指定年份列表,如 [2024, 2025];默认自动推断 Returns: DataFrame,列: date, open, high, low, close, volume, amount, ... """ code = str(code).strip().zfill(6) prefix = "sh" if code.startswith("6") else "sz" pattern = f"{prefix}{code}_daily.parquet" if years is None: scan_years = self._detect_years() else: scan_years = years frames = [] for year in sorted(scan_years): fp = self.config.daily_parquet_dir / str(year) / pattern if fp.exists(): frames.append(pd.read_parquet(fp)) if not frames: raise FileNotFoundError( f"未找到股票 {code} 的日线数据,扫描年份: {scan_years}" ) df = pd.concat(frames, ignore_index=True) df["date"] = pd.to_datetime(df["date"]) df = df.sort_values("date").reset_index(drop=True) if start: df = df[df["date"] >= pd.Timestamp(start)] if end: df = df[df["date"] <= pd.Timestamp(end)] return df # ------------------------------------------------------------------ # get_daily_batch — 批量获取多只股票日线 # ------------------------------------------------------------------ def get_daily_batch( self, codes: List[str], start: Optional[str] = None, end: Optional[str] = None, ) -> Dict[str, pd.DataFrame]: """ 批量获取多只股票日线行情 Args: codes: 股票代码列表,如 ["600519", "000001"] start: 起始日期 "YYYYMMDD" end: 结束日期 "YYYYMMDD" Returns: dict,key=股票代码, value=DataFrame """ result = {} for code in codes: try: result[code] = self.get_daily(code, start=start, end=end) except FileNotFoundError: logger.warning("跳过 %s: 数据不存在", code) return result # ------------------------------------------------------------------ # get_stock_list — 股票列表 / 指数成分股 # ------------------------------------------------------------------ def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame: """ 获取股票基础信息或指数成分股 Args: index: 指数代码,如 "hs300";None 返回全部 A 股基础信息 """ if index == "hs300": fp = self.config.stock_info_dir / "hs300_constituents_latest.csv" if not fp.exists(): raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}") return pd.read_csv(fp) info_dir = self.config.stock_info_dir candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv")) if not candidates: raise FileNotFoundError(f"未找到股票基础信息文件: {info_dir}") return pd.read_csv(candidates[-1]) # ------------------------------------------------------------------ # get_test_data — 标准测试数据集 # ------------------------------------------------------------------ def get_test_data(self, name: str) -> pd.DataFrame: """ 获取标准测试数据集 Args: name: 数据集名称,如 "600519" 或 "贵州茅台" """ test_dir = self.config.test_datasets_dir if not test_dir.exists(): raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}") for fp in test_dir.glob("*.csv"): if name in fp.stem: return pd.read_csv(fp, parse_dates=["date"]) raise FileNotFoundError( f"未找到测试数据集 '{name}',可用: " f"{[f.stem for f in test_dir.glob('*.csv')]}" ) # ------------------------------------------------------------------ # list_available — 查看可用数据资产 # ------------------------------------------------------------------ def list_available(self) -> dict: """列出所有可用数据资产""" result = {} daily_dir = self.config.daily_parquet_dir if daily_dir.exists(): years = sorted( [d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()] ) result["daily_parquet"] = {"years": years, "path": str(daily_dir)} info_dir = self.config.stock_info_dir if info_dir.exists(): files = [f.name for f in info_dir.iterdir() if f.is_file()] result["stock_info"] = {"files": files, "path": str(info_dir)} test_dir = self.config.test_datasets_dir if test_dir.exists(): datasets = [f.stem for f in test_dir.glob("*.csv")] result["test_datasets"] = {"datasets": datasets, "path": str(test_dir)} return result # ------------------------------------------------------------------ # 内部工具 # ------------------------------------------------------------------ def _detect_years(self) -> List[int]: """自动检测可用的年份目录""" daily_dir = self.config.daily_parquet_dir if not daily_dir.exists(): return [2024, 2025] return sorted( int(d.name) for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit() )