213 lines
7.0 KiB
Python
213 lines
7.0 KiB
Python
"""
|
||
DataCatalog - 统一数据访问接口
|
||
|
||
策略开发者只需通过 DataCatalog 获取数据,
|
||
无需关心底层文件路径和存储格式。
|
||
|
||
核心API:
|
||
- get_daily() 获取单只股票日线行情
|
||
- get_daily_batch() 批量获取多只股票日线行情
|
||
- get_stock_list() 获取股票基础信息/指数成分股
|
||
- get_test_data() 获取标准测试数据集
|
||
- list_available() 查看可用的数据资产
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import Dict, Optional, List
|
||
|
||
import pandas as pd
|
||
|
||
from data_platform.config import DataPlatformConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class DataCatalog:
|
||
"""
|
||
统一数据目录 —— 项目唯一数据入口
|
||
|
||
Usage:
|
||
from data_platform import DataCatalog
|
||
|
||
cat = DataCatalog()
|
||
df = cat.get_daily("600519", start="20250101", end="20260101")
|
||
stocks = cat.get_stock_list()
|
||
"""
|
||
|
||
def __init__(self, project_root: Optional[str] = None):
|
||
self.config = DataPlatformConfig(project_root)
|
||
|
||
# ------------------------------------------------------------------
|
||
# get_daily — 单只股票日线行情
|
||
# ------------------------------------------------------------------
|
||
|
||
def get_daily(
|
||
self,
|
||
code: str,
|
||
start: Optional[str] = None,
|
||
end: Optional[str] = None,
|
||
years: Optional[List[int]] = None,
|
||
) -> pd.DataFrame:
|
||
"""
|
||
获取单只股票日线行情(从本地 Parquet 读取)
|
||
|
||
Args:
|
||
code: 6位股票代码,如 "600519"
|
||
start: 起始日期 "YYYYMMDD"
|
||
end: 结束日期 "YYYYMMDD"
|
||
years: 指定年份列表,如 [2024, 2025];默认自动推断
|
||
|
||
Returns:
|
||
DataFrame,列: date, open, high, low, close, volume, amount, ...
|
||
"""
|
||
code = str(code).strip().zfill(6)
|
||
prefix = "sh" if code.startswith("6") else "sz"
|
||
pattern = f"{prefix}{code}_daily.parquet"
|
||
|
||
if years is None:
|
||
scan_years = self._detect_years()
|
||
else:
|
||
scan_years = years
|
||
|
||
frames = []
|
||
for year in sorted(scan_years):
|
||
fp = self.config.daily_parquet_dir / str(year) / pattern
|
||
if fp.exists():
|
||
frames.append(pd.read_parquet(fp))
|
||
|
||
if not frames:
|
||
raise FileNotFoundError(
|
||
f"未找到股票 {code} 的日线数据,扫描年份: {scan_years}"
|
||
)
|
||
|
||
df = pd.concat(frames, ignore_index=True)
|
||
df["date"] = pd.to_datetime(df["date"])
|
||
df = df.sort_values("date").reset_index(drop=True)
|
||
|
||
if start:
|
||
df = df[df["date"] >= pd.Timestamp(start)]
|
||
if end:
|
||
df = df[df["date"] <= pd.Timestamp(end)]
|
||
|
||
return df
|
||
|
||
# ------------------------------------------------------------------
|
||
# get_daily_batch — 批量获取多只股票日线
|
||
# ------------------------------------------------------------------
|
||
|
||
def get_daily_batch(
|
||
self,
|
||
codes: List[str],
|
||
start: Optional[str] = None,
|
||
end: Optional[str] = None,
|
||
) -> Dict[str, pd.DataFrame]:
|
||
"""
|
||
批量获取多只股票日线行情
|
||
|
||
Args:
|
||
codes: 股票代码列表,如 ["600519", "000001"]
|
||
start: 起始日期 "YYYYMMDD"
|
||
end: 结束日期 "YYYYMMDD"
|
||
|
||
Returns:
|
||
dict,key=股票代码, value=DataFrame
|
||
"""
|
||
result = {}
|
||
for code in codes:
|
||
try:
|
||
result[code] = self.get_daily(code, start=start, end=end)
|
||
except FileNotFoundError:
|
||
logger.warning("跳过 %s: 数据不存在", code)
|
||
return result
|
||
|
||
# ------------------------------------------------------------------
|
||
# get_stock_list — 股票列表 / 指数成分股
|
||
# ------------------------------------------------------------------
|
||
|
||
def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame:
|
||
"""
|
||
获取股票基础信息或指数成分股
|
||
|
||
Args:
|
||
index: 指数代码,如 "hs300";None 返回全部 A 股基础信息
|
||
"""
|
||
if index == "hs300":
|
||
fp = self.config.stock_info_dir / "hs300_constituents_latest.csv"
|
||
if not fp.exists():
|
||
raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}")
|
||
return pd.read_csv(fp)
|
||
|
||
info_dir = self.config.stock_info_dir
|
||
candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv"))
|
||
if not candidates:
|
||
raise FileNotFoundError(f"未找到股票基础信息文件: {info_dir}")
|
||
return pd.read_csv(candidates[-1])
|
||
|
||
# ------------------------------------------------------------------
|
||
# get_test_data — 标准测试数据集
|
||
# ------------------------------------------------------------------
|
||
|
||
def get_test_data(self, name: str) -> pd.DataFrame:
|
||
"""
|
||
获取标准测试数据集
|
||
|
||
Args:
|
||
name: 数据集名称,如 "600519" 或 "贵州茅台"
|
||
"""
|
||
test_dir = self.config.test_datasets_dir
|
||
if not test_dir.exists():
|
||
raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}")
|
||
|
||
for fp in test_dir.glob("*.csv"):
|
||
if name in fp.stem:
|
||
return pd.read_csv(fp, parse_dates=["date"])
|
||
|
||
raise FileNotFoundError(
|
||
f"未找到测试数据集 '{name}',可用: "
|
||
f"{[f.stem for f in test_dir.glob('*.csv')]}"
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# list_available — 查看可用数据资产
|
||
# ------------------------------------------------------------------
|
||
|
||
def list_available(self) -> dict:
|
||
"""列出所有可用数据资产"""
|
||
result = {}
|
||
|
||
daily_dir = self.config.daily_parquet_dir
|
||
if daily_dir.exists():
|
||
years = sorted(
|
||
[d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()]
|
||
)
|
||
result["daily_parquet"] = {"years": years, "path": str(daily_dir)}
|
||
|
||
info_dir = self.config.stock_info_dir
|
||
if info_dir.exists():
|
||
files = [f.name for f in info_dir.iterdir() if f.is_file()]
|
||
result["stock_info"] = {"files": files, "path": str(info_dir)}
|
||
|
||
test_dir = self.config.test_datasets_dir
|
||
if test_dir.exists():
|
||
datasets = [f.stem for f in test_dir.glob("*.csv")]
|
||
result["test_datasets"] = {"datasets": datasets, "path": str(test_dir)}
|
||
|
||
return result
|
||
|
||
# ------------------------------------------------------------------
|
||
# 内部工具
|
||
# ------------------------------------------------------------------
|
||
|
||
def _detect_years(self) -> List[int]:
|
||
"""自动检测可用的年份目录"""
|
||
daily_dir = self.config.daily_parquet_dir
|
||
if not daily_dir.exists():
|
||
return [2024, 2025]
|
||
return sorted(
|
||
int(d.name)
|
||
for d in daily_dir.iterdir()
|
||
if d.is_dir() and d.name.isdigit()
|
||
)
|