auto-sync: 2026-04-30 23:07:47
This commit is contained in:
+39
-33
@@ -4,8 +4,9 @@ DataCatalog - 统一数据访问接口
|
|||||||
策略开发者只需通过 DataCatalog 获取数据,
|
策略开发者只需通过 DataCatalog 获取数据,
|
||||||
无需关心底层文件路径和存储格式。
|
无需关心底层文件路径和存储格式。
|
||||||
|
|
||||||
4个核心API:
|
核心API:
|
||||||
- get_daily() 获取单只股票日线行情
|
- get_daily() 获取单只股票日线行情
|
||||||
|
- get_daily_batch() 批量获取多只股票日线行情
|
||||||
- get_stock_list() 获取股票基础信息/指数成分股
|
- get_stock_list() 获取股票基础信息/指数成分股
|
||||||
- get_test_data() 获取标准测试数据集
|
- get_test_data() 获取标准测试数据集
|
||||||
- list_available() 查看可用的数据资产
|
- list_available() 查看可用的数据资产
|
||||||
@@ -14,7 +15,7 @@ DataCatalog - 统一数据访问接口
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
from typing import Dict, Optional, List
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ class DataCatalog:
|
|||||||
self.config = DataPlatformConfig(project_root)
|
self.config = DataPlatformConfig(project_root)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# F1: get_daily — 单只股票日线行情
|
# get_daily — 单只股票日线行情
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def get_daily(
|
def get_daily(
|
||||||
@@ -65,7 +66,6 @@ class DataCatalog:
|
|||||||
prefix = "sh" if code.startswith("6") else "sz"
|
prefix = "sh" if code.startswith("6") else "sz"
|
||||||
pattern = f"{prefix}{code}_daily.parquet"
|
pattern = f"{prefix}{code}_daily.parquet"
|
||||||
|
|
||||||
# 确定要扫描的年份
|
|
||||||
if years is None:
|
if years is None:
|
||||||
scan_years = self._detect_years()
|
scan_years = self._detect_years()
|
||||||
else:
|
else:
|
||||||
@@ -86,7 +86,6 @@ class DataCatalog:
|
|||||||
df["date"] = pd.to_datetime(df["date"])
|
df["date"] = pd.to_datetime(df["date"])
|
||||||
df = df.sort_values("date").reset_index(drop=True)
|
df = df.sort_values("date").reset_index(drop=True)
|
||||||
|
|
||||||
# 日期过滤
|
|
||||||
if start:
|
if start:
|
||||||
df = df[df["date"] >= pd.Timestamp(start)]
|
df = df[df["date"] >= pd.Timestamp(start)]
|
||||||
if end:
|
if end:
|
||||||
@@ -95,7 +94,36 @@ class DataCatalog:
|
|||||||
return df
|
return df
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# F1: get_stock_list — 股票列表 / 指数成分股
|
# get_daily_batch — 批量获取多只股票日线
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_daily_batch(
|
||||||
|
self,
|
||||||
|
codes: List[str],
|
||||||
|
start: Optional[str] = None,
|
||||||
|
end: Optional[str] = None,
|
||||||
|
) -> Dict[str, pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
批量获取多只股票日线行情
|
||||||
|
|
||||||
|
Args:
|
||||||
|
codes: 股票代码列表,如 ["600519", "000001"]
|
||||||
|
start: 起始日期 "YYYYMMDD"
|
||||||
|
end: 结束日期 "YYYYMMDD"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict,key=股票代码, value=DataFrame
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for code in codes:
|
||||||
|
try:
|
||||||
|
result[code] = self.get_daily(code, start=start, end=end)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning("跳过 %s: 数据不存在", code)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# get_stock_list — 股票列表 / 指数成分股
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame:
|
def get_stock_list(self, index: Optional[str] = None) -> pd.DataFrame:
|
||||||
@@ -104,9 +132,6 @@ class DataCatalog:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
index: 指数代码,如 "hs300";None 返回全部 A 股基础信息
|
index: 指数代码,如 "hs300";None 返回全部 A 股基础信息
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame
|
|
||||||
"""
|
"""
|
||||||
if index == "hs300":
|
if index == "hs300":
|
||||||
fp = self.config.stock_info_dir / "hs300_constituents_latest.csv"
|
fp = self.config.stock_info_dir / "hs300_constituents_latest.csv"
|
||||||
@@ -114,7 +139,6 @@ class DataCatalog:
|
|||||||
raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}")
|
raise FileNotFoundError(f"沪深300成分股文件不存在: {fp}")
|
||||||
return pd.read_csv(fp)
|
return pd.read_csv(fp)
|
||||||
|
|
||||||
# 全部 A 股基础信息 —— 找最新的 stock_basic_info 文件
|
|
||||||
info_dir = self.config.stock_info_dir
|
info_dir = self.config.stock_info_dir
|
||||||
candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv"))
|
candidates = sorted(info_dir.glob("stock_basic_info_raw_*.csv"))
|
||||||
if not candidates:
|
if not candidates:
|
||||||
@@ -122,7 +146,7 @@ class DataCatalog:
|
|||||||
return pd.read_csv(candidates[-1])
|
return pd.read_csv(candidates[-1])
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# F1: get_test_data — 标准测试数据集
|
# get_test_data — 标准测试数据集
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def get_test_data(self, name: str) -> pd.DataFrame:
|
def get_test_data(self, name: str) -> pd.DataFrame:
|
||||||
@@ -131,18 +155,11 @@ class DataCatalog:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: 数据集名称,如 "600519" 或 "贵州茅台"
|
name: 数据集名称,如 "600519" 或 "贵州茅台"
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame
|
|
||||||
|
|
||||||
Example:
|
|
||||||
cat.get_test_data("600519") # 茅台252日数据
|
|
||||||
"""
|
"""
|
||||||
test_dir = self.config.test_datasets_dir
|
test_dir = self.config.test_datasets_dir
|
||||||
if not test_dir.exists():
|
if not test_dir.exists():
|
||||||
raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}")
|
raise FileNotFoundError(f"测试数据集目录不存在: {test_dir}")
|
||||||
|
|
||||||
# 模糊匹配:文件名包含 code 或 名称
|
|
||||||
for fp in test_dir.glob("*.csv"):
|
for fp in test_dir.glob("*.csv"):
|
||||||
if name in fp.stem:
|
if name in fp.stem:
|
||||||
return pd.read_csv(fp, parse_dates=["date"])
|
return pd.read_csv(fp, parse_dates=["date"])
|
||||||
@@ -153,36 +170,25 @@ class DataCatalog:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# F2: list_available — 查看可用数据资产
|
# list_available — 查看可用数据资产
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def list_available(self) -> dict:
|
def list_available(self) -> dict:
|
||||||
"""
|
"""列出所有可用数据资产"""
|
||||||
列出所有可用数据资产
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict,按类别列出数据概况
|
|
||||||
"""
|
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
# 日线行情
|
|
||||||
daily_dir = self.config.daily_parquet_dir
|
daily_dir = self.config.daily_parquet_dir
|
||||||
if daily_dir.exists():
|
if daily_dir.exists():
|
||||||
years = sorted(
|
years = sorted(
|
||||||
[d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()]
|
[d.name for d in daily_dir.iterdir() if d.is_dir() and d.name.isdigit()]
|
||||||
)
|
)
|
||||||
result["daily_parquet"] = {
|
result["daily_parquet"] = {"years": years, "path": str(daily_dir)}
|
||||||
"years": years,
|
|
||||||
"path": str(daily_dir),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 股票信息
|
|
||||||
info_dir = self.config.stock_info_dir
|
info_dir = self.config.stock_info_dir
|
||||||
if info_dir.exists():
|
if info_dir.exists():
|
||||||
files = [f.name for f in info_dir.iterdir() if f.is_file()]
|
files = [f.name for f in info_dir.iterdir() if f.is_file()]
|
||||||
result["stock_info"] = {"files": files, "path": str(info_dir)}
|
result["stock_info"] = {"files": files, "path": str(info_dir)}
|
||||||
|
|
||||||
# 测试数据
|
|
||||||
test_dir = self.config.test_datasets_dir
|
test_dir = self.config.test_datasets_dir
|
||||||
if test_dir.exists():
|
if test_dir.exists():
|
||||||
datasets = [f.stem for f in test_dir.glob("*.csv")]
|
datasets = [f.stem for f in test_dir.glob("*.csv")]
|
||||||
@@ -198,7 +204,7 @@ class DataCatalog:
|
|||||||
"""自动检测可用的年份目录"""
|
"""自动检测可用的年份目录"""
|
||||||
daily_dir = self.config.daily_parquet_dir
|
daily_dir = self.config.daily_parquet_dir
|
||||||
if not daily_dir.exists():
|
if not daily_dir.exists():
|
||||||
return [2024, 2025] # 合理兜底
|
return [2024, 2025]
|
||||||
return sorted(
|
return sorted(
|
||||||
int(d.name)
|
int(d.name)
|
||||||
for d in daily_dir.iterdir()
|
for d in daily_dir.iterdir()
|
||||||
|
|||||||
Reference in New Issue
Block a user