Files
2026-05-02 22:41:17 +08:00

206 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
导入15分钟线Parquet到vnpy SQLite DB
复用P1导入逻辑(pandas向量化+批量INSERT OR REPLACE
与 import_vnpy_daily_fast.py 的区别:
- interval = '15m'(而非 'd'
- datetime已是 "YYYY-MM-DD HH:MM:SS" 格式
- 从单个parquet文件导入(非按年份目录)
- volume和amount是object类型需转float
用法:
python3 import_vnpy_minute.py --scope hs300
python3 import_vnpy_minute.py --scope all
python3 import_vnpy_minute.py --codes 000001 600519
python3 import_vnpy_minute.py --db /path/to/quant_trading.db
"""
import sqlite3
import re
import sys
import time
import argparse
import logging
from pathlib import Path
import pandas as pd
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# --- 配置 ---
MINUTE_DIR = Path("/Volumes/stock/minute_kline/15min")
DB_PATH = "/tmp/quant_trading_import.db"
INTERVAL = "1m" # vnpy 4.x Interval.MINUTE.value = '1m', 统一用1m存储分钟线
BATCH_SIZE = 50000
HS300_FILE = Path("/Volumes/stock/A股数据/stock_info/hs300_constituents_latest.csv")
ALL_STOCKS_FILE = Path("/Volumes/stock/A股数据/stock_info/stock_basic_info_raw_20260326_113530.csv")
def parse_filename(filename: str):
"""解析文件名: sz000001_15min.parquet -> (code, exchange)"""
m = re.match(r"(sh|sz)(\d{6})_15min\.parquet", filename)
if not m:
return None, None
prefix, code = m.groups()
return code, "SSE" if prefix == "sh" else "SZSE"
def import_file(conn, filepath: Path) -> int:
"""导入单个Parquet文件,返回导入行数"""
code, exchange = parse_filename(filepath.name)
if code is None:
return 0
try:
df = pd.read_parquet(filepath)
if df.empty:
return 0
except Exception as e:
logger.warning("读取失败 %s: %s", filepath.name, e)
return 0
# 构建导入数据
df = df.rename(columns={
"open": "open_price",
"high": "high_price",
"low": "low_price",
"close": "close_price",
"amount": "turnover",
})
df["symbol"] = code
df["exchange"] = exchange
df["interval"] = INTERVAL
df["open_interest"] = 0.0
df["datetime"] = df["day"].astype(str)
# 类型转换(volume/amount可能是object
for col in ["volume", "turnover", "open_price", "high_price", "low_price", "close_price"]:
df[col] = pd.to_numeric(df[col], errors="coerce")
# 丢弃NaN行(价格/成交量为NaN说明原始数据异常)
na_before = len(df)
df = df.dropna(subset=["open_price", "close_price"])
if len(df) < na_before:
logger.warning("%s 丢弃 %d 条NaN行", filepath.name, na_before - len(df))
df["volume"] = df["volume"].fillna(0.0)
df["turnover"] = df["turnover"].fillna(0.0)
values = df[[
"symbol", "exchange", "datetime", "interval", "volume", "turnover",
"open_interest", "open_price", "high_price", "low_price", "close_price"
]].values.tolist()
c = conn.cursor()
for i in range(0, len(values), BATCH_SIZE):
c.executemany(
"""INSERT OR REPLACE INTO dbbardata
(symbol, exchange, datetime, interval, volume, turnover, open_interest,
open_price, high_price, low_price, close_price)
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
values[i : i + BATCH_SIZE],
)
conn.commit()
return len(values)
def get_stock_list(scope: str):
"""获取股票代码列表"""
if scope == "hs300":
df = pd.read_csv(HS300_FILE)
for col in ["成分券代码", "代码", "code"]:
if col in df.columns:
return [str(c).zfill(6) for c in df[col].tolist()]
raise ValueError(f"HS300文件找不到代码列: {list(df.columns)}")
if scope == "all":
df = pd.read_csv(ALL_STOCKS_FILE)
for col in ["代码", "code", "股票代码"]:
if col in df.columns:
return [str(c).zfill(6) for c in df[col].tolist()]
raise ValueError(f"全市场文件找不到代码列: {list(df.columns)}")
raise ValueError(f"Unknown scope: {scope}")
def main():
parser = argparse.ArgumentParser(description="导入15分钟线到vnpy DB")
parser.add_argument("--scope", choices=["hs300", "all"], help="导入范围")
parser.add_argument("--codes", nargs="+", help="指定股票代码")
parser.add_argument("--db", default=DB_PATH, help="SQLite DB路径")
parser.add_argument("--minute-dir", default=str(MINUTE_DIR), help="分钟线Parquet目录")
args = parser.parse_args()
minute_dir = Path(args.minute_dir)
if not minute_dir.exists():
logger.error("分钟线目录不存在: %s", minute_dir)
sys.exit(1)
# 获取代码列表
if args.codes:
codes = set(args.codes)
elif args.scope:
codes = set(get_stock_list(args.scope))
else:
parser.error("必须指定 --scope 或 --codes")
# 匹配文件
files = sorted(minute_dir.glob("*_15min.parquet"))
matched = []
for f in files:
code, _ = parse_filename(f.name)
if code and code in codes:
matched.append(f)
logger.info("匹配文件: %d / %d", len(matched), len(files))
if not matched:
logger.error("没有匹配的Parquet文件")
sys.exit(1)
# 导入
conn = sqlite3.connect(args.db)
total_rows = 0
t_start = time.time()
for i, f in enumerate(matched):
t0 = time.time()
rows = import_file(conn, f)
t1 = time.time()
total_rows += rows
logger.info("[%d/%d] %s: %d rows (%.1fs) total=%d", i + 1, len(matched), f.name, rows, t1 - t0, total_rows)
# 更新 overview
c = conn.cursor()
c.execute(
"""INSERT OR REPLACE INTO dbbaroverview (symbol, exchange, interval, count, start, end)
SELECT symbol, exchange, interval, COUNT(*), MIN(datetime), MAX(datetime)
FROM dbbardata GROUP BY symbol, exchange, interval"""
)
conn.commit()
# 统计
c.execute("SELECT COUNT(*) FROM dbbardata WHERE interval = ?", (INTERVAL,))
minute_rows = c.fetchone()[0]
c.execute("SELECT COUNT(*) FROM dbbardata")
all_rows = c.fetchone()[0]
c.execute("SELECT COUNT(*) FROM dbbaroverview")
overview_count = c.fetchone()[0]
elapsed = time.time() - t_start
conn.close()
logger.info("=" * 50)
logger.info("导入完成")
logger.info("15分钟线: %d", minute_rows)
logger.info("总数据量: %d 行 (含日线)", all_rows)
logger.info("Overview: %d", overview_count)
logger.info("耗时: %.1f 秒 (%.1f 分钟)", elapsed, elapsed / 60)
logger.info("DB路径: %s", args.db)
if __name__ == "__main__":
main()