diff --git a/data_platform/updater.py b/data_platform/updater.py new file mode 100644 index 00000000..3026e776 --- /dev/null +++ b/data_platform/updater.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +"""增量更新 - Parquet+vnpy DB双写""" +import os +import re +import sys +import json +import sqlite3 +import shutil +import logging +import time +import pandas as pd +from pathlib import Path +from datetime import datetime + +logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') +logger = logging.getLogger(__name__) + +DAILY_DIR = "/Volumes/stock/A股数据/日线数据/daily/" +VNPY_DB_PATH = "/Volumes/stock/sanguo_vnpy/data/quant_trading.db" +LOCAL_DB_TMP = "/tmp/quant_trading_updater.db" +BATCH_SIZE = 50000 + + +def parse_filename(filename): + m = re.match(r'(sh|sz)(\d{6})_daily\.parquet', filename) + if not m: + return None, None + return m.group(2), 'SSE' if m.group(1) == 'sh' else 'SZSE' + + +def get_all_symbols(): + """扫描最新年份目录获取所有股票代码""" + latest_year = max(d.name for d in Path(DAILY_DIR).iterdir() if d.is_dir() and d.name.isdigit()) + symbols = [] + for f in (Path(DAILY_DIR) / latest_year).glob('*.parquet'): + code, exchange = parse_filename(f.name) + if code: + symbols.append((code, exchange, f.name)) + return symbols + + +def get_last_date(code: str, exchange: str) -> str: + """获取某只股票在NAS Parquet中的最后日期""" + prefix = 'sh' if exchange == 'SSE' else 'sz' + for year_dir in sorted(Path(DAILY_DIR).iterdir(), reverse=True): + if not year_dir.is_dir() or not year_dir.name.isdigit(): + continue + fpath = year_dir / f"{prefix}{code}_daily.parquet" + if fpath.exists(): + try: + df = pd.read_parquet(fpath, columns=['date']) + if not df.empty: + last = df['date'].max() + return str(last)[:10] + except Exception: + pass + return "" + + +def fetch_incremental(code: str, start_date: str, end_date: str): + """用akshare获取增量数据""" + try: + import akshare as ak + df = ak.stock_zh_a_hist( + symbol=code, period="daily", + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", ""), + adjust="" + ) + if df is None or df.empty: + return None + df = df.rename(columns={"日期": "date", "开盘": "open", "收盘": "close", + "最高": "high", "最低": "low", "成交量": "volume", + "成交额": "amount"}) + df["date"] = pd.to_datetime(df["date"]).dt.strftime("%Y-%m-%d") + for c in ["open", "high", "low", "close", "volume", "amount"]: + df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0) + return df[["date", "open", "high", "low", "close", "volume", "amount"]] + except Exception as e: + logger.warning(f"akshare获取失败 {code}: {e}") + return None + + +def append_to_parquet(code: str, exchange: str, new_data: pd.DataFrame): + """原子写入:临时文件+rename,追加到对应年份目录""" + prefix = 'sh' if exchange == 'SSE' else 'sz' + for _, row in new_data.iterrows(): + year = row['date'][:4] + year_dir = Path(DAILY_DIR) / year + year_dir.mkdir(parents=True, exist_ok=True) + fpath = year_dir / f"{prefix}{code}_daily.parquet" + + if fpath.exists(): + existing = pd.read_parquet(fpath) + combined = pd.concat([existing, pd.DataFrame([row])], ignore_index=True) + combined = combined.drop_duplicates(subset=['date'], keep='last') + combined = combined.sort_values('date').reset_index(drop=True) + else: + combined = pd.DataFrame([row]) + + tmp_path = str(fpath) + ".tmp" + combined.to_parquet(tmp_path, index=False) + os.rename(tmp_path, str(fpath)) + + +def append_to_vnpy_db(code: str, exchange: str, new_data: pd.DataFrame): + """写入vnpy DB (先本地/tmp,完成后复制到NAS)""" + values = [] + for _, row in new_data.iterrows(): + values.append(( + code, exchange, str(row['date']), 'd', + float(row.get('volume', 0)), float(row.get('amount', 0)), 0.0, + float(row.get('open', 0)), float(row.get('high', 0)), + float(row.get('low', 0)), float(row.get('close', 0)), + )) + return values + + +def main(): + today = datetime.now().strftime("%Y-%m-%d") + max_end = today # 不超过今天 + + logger.info(f"=== 增量更新开始 {today} ===") + + # 获取所有股票 + symbols = get_all_symbols() + logger.info(f"扫描到 {len(symbols)} 只股票") + + updated = 0 + skipped = 0 + failed = 0 + new_records = 0 + all_db_values = [] + + for i, (code, exchange, fname) in enumerate(symbols): + # 获取最后日期 + last_date = get_last_date(code, exchange) + if not last_date: + skipped += 1 + continue + + # 计算需要补的起始日期(下一天) + next_day = (pd.Timestamp(last_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") + if next_day > max_end: + skipped += 1 + continue + + # 获取增量数据 + data = fetch_incremental(code, next_day, max_end) + if data is None or data.empty: + skipped += 1 + continue + + # 校验(简单fatal检查) + if (data[['open', 'high', 'low', 'close']] <= 0).any().any(): + logger.warning(f"{code} 有非正价格,跳过") + failed += 1 + continue + + # 写Parquet + try: + append_to_parquet(code, exchange, data) + except Exception as e: + logger.error(f"Parquet写入失败 {code}: {e}") + failed += 1 + continue + + # 收集vnpy DB数据 + db_vals = append_to_vnpy_db(code, exchange, data) + all_db_values.extend(db_vals) + + new_records += len(data) + updated += 1 + + if (i + 1) % 500 == 0: + logger.info(f"进度: {i+1}/{len(symbols)} updated={updated} skipped={skipped} failed={failed}") + + # 限频:akshare 1秒间隔 + time.sleep(0.5) + + # 写vnpy DB + if all_db_values: + logger.info(f"写入vnpy DB: {len(all_db_values)} 条记录") + try: + # 复制NAS DB到本地 + shutil.copy2(VNPY_DB_PATH, LOCAL_DB_TMP) + conn = sqlite3.connect(LOCAL_DB_TMP) + c = conn.cursor() + for j in range(0, len(all_db_values), BATCH_SIZE): + c.executemany('''INSERT OR REPLACE INTO dbbardata + (symbol,exchange,datetime,interval,volume,turnover,open_interest, + open_price,high_price,low_price,close_price) + VALUES (?,?,?,?,?,?,?,?,?,?,?)''', all_db_values[j:j+BATCH_SIZE]) + conn.commit() + + # 重建overview + c.execute('''INSERT OR REPLACE INTO dbbaroverview (symbol,exchange,interval,count,start,end) + SELECT symbol,exchange,interval,COUNT(*),MIN(datetime),MAX(datetime) + FROM dbbardata GROUP BY symbol,exchange,interval''') + conn.commit() + conn.close() + + # 复制回NAS + shutil.copy2(LOCAL_DB_TMP, VNPY_DB_PATH) + logger.info("✅ vnpy DB更新完成") + except Exception as e: + logger.error(f"❌ vnpy DB更新失败: {e}") + + report = { + "date": today, + "total_symbols": len(symbols), + "updated": updated, + "skipped": skipped, + "failed": failed, + "new_records": new_records, + } + logger.info(f"=== 更新完成 ===") + logger.info(json.dumps(report, ensure_ascii=False, indent=2)) + return report + + +if __name__ == "__main__": + main()