100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
"""
|
|
数据平台配置层
|
|
|
|
配置优先级:
|
|
1. 环境变量 SANGUO_QUANT_ROOT
|
|
2. 默认项目根目录(向上查找 sanguo_quant_live 标记文件)
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
import yaml
|
|
|
|
|
|
# 项目根目录标记文件
|
|
_ROOT_MARKER = ".gitignore"
|
|
|
|
|
|
def _find_project_root() -> Path:
|
|
"""向上查找项目根目录"""
|
|
env = os.environ.get("SANGUO_QUANT_ROOT")
|
|
if env:
|
|
return Path(env).expanduser().absolute()
|
|
|
|
# 从当前文件向上查找
|
|
current = Path(__file__).parent
|
|
for _ in range(5):
|
|
if (current / _ROOT_MARKER).exists():
|
|
return current
|
|
current = current.parent
|
|
|
|
# 兜底:data_platform 的父目录
|
|
return Path(__file__).parent.parent
|
|
|
|
|
|
class DataPlatformConfig:
|
|
"""数据平台配置"""
|
|
|
|
def __init__(self, project_root: Optional[str] = None):
|
|
if project_root:
|
|
self.root = Path(project_root).expanduser().absolute()
|
|
else:
|
|
self.root = _find_project_root()
|
|
|
|
# 数据根目录(赵云数据区)
|
|
self.data_root = self.root / "zhaoyun-data" / "data"
|
|
self.raw_dir = self.data_root / "raw"
|
|
self.processed_dir = self.data_root / "processed"
|
|
self.running_dir = self.data_root / "running_data"
|
|
|
|
# 加载 yaml 覆盖(可选)
|
|
self._overrides = {}
|
|
config_path = Path(__file__).parent / "config.yaml"
|
|
if config_path.exists():
|
|
with open(config_path, "r", encoding="utf-8") as f:
|
|
self._overrides = yaml.safe_load(f) or {}
|
|
|
|
# --- 核心路径属性 ---
|
|
|
|
@property
|
|
def daily_parquet_dir(self) -> Path:
|
|
"""日线行情 Parquet 根目录(按年份子目录)"""
|
|
return self.raw_dir / "daily"
|
|
|
|
@property
|
|
def stock_info_dir(self) -> Path:
|
|
"""股票基础信息目录"""
|
|
return self.raw_dir / "stock_info"
|
|
|
|
@property
|
|
def test_datasets_dir(self) -> Path:
|
|
"""测试数据集目录"""
|
|
return self.processed_dir / "test_datasets"
|
|
|
|
@property
|
|
def financial_dir(self) -> Path:
|
|
"""财务数据目录"""
|
|
return self.raw_dir / "financial"
|
|
|
|
# --- 工具方法 ---
|
|
|
|
def get_daily_parquet(self, year: int) -> Path:
|
|
"""获取指定年份的日线 parquet 目录"""
|
|
return self.daily_parquet_dir / str(year)
|
|
|
|
def get_stock_parquet(self, code: str, year: int) -> Path:
|
|
"""获取指定股票指定年份的 parquet 文件路径"""
|
|
prefix = "sh" if code.startswith("6") else "sz"
|
|
filename = f"{prefix}{code}_daily.parquet"
|
|
return self.daily_parquet_dir / str(year) / filename
|
|
|
|
def to_dict(self) -> dict:
|
|
"""导出配置为字典"""
|
|
return {
|
|
"root": str(self.root),
|
|
"data_root": str(self.data_root),
|
|
"daily_parquet_dir": str(self.daily_parquet_dir),
|
|
"stock_info_dir": str(self.stock_info_dir),
|
|
}
|