auto-sync: 2026-04-30 23:07:47

This commit is contained in:
cfdaily
2026-04-30 23:07:47 +08:00
parent 89f546ad77
commit a0934b939b
+39 -33
View File
@@ -4,8 +4,9 @@ DataCatalog - 统一数据访问接口
策略开发者只需通过 DataCatalog 获取数据, 策略开发者只需通过 DataCatalog 获取数据,
无需关心底层文件路径和存储格式。 无需关心底层文件路径和存储格式。
4个核心API 核心API
- get_daily() 获取单只股票日线行情 - get_daily() 获取单只股票日线行情
- get_daily_batch() 批量获取多只股票日线行情
- get_stock_list() 获取股票基础信息/指数成分股 - get_stock_list() 获取股票基础信息/指数成分股
- get_test_data() 获取标准测试数据集 - get_test_data() 获取标准测试数据集
- list_available() 查看可用的数据资产 - list_available() 查看可用的数据资产
@@ -14,7 +15,7 @@ DataCatalog - 统一数据访问接口
import os import os
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional, List from typing import Dict, Optional, List
import pandas as pd import pandas as pd
@@ -39,7 +40,7 @@ class DataCatalog:
self.config = DataPlatformConfig(project_root) self.config = DataPlatformConfig(project_root)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# F1: get_daily — 单只股票日线行情 # get_daily — 单只股票日线行情
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def get_daily( def get_daily(
@@ -65,7 +66,6 @@ class DataCatalog:
prefix = "sh" if code.startswith("6") else "sz" prefix = "sh" if code.startswith("6") else "sz"
pattern = f"{prefix}{code}_daily.parquet" pattern = f"{prefix}{code}_daily.parquet"
# 确定要扫描的年份
if years is None: if years is None:
scan_years = self._detect_years() scan_years = self._detect_years()
else: else:
@@ -86,7 +86,6 @@ class DataCatalog:
df["date"] = pd.to_datetime(df["date"]) df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True) df = df.sort_values("date").reset_index(drop=True)
# 日期过滤
if start: if start:
df = df[df["date"] >= pd.Timestamp(start)] df = df[df["date"] >= pd.Timestamp(start)]
if end: if end:
@@ -95,7 +94,36 @@ class DataCatalog:
return df return df
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# F1: get_stock_list — 股票列表 / 指数成分股 # get_daily_batch — 批量获取多只股票日线
# ------------------------------------------------------------------
def get_daily_batch(
self,
codes: List[str],
start: Optional[str] = None,
end: Optional[str] = None,
) -> Dict[str, pd.DataFrame]:
"""
批量获取多只股票日线行情
Args:
codes: 股票代码列表,如 ["600519", "000001"]
start: 起始日期 "YYYYMMDD"
end: 结束日期 "YYYYMMDD"
Returns:
dictkey=股票代码, value=DataFrame
"""
result = {}
for code in codes:
try:
result[code] = self.get_daily(code, start=start, end=end)
except FileNotFoundError:
logger.warning("跳过 %s: 数据不存在", code)
return result
# ------------------------------------------------------------------
# get_stock_list — 股票列表 / 指数成分股
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame: def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame:
@@ -104,9 +132,6 @@ class DataCatalog:
Args: Args:
index: 指数代码,如 "hs300"None 返回全部 A 股基础信息 index: 指数代码,如 "hs300"None 返回全部 A 股基础信息
Returns:
DataFrame
""" """
if index == "hs300": if index == "hs300":
fp = self.config.stock_info_dir / "hs300_constituents_latest.csv" fp = self.config.stock_info_dir / "hs300_constituents_latest.csv"
@@ -114,7 +139,6 @@ class DataCatalog:
raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}") raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}")
return pd.read_csv(fp) return pd.read_csv(fp)
# 全部 A 股基础信息 —— 找最新的 stock_basic_info 文件
info_dir = self.config.stock_info_dir info_dir = self.config.stock_info_dir
candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv")) candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv"))
if not candidates: if not candidates:
@@ -122,7 +146,7 @@ class DataCatalog:
return pd.read_csv(candidates[-1]) return pd.read_csv(candidates[-1])
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# F1: get_test_data — 标准测试数据集 # get_test_data — 标准测试数据集
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def get_test_data(self, name: str) -> pd.DataFrame: def get_test_data(self, name: str) -> pd.DataFrame:
@@ -131,18 +155,11 @@ class DataCatalog:
Args: Args:
name: 数据集名称,如 "600519""贵州茅台" name: 数据集名称,如 "600519""贵州茅台"
Returns:
DataFrame
Example:
cat.get_test_data("600519") # 茅台252日数据
""" """
test_dir = self.config.test_datasets_dir test_dir = self.config.test_datasets_dir
if not test_dir.exists(): if not test_dir.exists():
raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}") raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}")
# 模糊匹配:文件名包含 code 或 名称
for fp in test_dir.glob("*.csv"): for fp in test_dir.glob("*.csv"):
if name in fp.stem: if name in fp.stem:
return pd.read_csv(fp, parse_dates=["date"]) return pd.read_csv(fp, parse_dates=["date"])
@@ -153,36 +170,25 @@ class DataCatalog:
) )
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# F2: list_available — 查看可用数据资产 # list_available — 查看可用数据资产
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def list_available(self) -> dict: def list_available(self) -> dict:
""" """列出所有可用数据资产"""
列出所有可用数据资产
Returns:
dict,按类别列出数据概况
"""
result = {} result = {}
# 日线行情
daily_dir = self.config.daily_parquet_dir daily_dir = self.config.daily_parquet_dir
if daily_dir.exists(): if daily_dir.exists():
years = sorted( years = sorted(
[d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()] [d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()]
) )
result["daily_parquet"] = { result["daily_parquet"] = {"years": years, "path": str(daily_dir)}
"years": years,
"path": str(daily_dir),
}
# 股票信息
info_dir = self.config.stock_info_dir info_dir = self.config.stock_info_dir
if info_dir.exists(): if info_dir.exists():
files = [f.name for f in info_dir.iterdir() if f.is_file()] files = [f.name for f in info_dir.iterdir() if f.is_file()]
result["stock_info"] = {"files": files, "path": str(info_dir)} result["stock_info"] = {"files": files, "path": str(info_dir)}
# 测试数据
test_dir = self.config.test_datasets_dir test_dir = self.config.test_datasets_dir
if test_dir.exists(): if test_dir.exists():
datasets = [f.stem for f in test_dir.glob("*.csv")] datasets = [f.stem for f in test_dir.glob("*.csv")]
@@ -198,7 +204,7 @@ class DataCatalog:
"""自动检测可用的年份目录""" """自动检测可用的年份目录"""
daily_dir = self.config.daily_parquet_dir daily_dir = self.config.daily_parquet_dir
if not daily_dir.exists(): if not daily_dir.exists():
return [2024, 2025] # 合理兜底 return [2024, 2025]
return sorted( return sorted(
int(d.name) int(d.name)
for d in daily_dir.iterdir() for d in daily_dir.iterdir()