diff --git a/data_platform/import_vnpy_daily_fast.py b/data_platform/import_vnpy_daily_fast.py new file mode 100644 index 00000000..f97caab2 --- /dev/null +++ b/data_platform/import_vnpy_daily_fast.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +高效导入NAS日线Parquet到vnpy SQLite DB +用pandas向量化代替逐行迭代,速度快10x+ +""" +import sqlite3 +import pandas as pd +import numpy as np +import os +import re +import sys +import time +from pathlib import Path + +DB_PATH = '/tmp/quant_trading_import.db' +DAILY_DIR = '/Volumes/stock/A股数据/日线数据/daily/' + + +def parse_filename(filename): + m = re.match(r'(sh|sz)(\d{6})_daily\.parquet', filename) + if not m: + return None, None + prefix, code = m.groups() + return code, 'SSE' if prefix == 'sh' else 'SZSE' + + +def import_year(conn, year): + year_dir = Path(DAILY_DIR) / str(year) + if not year_dir.exists(): + return 0, 0 + + files = sorted(year_dir.glob('*.parquet')) + if not files: + return 0, 0 + + c = conn.cursor() + all_dfs = [] + + for f in files: + code, exchange = parse_filename(f.name) + if code is None: + continue + try: + df = pd.read_parquet(f, columns=['date', 'open', 'high', 'low', 'close', 'volume', 'amount']) + if df.empty: + continue + df['symbol'] = code + df['exchange'] = exchange + all_dfs.append(df) + except Exception: + pass + + if not all_dfs: + return 0, 0 + + combined = pd.concat(all_dfs, ignore_index=True) + + # Vectorized conversion + combined['datetime'] = combined['date'].astype(str) + combined['interval'] = 'd' + combined['open_interest'] = 0.0 + combined = combined.rename(columns={ + 'open': 'open_price', 'high': 'high_price', + 'low': 'low_price', 'close': 'close_price', 'amount': 'turnover' + }) + + # Fill NaN + for col in ['volume', 'turnover', 'open_price', 'high_price', 'low_price', 'close_price']: + combined[col] = combined[col].fillna(0.0).astype(float) + + values = combined[['symbol','exchange','datetime','interval','volume','turnover', + 'open_interest','open_price','high_price','low_price','close_price' + ]].values.tolist() + + # Batch insert + BATCH = 50000 + for i in range(0, len(values), BATCH): + c.executemany('''INSERT OR REPLACE INTO dbbardata + (symbol,exchange,datetime,interval,volume,turnover,open_interest, + open_price,high_price,low_price,close_price) + VALUES (?,?,?,?,?,?,?,?,?,?,?)''', values[i:i+BATCH]) + conn.commit() + + return len(all_dfs), len(combined) + + +def main(): + start_year = 2017 + for i, arg in enumerate(sys.argv): + if arg == '--start-year' and i+1 < len(sys.argv): + start_year = int(sys.argv[i+1]) + + print(f'Importing from {start_year} to local DB: {DB_PATH}') + + conn = sqlite3.connect(DB_PATH) + total_rows = 0 + t_start = time.time() + + for year in range(start_year, 2027): + t0 = time.time() + files, rows = import_year(conn, year) + t1 = time.time() + total_rows += rows + print(f'{year}: {files} files, {rows} rows ({t1-t0:.1f}s) total={total_rows}') + + elapsed = time.time() - t_start + + # Update overview + c = conn.cursor() + c.execute('''INSERT OR REPLACE INTO dbbaroverview (symbol,exchange,interval,count,start,end) + SELECT symbol,exchange,interval,COUNT(*),MIN(datetime),MAX(datetime) + FROM dbbardata GROUP BY symbol,exchange,interval''') + conn.commit() + + c.execute('SELECT COUNT(*) FROM dbbardata') + final = c.fetchone()[0] + c.execute('SELECT COUNT(*) FROM dbbaroverview') + overview = c.fetchone()[0] + + print(f'\nDone in {elapsed:.1f}s ({elapsed/60:.1f}min)') + print(f'Total rows: {final}, Overview entries: {overview}') + conn.close() + + +if __name__ == '__main__': + main()