#!/usr/bin/env python3 """ 导入15分钟线Parquet到vnpy SQLite DB 复用P1导入逻辑(pandas向量化+批量INSERT OR REPLACE) 与 import_vnpy_daily_fast.py 的区别: - interval = '15m'(而非 'd') - datetime已是 "YYYY-MM-DD HH:MM:SS" 格式 - 从单个parquet文件导入(非按年份目录) - volume和amount是object类型需转float 用法: python3 import_vnpy_minute.py --scope hs300 python3 import_vnpy_minute.py --scope all python3 import_vnpy_minute.py --codes 000001 600519 python3 import_vnpy_minute.py --db /path/to/quant_trading.db """ import sqlite3 import re import sys import time import argparse import logging from pathlib import Path import pandas as pd logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) # --- 配置 --- MINUTE_DIR = Path("/Volumes/stock/minute_kline/15min") DB_PATH = "/tmp/quant_trading_import.db" INTERVAL = "1m" # vnpy 4.x Interval.MINUTE.value = '1m', 统一用1m存储分钟线 BATCH_SIZE = 50000 HS300_FILE = Path("/Volumes/stock/A股数据/stock_info/hs300_constituents_latest.csv") ALL_STOCKS_FILE = Path("/Volumes/stock/sanguo_vnpy/data/all_stocks.csv") def parse_filename(filename: str): """解析文件名: sz000001_15min.parquet -> (code, exchange)""" m = re.match(r"(sh|sz)(\d{6})_15min\.parquet", filename) if not m: return None, None prefix, code = m.groups() return code, "SSE" if prefix == "sh" else "SZSE" def import_file(conn, filepath: Path) -> int: """导入单个Parquet文件,返回导入行数""" code, exchange = parse_filename(filepath.name) if code is None: return 0 try: df = pd.read_parquet(filepath) if df.empty: return 0 except Exception as e: logger.warning("读取失败 %s: %s", filepath.name, e) return 0 # 构建导入数据 df = df.rename(columns={ "open": "open_price", "high": "high_price", "low": "low_price", "close": "close_price", "amount": "turnover", }) df["symbol"] = code df["exchange"] = exchange df["interval"] = INTERVAL df["open_interest"] = 0.0 df["datetime"] = df["day"].astype(str) # 类型转换(volume/amount可能是object) for col in ["volume", "turnover", "open_price", "high_price", "low_price", "close_price"]: df[col] = pd.to_numeric(df[col], errors="coerce") # 丢弃NaN行(价格/成交量为NaN说明原始数据异常) na_before = len(df) df = df.dropna(subset=["open_price", "close_price"]) if len(df) < na_before: logger.warning("%s 丢弃 %d 条NaN行", filepath.name, na_before - len(df)) df["volume"] = df["volume"].fillna(0.0) df["turnover"] = df["turnover"].fillna(0.0) values = df[[ "symbol", "exchange", "datetime", "interval", "volume", "turnover", "open_interest", "open_price", "high_price", "low_price", "close_price" ]].values.tolist() c = conn.cursor() for i in range(0, len(values), BATCH_SIZE): c.executemany( """INSERT OR REPLACE INTO dbbardata (symbol, exchange, datetime, interval, volume, turnover, open_interest, open_price, high_price, low_price, close_price) VALUES (?,?,?,?,?,?,?,?,?,?,?)""", values[i : i + BATCH_SIZE], ) conn.commit() return len(values) def get_stock_list(scope: str): """获取股票代码列表""" if scope == "hs300": df = pd.read_csv(HS300_FILE) for col in ["成分券代码", "代码", "code"]: if col in df.columns: return [str(c).zfill(6) for c in df[col].tolist()] raise ValueError(f"HS300文件找不到代码列: {list(df.columns)}") if scope == "all": df = pd.read_csv(ALL_STOCKS_FILE) for col in ["代码", "code", "股票代码"]: if col in df.columns: return [str(c).zfill(6) for c in df[col].tolist()] raise ValueError(f"全市场文件找不到代码列: {list(df.columns)}") raise ValueError(f"Unknown scope: {scope}") def main(): parser = argparse.ArgumentParser(description="导入15分钟线到vnpy DB") parser.add_argument("--scope", choices=["hs300", "all"], help="导入范围") parser.add_argument("--codes", nargs="+", help="指定股票代码") parser.add_argument("--db", default=DB_PATH, help="SQLite DB路径") parser.add_argument("--minute-dir", default=str(MINUTE_DIR), help="分钟线Parquet目录") args = parser.parse_args() minute_dir = Path(args.minute_dir) if not minute_dir.exists(): logger.error("分钟线目录不存在: %s", minute_dir) sys.exit(1) # 获取代码列表 if args.codes: codes = set(args.codes) elif args.scope: codes = set(get_stock_list(args.scope)) else: parser.error("必须指定 --scope 或 --codes") # 匹配文件 files = sorted(minute_dir.glob("*_15min.parquet")) matched = [] for f in files: code, _ = parse_filename(f.name) if code and code in codes: matched.append(f) logger.info("匹配文件: %d / %d", len(matched), len(files)) if not matched: logger.error("没有匹配的Parquet文件") sys.exit(1) # 导入 conn = sqlite3.connect(args.db) total_rows = 0 t_start = time.time() for i, f in enumerate(matched): t0 = time.time() rows = import_file(conn, f) t1 = time.time() total_rows += rows logger.info("[%d/%d] %s: %d rows (%.1fs) total=%d", i + 1, len(matched), f.name, rows, t1 - t0, total_rows) # 更新 overview c = conn.cursor() c.execute( """INSERT OR REPLACE INTO dbbaroverview (symbol, exchange, interval, count, start, end) SELECT symbol, exchange, interval, COUNT(*), MIN(datetime), MAX(datetime) FROM dbbardata GROUP BY symbol, exchange, interval""" ) conn.commit() # 统计 c.execute("SELECT COUNT(*) FROM dbbardata WHERE interval = ?", (INTERVAL,)) minute_rows = c.fetchone()[0] c.execute("SELECT COUNT(*) FROM dbbardata") all_rows = c.fetchone()[0] c.execute("SELECT COUNT(*) FROM dbbaroverview") overview_count = c.fetchone()[0] elapsed = time.time() - t_start conn.close() logger.info("=" * 50) logger.info("导入完成") logger.info("15分钟线: %d 行", minute_rows) logger.info("总数据量: %d 行 (含日线)", all_rows) logger.info("Overview: %d 条", overview_count) logger.info("耗时: %.1f 秒 (%.1f 分钟)", elapsed, elapsed / 60) logger.info("DB路径: %s", args.db) if __name__ == "__main__": main()