auto-sync: 2026-05-02 18:47:55
This commit is contained in:
@@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
"""增量更新 - Parquet+vnpy DB双写"""
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import sqlite3
|
||||
import shutil
|
||||
import logging
|
||||
import time
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DAILY_DIR = "/Volumes/stock/A股数据/日线数据/daily/"
|
||||
VNPY_DB_PATH = "/Volumes/stock/sanguo_vnpy/data/quant_trading.db"
|
||||
LOCAL_DB_TMP = "/tmp/quant_trading_updater.db"
|
||||
BATCH_SIZE = 50000
|
||||
|
||||
|
||||
def parse_filename(filename):
|
||||
m = re.match(r'(sh|sz)(\d{6})_daily\.parquet', filename)
|
||||
if not m:
|
||||
return None, None
|
||||
return m.group(2), 'SSE' if m.group(1) == 'sh' else 'SZSE'
|
||||
|
||||
|
||||
def get_all_symbols():
|
||||
"""扫描最新年份目录获取所有股票代码"""
|
||||
latest_year = max(d.name for d in Path(DAILY_DIR).iterdir() if d.is_dir() and d.name.isdigit())
|
||||
symbols = []
|
||||
for f in (Path(DAILY_DIR) / latest_year).glob('*.parquet'):
|
||||
code, exchange = parse_filename(f.name)
|
||||
if code:
|
||||
symbols.append((code, exchange, f.name))
|
||||
return symbols
|
||||
|
||||
|
||||
def get_last_date(code: str, exchange: str) -> str:
|
||||
"""获取某只股票在NAS Parquet中的最后日期"""
|
||||
prefix = 'sh' if exchange == 'SSE' else 'sz'
|
||||
for year_dir in sorted(Path(DAILY_DIR).iterdir(), reverse=True):
|
||||
if not year_dir.is_dir() or not year_dir.name.isdigit():
|
||||
continue
|
||||
fpath = year_dir / f"{prefix}{code}_daily.parquet"
|
||||
if fpath.exists():
|
||||
try:
|
||||
df = pd.read_parquet(fpath, columns=['date'])
|
||||
if not df.empty:
|
||||
last = df['date'].max()
|
||||
return str(last)[:10]
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def fetch_incremental(code: str, start_date: str, end_date: str):
|
||||
"""用akshare获取增量数据"""
|
||||
try:
|
||||
import akshare as ak
|
||||
df = ak.stock_zh_a_hist(
|
||||
symbol=code, period="daily",
|
||||
start_date=start_date.replace("-", ""),
|
||||
end_date=end_date.replace("-", ""),
|
||||
adjust=""
|
||||
)
|
||||
if df is None or df.empty:
|
||||
return None
|
||||
df = df.rename(columns={"日期": "date", "开盘": "open", "收盘": "close",
|
||||
"最高": "high", "最低": "low", "成交量": "volume",
|
||||
"成交额": "amount"})
|
||||
df["date"] = pd.to_datetime(df["date"]).dt.strftime("%Y-%m-%d")
|
||||
for c in ["open", "high", "low", "close", "volume", "amount"]:
|
||||
df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)
|
||||
return df[["date", "open", "high", "low", "close", "volume", "amount"]]
|
||||
except Exception as e:
|
||||
logger.warning(f"akshare获取失败 {code}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def append_to_parquet(code: str, exchange: str, new_data: pd.DataFrame):
|
||||
"""原子写入:临时文件+rename,追加到对应年份目录"""
|
||||
prefix = 'sh' if exchange == 'SSE' else 'sz'
|
||||
for _, row in new_data.iterrows():
|
||||
year = row['date'][:4]
|
||||
year_dir = Path(DAILY_DIR) / year
|
||||
year_dir.mkdir(parents=True, exist_ok=True)
|
||||
fpath = year_dir / f"{prefix}{code}_daily.parquet"
|
||||
|
||||
if fpath.exists():
|
||||
existing = pd.read_parquet(fpath)
|
||||
combined = pd.concat([existing, pd.DataFrame([row])], ignore_index=True)
|
||||
combined = combined.drop_duplicates(subset=['date'], keep='last')
|
||||
combined = combined.sort_values('date').reset_index(drop=True)
|
||||
else:
|
||||
combined = pd.DataFrame([row])
|
||||
|
||||
tmp_path = str(fpath) + ".tmp"
|
||||
combined.to_parquet(tmp_path, index=False)
|
||||
os.rename(tmp_path, str(fpath))
|
||||
|
||||
|
||||
def append_to_vnpy_db(code: str, exchange: str, new_data: pd.DataFrame):
|
||||
"""写入vnpy DB (先本地/tmp,完成后复制到NAS)"""
|
||||
values = []
|
||||
for _, row in new_data.iterrows():
|
||||
values.append((
|
||||
code, exchange, str(row['date']), 'd',
|
||||
float(row.get('volume', 0)), float(row.get('amount', 0)), 0.0,
|
||||
float(row.get('open', 0)), float(row.get('high', 0)),
|
||||
float(row.get('low', 0)), float(row.get('close', 0)),
|
||||
))
|
||||
return values
|
||||
|
||||
|
||||
def main():
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
max_end = today # 不超过今天
|
||||
|
||||
logger.info(f"=== 增量更新开始 {today} ===")
|
||||
|
||||
# 获取所有股票
|
||||
symbols = get_all_symbols()
|
||||
logger.info(f"扫描到 {len(symbols)} 只股票")
|
||||
|
||||
updated = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
new_records = 0
|
||||
all_db_values = []
|
||||
|
||||
for i, (code, exchange, fname) in enumerate(symbols):
|
||||
# 获取最后日期
|
||||
last_date = get_last_date(code, exchange)
|
||||
if not last_date:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 计算需要补的起始日期(下一天)
|
||||
next_day = (pd.Timestamp(last_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
if next_day > max_end:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 获取增量数据
|
||||
data = fetch_incremental(code, next_day, max_end)
|
||||
if data is None or data.empty:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 校验(简单fatal检查)
|
||||
if (data[['open', 'high', 'low', 'close']] <= 0).any().any():
|
||||
logger.warning(f"{code} 有非正价格,跳过")
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 写Parquet
|
||||
try:
|
||||
append_to_parquet(code, exchange, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Parquet写入失败 {code}: {e}")
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 收集vnpy DB数据
|
||||
db_vals = append_to_vnpy_db(code, exchange, data)
|
||||
all_db_values.extend(db_vals)
|
||||
|
||||
new_records += len(data)
|
||||
updated += 1
|
||||
|
||||
if (i + 1) % 500 == 0:
|
||||
logger.info(f"进度: {i+1}/{len(symbols)} updated={updated} skipped={skipped} failed={failed}")
|
||||
|
||||
# 限频:akshare 1秒间隔
|
||||
time.sleep(0.5)
|
||||
|
||||
# 写vnpy DB
|
||||
if all_db_values:
|
||||
logger.info(f"写入vnpy DB: {len(all_db_values)} 条记录")
|
||||
try:
|
||||
# 复制NAS DB到本地
|
||||
shutil.copy2(VNPY_DB_PATH, LOCAL_DB_TMP)
|
||||
conn = sqlite3.connect(LOCAL_DB_TMP)
|
||||
c = conn.cursor()
|
||||
for j in range(0, len(all_db_values), BATCH_SIZE):
|
||||
c.executemany('''INSERT OR REPLACE INTO dbbardata
|
||||
(symbol,exchange,datetime,interval,volume,turnover,open_interest,
|
||||
open_price,high_price,low_price,close_price)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?)''', all_db_values[j:j+BATCH_SIZE])
|
||||
conn.commit()
|
||||
|
||||
# 重建overview
|
||||
c.execute('''INSERT OR REPLACE INTO dbbaroverview (symbol,exchange,interval,count,start,end)
|
||||
SELECT symbol,exchange,interval,COUNT(*),MIN(datetime),MAX(datetime)
|
||||
FROM dbbardata GROUP BY symbol,exchange,interval''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# 复制回NAS
|
||||
shutil.copy2(LOCAL_DB_TMP, VNPY_DB_PATH)
|
||||
logger.info("✅ vnpy DB更新完成")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ vnpy DB更新失败: {e}")
|
||||
|
||||
report = {
|
||||
"date": today,
|
||||
"total_symbols": len(symbols),
|
||||
"updated": updated,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
"new_records": new_records,
|
||||
}
|
||||
logger.info(f"=== 更新完成 ===")
|
||||
logger.info(json.dumps(report, ensure_ascii=False, indent=2))
|
||||
return report
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user