diff --git a/data_platform/catalog.py b/data_platform/catalog.py index 35c576108..e120e0e5b 100644 --- a/data_platform/catalog.py +++ b/data_platform/catalog.py @@ -4,8 +4,9 @@ DataCatalog - 统一数据访问接口 策略开发者只需通过 DataCatalog 获取数据, 无需关心底层文件路径和存储格式。 -4个核心API: +核心API: - get_daily() 获取单只股票日线行情 +- get_daily_batch() 批量获取多只股票日线行情 - get_stock_list() 获取股票基础信息/指数成分股 - get_test_data() 获取标准测试数据集 - list_available() 查看可用的数据资产 @@ -14,7 +15,7 @@ DataCatalog - 统一数据访问接口 import os import logging from pathlib import Path -from typing import Optional, List +from typing import Dict, Optional, List import pandas as pd @@ -39,7 +40,7 @@ class DataCatalog: self.config = DataPlatformConfig(project_root) # ------------------------------------------------------------------ - # F1: get_daily — 单只股票日线行情 + # get_daily — 单只股票日线行情 # ------------------------------------------------------------------ def get_daily( @@ -65,7 +66,6 @@ class DataCatalog: prefix = "sh" if code.startswith("6") else "sz" pattern = f"{prefix}{code}_daily.parquet" - # 确定要扫描的年份 if years is None: scan_years = self._detect_years() else: @@ -86,7 +86,6 @@ class DataCatalog: df["date"] = pd.to_datetime(df["date"]) df = df.sort_values("date").reset_index(drop=True) - # 日期过滤 if start: df = df[df["date"] >= pd.Timestamp(start)] if end: @@ -95,7 +94,36 @@ class DataCatalog: return df # ------------------------------------------------------------------ - # F1: get_stock_list — 股票列表 / 指数成分股 + # get_daily_batch — 批量获取多只股票日线 + # ------------------------------------------------------------------ + + def get_daily_batch( + self, + codes: List[str], + start: Optional[str] = None, + end: Optional[str] = None, + ) -> Dict[str, pd.DataFrame]: + """ + 批量获取多只股票日线行情 + + Args: + codes: 股票代码列表,如 ["600519", "000001"] + start: 起始日期 "YYYYMMDD" + end: 结束日期 "YYYYMMDD" + + Returns: + dict,key=股票代码, value=DataFrame + """ + result = {} + for code in codes: + try: + result[code] = self.get_daily(code, start=start, end=end) + except FileNotFoundError: + logger.warning("跳过 %s: 数据不存在", code) + return result + + # ------------------------------------------------------------------ + # get_stock_list — 股票列表 / 指数成分股 # ------------------------------------------------------------------ def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame: @@ -104,9 +132,6 @@ class DataCatalog: Args: index: 指数代码,如 "hs300";None 返回全部 A 股基础信息 - - Returns: - DataFrame """ if index == "hs300": fp = self.config.stock_info_dir / "hs300_constituents_latest.csv" @@ -114,7 +139,6 @@ class DataCatalog: raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}") return pd.read_csv(fp) - # 全部 A 股基础信息 —— 找最新的 stock_basic_info 文件 info_dir = self.config.stock_info_dir candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv")) if not candidates: @@ -122,7 +146,7 @@ class DataCatalog: return pd.read_csv(candidates[-1]) # ------------------------------------------------------------------ - # F1: get_test_data — 标准测试数据集 + # get_test_data — 标准测试数据集 # ------------------------------------------------------------------ def get_test_data(self, name: str) -> pd.DataFrame: @@ -131,18 +155,11 @@ class DataCatalog: Args: name: 数据集名称,如 "600519" 或 "贵州茅台" - - Returns: - DataFrame - - Example: - cat.get_test_data("600519") # 茅台252日数据 """ test_dir = self.config.test_datasets_dir if not test_dir.exists(): raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}") - # 模糊匹配:文件名包含 code 或 名称 for fp in test_dir.glob("*.csv"): if name in fp.stem: return pd.read_csv(fp, parse_dates=["date"]) @@ -153,36 +170,25 @@ class DataCatalog: ) # ------------------------------------------------------------------ - # F2: list_available — 查看可用数据资产 + # list_available — 查看可用数据资产 # ------------------------------------------------------------------ def list_available(self) -> dict: - """ - 列出所有可用数据资产 - - Returns: - dict,按类别列出数据概况 - """ + """列出所有可用数据资产""" result = {} - # 日线行情 daily_dir = self.config.daily_parquet_dir if daily_dir.exists(): years = sorted( [d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()] ) - result["daily_parquet"] = { - "years": years, - "path": str(daily_dir), - } + result["daily_parquet"] = {"years": years, "path": str(daily_dir)} - # 股票信息 info_dir = self.config.stock_info_dir if info_dir.exists(): files = [f.name for f in info_dir.iterdir() if f.is_file()] result["stock_info"] = {"files": files, "path": str(info_dir)} - # 测试数据 test_dir = self.config.test_datasets_dir if test_dir.exists(): datasets = [f.stem for f in test_dir.glob("*.csv")] @@ -198,7 +204,7 @@ class DataCatalog: """自动检测可用的年份目录""" daily_dir = self.config.daily_parquet_dir if not daily_dir.exists(): - return [2024, 2025] # 合理兜底 + return [2024, 2025] return sorted( int(d.name) for d in daily_dir.iterdir()