Files
2026-04-30 23:07:47 +08:00

213 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
DataCatalog - 统一数据访问接口
策略开发者只需通过 DataCatalog 获取数据,
无需关心底层文件路径和存储格式。
核心API
- get_daily() 获取单只股票日线行情
- get_daily_batch() 批量获取多只股票日线行情
- get_stock_list() 获取股票基础信息/指数成分股
- get_test_data() 获取标准测试数据集
- list_available() 查看可用的数据资产
"""
import os
import logging
from pathlib import Path
from typing import Dict, Optional, List
import pandas as pd
from data_platform.config import DataPlatformConfig
logger = logging.getLogger(__name__)
class DataCatalog:
"""
统一数据目录 —— 项目唯一数据入口
Usage:
from data_platform import DataCatalog
cat = DataCatalog()
df = cat.get_daily("600519", start="20250101", end="20260101")
stocks = cat.get_stock_list()
"""
def __init__(self, project_root: Optional[str] = None):
self.config = DataPlatformConfig(project_root)
# ------------------------------------------------------------------
# get_daily — 单只股票日线行情
# ------------------------------------------------------------------
def get_daily(
self,
code: str,
start: Optional[str] = None,
end: Optional[str] = None,
years: Optional[List[int]] = None,
) -> pd.DataFrame:
"""
获取单只股票日线行情(从本地 Parquet 读取)
Args:
code: 6位股票代码,如 "600519"
start: 起始日期 "YYYYMMDD"
end: 结束日期 "YYYYMMDD"
years: 指定年份列表,如 [2024, 2025];默认自动推断
Returns:
DataFrame,列: date, open, high, low, close, volume, amount, ...
"""
code = str(code).strip().zfill(6)
prefix = "sh" if code.startswith("6") else "sz"
pattern = f"{prefix}{code}_daily.parquet"
if years is None:
scan_years = self._detect_years()
else:
scan_years = years
frames = []
for year in sorted(scan_years):
fp = self.config.daily_parquet_dir / str(year) / pattern
if fp.exists():
frames.append(pd.read_parquet(fp))
if not frames:
raise FileNotFoundError(
f"未找到股票 {code} 的日线数据,扫描年份: {scan_years}"
)
df = pd.concat(frames, ignore_index=True)
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True)
if start:
df = df[df["date"] >= pd.Timestamp(start)]
if end:
df = df[df["date"] <= pd.Timestamp(end)]
return df
# ------------------------------------------------------------------
# get_daily_batch — 批量获取多只股票日线
# ------------------------------------------------------------------
def get_daily_batch(
self,
codes: List[str],
start: Optional[str] = None,
end: Optional[str] = None,
) -> Dict[str, pd.DataFrame]:
"""
批量获取多只股票日线行情
Args:
codes: 股票代码列表,如 ["600519", "000001"]
start: 起始日期 "YYYYMMDD"
end: 结束日期 "YYYYMMDD"
Returns:
dictkey=股票代码, value=DataFrame
"""
result = {}
for code in codes:
try:
result[code] = self.get_daily(code, start=start, end=end)
except FileNotFoundError:
logger.warning("跳过 %s: 数据不存在", code)
return result
# ------------------------------------------------------------------
# get_stock_list — 股票列表 / 指数成分股
# ------------------------------------------------------------------
def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame:
"""
获取股票基础信息或指数成分股
Args:
index: 指数代码,如 "hs300"None 返回全部 A 股基础信息
"""
if index == "hs300":
fp = self.config.stock_info_dir / "hs300_constituents_latest.csv"
if not fp.exists():
raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}")
return pd.read_csv(fp)
info_dir = self.config.stock_info_dir
candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv"))
if not candidates:
raise FileNotFoundError(f"未找到股票基础信息文件: {info_dir}")
return pd.read_csv(candidates[-1])
# ------------------------------------------------------------------
# get_test_data — 标准测试数据集
# ------------------------------------------------------------------
def get_test_data(self, name: str) -> pd.DataFrame:
"""
获取标准测试数据集
Args:
name: 数据集名称,如 "600519""贵州茅台"
"""
test_dir = self.config.test_datasets_dir
if not test_dir.exists():
raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}")
for fp in test_dir.glob("*.csv"):
if name in fp.stem:
return pd.read_csv(fp, parse_dates=["date"])
raise FileNotFoundError(
f"未找到测试数据集 '{name}',可用: "
f"{[f.stem for f in test_dir.glob('*.csv')]}"
)
# ------------------------------------------------------------------
# list_available — 查看可用数据资产
# ------------------------------------------------------------------
def list_available(self) -> dict:
"""列出所有可用数据资产"""
result = {}
daily_dir = self.config.daily_parquet_dir
if daily_dir.exists():
years = sorted(
[d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()]
)
result["daily_parquet"] = {"years": years, "path": str(daily_dir)}
info_dir = self.config.stock_info_dir
if info_dir.exists():
files = [f.name for f in info_dir.iterdir() if f.is_file()]
result["stock_info"] = {"files": files, "path": str(info_dir)}
test_dir = self.config.test_datasets_dir
if test_dir.exists():
datasets = [f.stem for f in test_dir.glob("*.csv")]
result["test_datasets"] = {"datasets": datasets, "path": str(test_dir)}
return result
# ------------------------------------------------------------------
# 内部工具
# ------------------------------------------------------------------
def _detect_years(self) -> List[int]:
"""自动检测可用的年份目录"""
daily_dir = self.config.daily_parquet_dir
if not daily_dir.exists():
return [2024, 2025]
return sorted(
int(d.name)
for d in daily_dir.iterdir()
if d.is_dir() and d.name.isdigit()
)