173 lines
5.2 KiB
Python
173 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
将NAS日线Parquet数据导入vnpy SQLite数据库
|
|
用法: python3 import_vnpy_daily.py [--start-year 2013] [--dry-run]
|
|
"""
|
|
import sqlite3
|
|
import pandas as pd
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
DB_PATH = os.environ.get('VNPY_DB_PATH', '/tmp/quant_trading_import.db')
|
|
DAILY_DIR = '/Volumes/stock/A股数据/日线数据/daily/'
|
|
|
|
BATCH_SIZE = 50000 # 每批插入行数
|
|
|
|
|
|
def parse_filename(filename):
|
|
"""sh600519_daily.parquet → ('600519', 'SSE')"""
|
|
m = re.match(r'(sh|sz)(\d{6})_daily\.parquet', filename)
|
|
if not m:
|
|
return None, None
|
|
prefix, code = m.groups()
|
|
exchange = 'SSE' if prefix == 'sh' else 'SZSE'
|
|
return code, exchange
|
|
|
|
|
|
def import_year(conn, year, dry_run=False):
|
|
"""导入指定年份的所有日线数据"""
|
|
year_dir = Path(DAILY_DIR) / str(year)
|
|
if not year_dir.exists():
|
|
print(f' ⚠️ {year} 目录不存在')
|
|
return 0, 0, 0
|
|
|
|
files = sorted(year_dir.glob('*.parquet'))
|
|
if not files:
|
|
print(f' ⚠️ {year} 无parquet文件')
|
|
return 0, 0, 0
|
|
|
|
c = conn.cursor()
|
|
imported_files = 0
|
|
imported_rows = 0
|
|
failed = 0
|
|
|
|
# 收集所有数据
|
|
all_values = []
|
|
|
|
for f in files:
|
|
code, exchange = parse_filename(f.name)
|
|
if code is None:
|
|
failed += 1
|
|
continue
|
|
|
|
try:
|
|
df = pd.read_parquet(f, columns=['date', 'open', 'high', 'low', 'close', 'volume', 'amount'])
|
|
if df.empty:
|
|
continue
|
|
|
|
for _, row in df.iterrows():
|
|
all_values.append((
|
|
code,
|
|
exchange,
|
|
str(row['date']),
|
|
'd',
|
|
float(row['volume']) if pd.notna(row['volume']) else 0.0,
|
|
float(row['amount']) if pd.notna(row['amount']) else 0.0,
|
|
0.0, # open_interest
|
|
float(row['open']) if pd.notna(row['open']) else 0.0,
|
|
float(row['high']) if pd.notna(row['high']) else 0.0,
|
|
float(row['low']) if pd.notna(row['low']) else 0.0,
|
|
float(row['close']) if pd.notna(row['close']) else 0.0,
|
|
))
|
|
|
|
imported_files += 1
|
|
imported_rows += len(df)
|
|
|
|
except Exception as e:
|
|
failed += 1
|
|
if failed <= 3:
|
|
print(f' ❌ {f.name}: {e}')
|
|
|
|
if dry_run:
|
|
print(f' [DRY RUN] Would insert {len(all_values)} rows from {imported_files} files')
|
|
return imported_files, imported_rows, failed
|
|
|
|
# 批量插入
|
|
if all_values:
|
|
for i in range(0, len(all_values), BATCH_SIZE):
|
|
batch = all_values[i:i+BATCH_SIZE]
|
|
c.executemany('''
|
|
INSERT OR REPLACE INTO dbbardata
|
|
(symbol, exchange, datetime, interval, volume, turnover, open_interest,
|
|
open_price, high_price, low_price, close_price)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', batch)
|
|
conn.commit()
|
|
|
|
print(f' ✅ {year}: {imported_files} files, {imported_rows} rows inserted, {failed} failed')
|
|
return imported_files, imported_rows, failed
|
|
|
|
|
|
def update_overview(conn):
|
|
"""更新 dbbaroverview 汇总表"""
|
|
c = conn.cursor()
|
|
c.execute('''
|
|
INSERT OR REPLACE INTO dbbaroverview (symbol, exchange, interval, count, start, end)
|
|
SELECT symbol, exchange, interval,
|
|
COUNT(*) as count,
|
|
MIN(datetime) as start,
|
|
MAX(datetime) as end
|
|
FROM dbbardata
|
|
GROUP BY symbol, exchange, interval
|
|
''')
|
|
conn.commit()
|
|
c.execute('SELECT COUNT(*) FROM dbbaroverview')
|
|
print(f' Overview: {c.fetchone()[0]} entries')
|
|
|
|
|
|
def main():
|
|
start_year = 2013 # 默认从2013开始(2010-2012已导入)
|
|
dry_run = '--dry-run' in sys.argv
|
|
|
|
for i, arg in enumerate(sys.argv):
|
|
if arg == '--start-year' and i+1 < len(sys.argv):
|
|
start_year = int(sys.argv[i+1])
|
|
|
|
print(f'=== vnpy Daily Import ===')
|
|
print(f'DB: {DB_PATH}')
|
|
print(f'Start year: {start_year}')
|
|
print(f'Dry run: {dry_run}')
|
|
print()
|
|
|
|
conn = sqlite3.connect(DB_PATH, timeout=60)
|
|
|
|
total_files = 0
|
|
total_rows = 0
|
|
total_failed = 0
|
|
start_time = time.time()
|
|
|
|
for year in range(start_year, 2027):
|
|
t0 = time.time()
|
|
f, r, fail = import_year(conn, year, dry_run)
|
|
t1 = time.time()
|
|
total_files += f
|
|
total_rows += r
|
|
total_failed += fail
|
|
print(f' ({t1-t0:.1f}s, total: {total_rows} rows)')
|
|
|
|
if not dry_run:
|
|
print('\nUpdating overview...')
|
|
update_overview(conn)
|
|
|
|
elapsed = time.time() - start_time
|
|
print(f'\n=== Summary ===')
|
|
print(f'Files: {total_files}')
|
|
print(f'Rows: {total_rows}')
|
|
print(f'Failed: {total_failed}')
|
|
print(f'Time: {elapsed:.1f}s ({elapsed/60:.1f}min)')
|
|
|
|
# Final verification
|
|
c = conn.cursor()
|
|
c.execute('SELECT COUNT(*) FROM dbbardata')
|
|
print(f'DB total rows: {c.fetchone()[0]}')
|
|
|
|
conn.close()
|
|
print('✅ Done')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|