diff --git a/data_platform/import_vnpy_daily.py b/data_platform/import_vnpy_daily.py new file mode 100644 index 00000000..3498ea8b --- /dev/null +++ b/data_platform/import_vnpy_daily.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +""" +将NAS日线Parquet数据导入vnpy SQLite数据库 +用法: python3 import_vnpy_daily.py [--start-year 2013] [--dry-run] +""" +import sqlite3 +import pandas as pd +import os +import re +import sys +import time +from pathlib import Path + +DB_PATH = '/Volumes/stock/sanguo_vnpy/data/quant_trading.db' +DAILY_DIR = '/Volumes/stock/A股数据/日线数据/daily/' + +BATCH_SIZE = 50000 # 每批插入行数 + + +def parse_filename(filename): + """sh600519_daily.parquet → ('600519', 'SSE')""" + m = re.match(r'(sh|sz)(\d{6})_daily\.parquet', filename) + if not m: + return None, None + prefix, code = m.groups() + exchange = 'SSE' if prefix == 'sh' else 'SZSE' + return code, exchange + + +def import_year(conn, year, dry_run=False): + """导入指定年份的所有日线数据""" + year_dir = Path(DAILY_DIR) / str(year) + if not year_dir.exists(): + print(f' ⚠️ {year} 目录不存在') + return 0, 0, 0 + + files = sorted(year_dir.glob('*.parquet')) + if not files: + print(f' ⚠️ {year} 无parquet文件') + return 0, 0, 0 + + c = conn.cursor() + imported_files = 0 + imported_rows = 0 + failed = 0 + + # 收集所有数据 + all_values = [] + + for f in files: + code, exchange = parse_filename(f.name) + if code is None: + failed += 1 + continue + + try: + df = pd.read_parquet(f, columns=['date', 'open', 'high', 'low', 'close', 'volume', 'amount']) + if df.empty: + continue + + for _, row in df.iterrows(): + all_values.append(( + code, + exchange, + str(row['date']), + 'd', + float(row['volume']) if pd.notna(row['volume']) else 0.0, + float(row['amount']) if pd.notna(row['amount']) else 0.0, + 0.0, # open_interest + float(row['open']) if pd.notna(row['open']) else 0.0, + float(row['high']) if pd.notna(row['high']) else 0.0, + float(row['low']) if pd.notna(row['low']) else 0.0, + float(row['close']) if pd.notna(row['close']) else 0.0, + )) + + imported_files += 1 + imported_rows += len(df) + + except Exception as e: + failed += 1 + if failed <= 3: + print(f' ❌ {f.name}: {e}') + + if dry_run: + print(f' [DRY RUN] Would insert {len(all_values)} rows from {imported_files} files') + return imported_files, imported_rows, failed + + # 批量插入 + if all_values: + for i in range(0, len(all_values), BATCH_SIZE): + batch = all_values[i:i+BATCH_SIZE] + c.executemany(''' + INSERT OR REPLACE INTO dbbardata + (symbol, exchange, datetime, interval, volume, turnover, open_interest, + open_price, high_price, low_price, close_price) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', batch) + conn.commit() + + print(f' ✅ {year}: {imported_files} files, {imported_rows} rows inserted, {failed} failed') + return imported_files, imported_rows, failed + + +def update_overview(conn): + """更新 dbbaroverview 汇总表""" + c = conn.cursor() + c.execute(''' + INSERT OR REPLACE INTO dbbaroverview (symbol, exchange, interval, count, start, end) + SELECT symbol, exchange, interval, + COUNT(*) as count, + MIN(datetime) as start, + MAX(datetime) as end + FROM dbbardata + GROUP BY symbol, exchange, interval + ''') + conn.commit() + c.execute('SELECT COUNT(*) FROM dbbaroverview') + print(f' Overview: {c.fetchone()[0]} entries') + + +def main(): + start_year = 2013 # 默认从2013开始(2010-2012已导入) + dry_run = '--dry-run' in sys.argv + + for i, arg in enumerate(sys.argv): + if arg == '--start-year' and i+1 < len(sys.argv): + start_year = int(sys.argv[i+1]) + + print(f'=== vnpy Daily Import ===') + print(f'DB: {DB_PATH}') + print(f'Start year: {start_year}') + print(f'Dry run: {dry_run}') + print() + + conn = sqlite3.connect(DB_PATH, timeout=60) + + total_files = 0 + total_rows = 0 + total_failed = 0 + start_time = time.time() + + for year in range(start_year, 2027): + t0 = time.time() + f, r, fail = import_year(conn, year, dry_run) + t1 = time.time() + total_files += f + total_rows += r + total_failed += fail + print(f' ({t1-t0:.1f}s, total: {total_rows} rows)') + + if not dry_run: + print('\nUpdating overview...') + update_overview(conn) + + elapsed = time.time() - start_time + print(f'\n=== Summary ===') + print(f'Files: {total_files}') + print(f'Rows: {total_rows}') + print(f'Failed: {total_failed}') + print(f'Time: {elapsed:.1f}s ({elapsed/60:.1f}min)') + + # Final verification + c = conn.cursor() + c.execute('SELECT COUNT(*) FROM dbbardata') + print(f'DB total rows: {c.fetchone()[0]}') + + conn.close() + print('✅ Done') + + +if __name__ == '__main__': + main()