auto-sync: 2026-05-02 21:26:32
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
导入15分钟线Parquet到vnpy SQLite DB
|
||||
|
||||
复用P1导入逻辑(pandas向量化+批量INSERT OR REPLACE)
|
||||
与 import_vnpy_daily_fast.py 的区别:
|
||||
- interval = '15m'(而非 'd')
|
||||
- datetime已是 "YYYY-MM-DD HH:MM:SS" 格式
|
||||
- 从单个parquet文件导入(非按年份目录)
|
||||
- volume和amount是object类型需转float
|
||||
|
||||
用法:
|
||||
python3 import_vnpy_minute.py --scope hs300
|
||||
python3 import_vnpy_minute.py --scope all
|
||||
python3 import_vnpy_minute.py --codes 000001 600519
|
||||
python3 import_vnpy_minute.py --db /path/to/quant_trading.db
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- 配置 ---
|
||||
MINUTE_DIR = Path("/Volumes/stock/minute_kline/15min")
|
||||
DB_PATH = "/tmp/quant_trading_import.db"
|
||||
INTERVAL = "15m"
|
||||
BATCH_SIZE = 50000
|
||||
|
||||
HS300_FILE = Path("/Users/chufeng/.openclaw/sanguo_projects/sanguo_quant_live/zhaoyun-data/data/raw/stock_info/hs300_constituents_latest.csv")
|
||||
ALL_STOCKS_FILE = Path("/Users/chufeng/.openclaw/sanguo_projects/sanguo_quant_live/zhaoyun-data/data/raw/stock_info/stock_basic_info_raw_20260326_113530.csv")
|
||||
|
||||
|
||||
def parse_filename(filename: str):
|
||||
"""解析文件名: sz000001_15min.parquet -> (code, exchange)"""
|
||||
m = re.match(r"(sh|sz)(\d{6})_15min\.parquet", filename)
|
||||
if not m:
|
||||
return None, None
|
||||
prefix, code = m.groups()
|
||||
return code, "SSE" if prefix == "sh" else "SZSE"
|
||||
|
||||
|
||||
def import_file(conn, filepath: Path) -> int:
|
||||
"""导入单个Parquet文件,返回导入行数"""
|
||||
code, exchange = parse_filename(filepath.name)
|
||||
if code is None:
|
||||
return 0
|
||||
|
||||
try:
|
||||
df = pd.read_parquet(filepath)
|
||||
if df.empty:
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning("读取失败 %s: %s", filepath.name, e)
|
||||
return 0
|
||||
|
||||
# 构建导入数据
|
||||
df = df.rename(columns={
|
||||
"open": "open_price",
|
||||
"high": "high_price",
|
||||
"low": "low_price",
|
||||
"close": "close_price",
|
||||
"amount": "turnover",
|
||||
})
|
||||
|
||||
df["symbol"] = code
|
||||
df["exchange"] = exchange
|
||||
df["interval"] = INTERVAL
|
||||
df["open_interest"] = 0.0
|
||||
df["datetime"] = df["day"].astype(str)
|
||||
|
||||
# 类型转换(volume/amount可能是object)
|
||||
for col in ["volume", "turnover", "open_price", "high_price", "low_price", "close_price"]:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
||||
|
||||
values = df[[
|
||||
"symbol", "exchange", "datetime", "interval", "volume", "turnover",
|
||||
"open_interest", "open_price", "high_price", "low_price", "close_price"
|
||||
]].values.tolist()
|
||||
|
||||
c = conn.cursor()
|
||||
for i in range(0, len(values), BATCH_SIZE):
|
||||
c.executemany(
|
||||
"""INSERT OR REPLACE INTO dbbardata
|
||||
(symbol, exchange, datetime, interval, volume, turnover, open_interest,
|
||||
open_price, high_price, low_price, close_price)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||||
values[i : i + BATCH_SIZE],
|
||||
)
|
||||
conn.commit()
|
||||
return len(values)
|
||||
|
||||
|
||||
def get_stock_list(scope: str):
|
||||
"""获取股票代码列表"""
|
||||
if scope == "hs300":
|
||||
df = pd.read_csv(HS300_FILE)
|
||||
for col in ["成分券代码", "代码", "code"]:
|
||||
if col in df.columns:
|
||||
return [str(c).zfill(6) for c in df[col].tolist()]
|
||||
raise ValueError(f"HS300文件找不到代码列: {list(df.columns)}")
|
||||
|
||||
if scope == "all":
|
||||
df = pd.read_csv(ALL_STOCKS_FILE)
|
||||
for col in ["代码", "code", "股票代码"]:
|
||||
if col in df.columns:
|
||||
return [str(c).zfill(6) for c in df[col].tolist()]
|
||||
raise ValueError(f"全市场文件找不到代码列: {list(df.columns)}")
|
||||
|
||||
raise ValueError(f"Unknown scope: {scope}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="导入15分钟线到vnpy DB")
|
||||
parser.add_argument("--scope", choices=["hs300", "all"], help="导入范围")
|
||||
parser.add_argument("--codes", nargs="+", help="指定股票代码")
|
||||
parser.add_argument("--db", default=DB_PATH, help="SQLite DB路径")
|
||||
parser.add_argument("--minute-dir", default=str(MINUTE_DIR), help="分钟线Parquet目录")
|
||||
args = parser.parse_args()
|
||||
|
||||
minute_dir = Path(args.minute_dir)
|
||||
if not minute_dir.exists():
|
||||
logger.error("分钟线目录不存在: %s", minute_dir)
|
||||
sys.exit(1)
|
||||
|
||||
# 获取代码列表
|
||||
if args.codes:
|
||||
codes = set(args.codes)
|
||||
elif args.scope:
|
||||
codes = set(get_stock_list(args.scope))
|
||||
else:
|
||||
parser.error("必须指定 --scope 或 --codes")
|
||||
|
||||
# 匹配文件
|
||||
files = sorted(minute_dir.glob("*_15min.parquet"))
|
||||
matched = []
|
||||
for f in files:
|
||||
code, _ = parse_filename(f.name)
|
||||
if code and code in codes:
|
||||
matched.append(f)
|
||||
|
||||
logger.info("匹配文件: %d / %d", len(matched), len(files))
|
||||
|
||||
if not matched:
|
||||
logger.error("没有匹配的Parquet文件")
|
||||
sys.exit(1)
|
||||
|
||||
# 导入
|
||||
conn = sqlite3.connect(args.db)
|
||||
total_rows = 0
|
||||
t_start = time.time()
|
||||
|
||||
for i, f in enumerate(matched):
|
||||
t0 = time.time()
|
||||
rows = import_file(conn, f)
|
||||
t1 = time.time()
|
||||
total_rows += rows
|
||||
logger.info("[%d/%d] %s: %d rows (%.1fs) total=%d", i + 1, len(matched), f.name, rows, t1 - t0, total_rows)
|
||||
|
||||
# 更新 overview
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"""INSERT OR REPLACE INTO dbbaroverview (symbol, exchange, interval, count, start, end)
|
||||
SELECT symbol, exchange, interval, COUNT(*), MIN(datetime), MAX(datetime)
|
||||
FROM dbbardata GROUP BY symbol, exchange, interval"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 统计
|
||||
c.execute("SELECT COUNT(*) FROM dbbardata WHERE interval = ?", (INTERVAL,))
|
||||
minute_rows = c.fetchone()[0]
|
||||
c.execute("SELECT COUNT(*) FROM dbbardata")
|
||||
all_rows = c.fetchone()[0]
|
||||
c.execute("SELECT COUNT(*) FROM dbbaroverview")
|
||||
overview_count = c.fetchone()[0]
|
||||
|
||||
elapsed = time.time() - t_start
|
||||
conn.close()
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("导入完成")
|
||||
logger.info("15分钟线: %d 行", minute_rows)
|
||||
logger.info("总数据量: %d 行 (含日线)", all_rows)
|
||||
logger.info("Overview: %d 条", overview_count)
|
||||
logger.info("耗时: %.1f 秒 (%.1f 分钟)", elapsed, elapsed / 60)
|
||||
logger.info("DB路径: %s", args.db)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user