auto-sync: 2026-05-02 11:21:13
This commit is contained in:
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
将NAS日线Parquet数据导入vnpy SQLite数据库
|
||||
用法: python3 import_vnpy_daily.py [--start-year 2013] [--dry-run]
|
||||
"""
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
DB_PATH = '/Volumes/stock/sanguo_vnpy/data/quant_trading.db'
|
||||
DAILY_DIR = '/Volumes/stock/A股数据/日线数据/daily/'
|
||||
|
||||
BATCH_SIZE = 50000 # 每批插入行数
|
||||
|
||||
|
||||
def parse_filename(filename):
|
||||
"""sh600519_daily.parquet → ('600519', 'SSE')"""
|
||||
m = re.match(r'(sh|sz)(\d{6})_daily\.parquet', filename)
|
||||
if not m:
|
||||
return None, None
|
||||
prefix, code = m.groups()
|
||||
exchange = 'SSE' if prefix == 'sh' else 'SZSE'
|
||||
return code, exchange
|
||||
|
||||
|
||||
def import_year(conn, year, dry_run=False):
|
||||
"""导入指定年份的所有日线数据"""
|
||||
year_dir = Path(DAILY_DIR) / str(year)
|
||||
if not year_dir.exists():
|
||||
print(f' ⚠️ {year} 目录不存在')
|
||||
return 0, 0, 0
|
||||
|
||||
files = sorted(year_dir.glob('*.parquet'))
|
||||
if not files:
|
||||
print(f' ⚠️ {year} 无parquet文件')
|
||||
return 0, 0, 0
|
||||
|
||||
c = conn.cursor()
|
||||
imported_files = 0
|
||||
imported_rows = 0
|
||||
failed = 0
|
||||
|
||||
# 收集所有数据
|
||||
all_values = []
|
||||
|
||||
for f in files:
|
||||
code, exchange = parse_filename(f.name)
|
||||
if code is None:
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
df = pd.read_parquet(f, columns=['date', 'open', 'high', 'low', 'close', 'volume', 'amount'])
|
||||
if df.empty:
|
||||
continue
|
||||
|
||||
for _, row in df.iterrows():
|
||||
all_values.append((
|
||||
code,
|
||||
exchange,
|
||||
str(row['date']),
|
||||
'd',
|
||||
float(row['volume']) if pd.notna(row['volume']) else 0.0,
|
||||
float(row['amount']) if pd.notna(row['amount']) else 0.0,
|
||||
0.0, # open_interest
|
||||
float(row['open']) if pd.notna(row['open']) else 0.0,
|
||||
float(row['high']) if pd.notna(row['high']) else 0.0,
|
||||
float(row['low']) if pd.notna(row['low']) else 0.0,
|
||||
float(row['close']) if pd.notna(row['close']) else 0.0,
|
||||
))
|
||||
|
||||
imported_files += 1
|
||||
imported_rows += len(df)
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
if failed <= 3:
|
||||
print(f' ❌ {f.name}: {e}')
|
||||
|
||||
if dry_run:
|
||||
print(f' [DRY RUN] Would insert {len(all_values)} rows from {imported_files} files')
|
||||
return imported_files, imported_rows, failed
|
||||
|
||||
# 批量插入
|
||||
if all_values:
|
||||
for i in range(0, len(all_values), BATCH_SIZE):
|
||||
batch = all_values[i:i+BATCH_SIZE]
|
||||
c.executemany('''
|
||||
INSERT OR REPLACE INTO dbbardata
|
||||
(symbol, exchange, datetime, interval, volume, turnover, open_interest,
|
||||
open_price, high_price, low_price, close_price)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', batch)
|
||||
conn.commit()
|
||||
|
||||
print(f' ✅ {year}: {imported_files} files, {imported_rows} rows inserted, {failed} failed')
|
||||
return imported_files, imported_rows, failed
|
||||
|
||||
|
||||
def update_overview(conn):
|
||||
"""更新 dbbaroverview 汇总表"""
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
INSERT OR REPLACE INTO dbbaroverview (symbol, exchange, interval, count, start, end)
|
||||
SELECT symbol, exchange, interval,
|
||||
COUNT(*) as count,
|
||||
MIN(datetime) as start,
|
||||
MAX(datetime) as end
|
||||
FROM dbbardata
|
||||
GROUP BY symbol, exchange, interval
|
||||
''')
|
||||
conn.commit()
|
||||
c.execute('SELECT COUNT(*) FROM dbbaroverview')
|
||||
print(f' Overview: {c.fetchone()[0]} entries')
|
||||
|
||||
|
||||
def main():
|
||||
start_year = 2013 # 默认从2013开始(2010-2012已导入)
|
||||
dry_run = '--dry-run' in sys.argv
|
||||
|
||||
for i, arg in enumerate(sys.argv):
|
||||
if arg == '--start-year' and i+1 < len(sys.argv):
|
||||
start_year = int(sys.argv[i+1])
|
||||
|
||||
print(f'=== vnpy Daily Import ===')
|
||||
print(f'DB: {DB_PATH}')
|
||||
print(f'Start year: {start_year}')
|
||||
print(f'Dry run: {dry_run}')
|
||||
print()
|
||||
|
||||
conn = sqlite3.connect(DB_PATH, timeout=60)
|
||||
|
||||
total_files = 0
|
||||
total_rows = 0
|
||||
total_failed = 0
|
||||
start_time = time.time()
|
||||
|
||||
for year in range(start_year, 2027):
|
||||
t0 = time.time()
|
||||
f, r, fail = import_year(conn, year, dry_run)
|
||||
t1 = time.time()
|
||||
total_files += f
|
||||
total_rows += r
|
||||
total_failed += fail
|
||||
print(f' ({t1-t0:.1f}s, total: {total_rows} rows)')
|
||||
|
||||
if not dry_run:
|
||||
print('\nUpdating overview...')
|
||||
update_overview(conn)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f'\n=== Summary ===')
|
||||
print(f'Files: {total_files}')
|
||||
print(f'Rows: {total_rows}')
|
||||
print(f'Failed: {total_failed}')
|
||||
print(f'Time: {elapsed:.1f}s ({elapsed/60:.1f}min)')
|
||||
|
||||
# Final verification
|
||||
c = conn.cursor()
|
||||
c.execute('SELECT COUNT(*) FROM dbbardata')
|
||||
print(f'DB total rows: {c.fetchone()[0]}')
|
||||
|
||||
conn.close()
|
||||
print('✅ Done')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user