auto-sync: 2026-05-02 21:26:32

This commit is contained in:
cfdaily
2026-05-02 21:26:32 +08:00
parent 6ad71f54ec
commit 6ba8540f41
+198
View File
@@ -0,0 +1,198 @@
#!/usr/bin/env python3
"""
导入15分钟线Parquet到vnpy SQLite DB
复用P1导入逻辑(pandas向量化+批量INSERT OR REPLACE
与 import_vnpy_daily_fast.py 的区别:
- interval = '15m'(而非 'd'
- datetime已是 "YYYY-MM-DD HH:MM:SS" 格式
- 从单个parquet文件导入(非按年份目录)
- volume和amount是object类型需转float
用法:
python3 import_vnpy_minute.py --scope hs300
python3 import_vnpy_minute.py --scope all
python3 import_vnpy_minute.py --codes 000001 600519
python3 import_vnpy_minute.py --db /path/to/quant_trading.db
"""
import sqlite3
import re
import sys
import time
import argparse
import logging
from pathlib import Path
import pandas as pd
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# --- 配置 ---
MINUTE_DIR = Path("/Volumes/stock/minute_kline/15min")
DB_PATH = "/tmp/quant_trading_import.db"
INTERVAL = "15m"
BATCH_SIZE = 50000
HS300_FILE = Path("/Users/chufeng/.openclaw/sanguo_projects/sanguo_quant_live/zhaoyun-data/data/raw/stock_info/hs300_constituents_latest.csv")
ALL_STOCKS_FILE = Path("/Users/chufeng/.openclaw/sanguo_projects/sanguo_quant_live/zhaoyun-data/data/raw/stock_info/stock_basic_info_raw_20260326_113530.csv")
def parse_filename(filename: str):
"""解析文件名: sz000001_15min.parquet -> (code, exchange)"""
m = re.match(r"(sh|sz)(\d{6})_15min\.parquet", filename)
if not m:
return None, None
prefix, code = m.groups()
return code, "SSE" if prefix == "sh" else "SZSE"
def import_file(conn, filepath: Path) -> int:
"""导入单个Parquet文件,返回导入行数"""
code, exchange = parse_filename(filepath.name)
if code is None:
return 0
try:
df = pd.read_parquet(filepath)
if df.empty:
return 0
except Exception as e:
logger.warning("读取失败 %s: %s", filepath.name, e)
return 0
# 构建导入数据
df = df.rename(columns={
"open": "open_price",
"high": "high_price",
"low": "low_price",
"close": "close_price",
"amount": "turnover",
})
df["symbol"] = code
df["exchange"] = exchange
df["interval"] = INTERVAL
df["open_interest"] = 0.0
df["datetime"] = df["day"].astype(str)
# 类型转换(volume/amount可能是object
for col in ["volume", "turnover", "open_price", "high_price", "low_price", "close_price"]:
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
values = df[[
"symbol", "exchange", "datetime", "interval", "volume", "turnover",
"open_interest", "open_price", "high_price", "low_price", "close_price"
]].values.tolist()
c = conn.cursor()
for i in range(0, len(values), BATCH_SIZE):
c.executemany(
"""INSERT OR REPLACE INTO dbbardata
(symbol, exchange, datetime, interval, volume, turnover, open_interest,
open_price, high_price, low_price, close_price)
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
values[i : i + BATCH_SIZE],
)
conn.commit()
return len(values)
def get_stock_list(scope: str):
"""获取股票代码列表"""
if scope == "hs300":
df = pd.read_csv(HS300_FILE)
for col in ["成分券代码", "代码", "code"]:
if col in df.columns:
return [str(c).zfill(6) for c in df[col].tolist()]
raise ValueError(f"HS300文件找不到代码列: {list(df.columns)}")
if scope == "all":
df = pd.read_csv(ALL_STOCKS_FILE)
for col in ["代码", "code", "股票代码"]:
if col in df.columns:
return [str(c).zfill(6) for c in df[col].tolist()]
raise ValueError(f"全市场文件找不到代码列: {list(df.columns)}")
raise ValueError(f"Unknown scope: {scope}")
def main():
parser = argparse.ArgumentParser(description="导入15分钟线到vnpy DB")
parser.add_argument("--scope", choices=["hs300", "all"], help="导入范围")
parser.add_argument("--codes", nargs="+", help="指定股票代码")
parser.add_argument("--db", default=DB_PATH, help="SQLite DB路径")
parser.add_argument("--minute-dir", default=str(MINUTE_DIR), help="分钟线Parquet目录")
args = parser.parse_args()
minute_dir = Path(args.minute_dir)
if not minute_dir.exists():
logger.error("分钟线目录不存在: %s", minute_dir)
sys.exit(1)
# 获取代码列表
if args.codes:
codes = set(args.codes)
elif args.scope:
codes = set(get_stock_list(args.scope))
else:
parser.error("必须指定 --scope 或 --codes")
# 匹配文件
files = sorted(minute_dir.glob("*_15min.parquet"))
matched = []
for f in files:
code, _ = parse_filename(f.name)
if code and code in codes:
matched.append(f)
logger.info("匹配文件: %d / %d", len(matched), len(files))
if not matched:
logger.error("没有匹配的Parquet文件")
sys.exit(1)
# 导入
conn = sqlite3.connect(args.db)
total_rows = 0
t_start = time.time()
for i, f in enumerate(matched):
t0 = time.time()
rows = import_file(conn, f)
t1 = time.time()
total_rows += rows
logger.info("[%d/%d] %s: %d rows (%.1fs) total=%d", i + 1, len(matched), f.name, rows, t1 - t0, total_rows)
# 更新 overview
c = conn.cursor()
c.execute(
"""INSERT OR REPLACE INTO dbbaroverview (symbol, exchange, interval, count, start, end)
SELECT symbol, exchange, interval, COUNT(*), MIN(datetime), MAX(datetime)
FROM dbbardata GROUP BY symbol, exchange, interval"""
)
conn.commit()
# 统计
c.execute("SELECT COUNT(*) FROM dbbardata WHERE interval = ?", (INTERVAL,))
minute_rows = c.fetchone()[0]
c.execute("SELECT COUNT(*) FROM dbbardata")
all_rows = c.fetchone()[0]
c.execute("SELECT COUNT(*) FROM dbbaroverview")
overview_count = c.fetchone()[0]
elapsed = time.time() - t_start
conn.close()
logger.info("=" * 50)
logger.info("导入完成")
logger.info("15分钟线: %d", minute_rows)
logger.info("总数据量: %d 行 (含日线)", all_rows)
logger.info("Overview: %d", overview_count)
logger.info("耗时: %.1f 秒 (%.1f 分钟)", elapsed, elapsed / 60)
logger.info("DB路径: %s", args.db)
if __name__ == "__main__":
main()