1096 lines
39 KiB
Python
1096 lines
39 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
全市场每日增量更新 - 日线 + 15分钟线 (v2.0)
|
||
|
||
功能:
|
||
1. 日线:多源fallback → 更新Parquet + 本地vnpy DB
|
||
2. 15分钟线:多源fallback → 增量合并Parquet + 本地vnpy DB
|
||
3. 本地DB构建完成后mv原子重命名到NAS
|
||
|
||
数据源降级链(按质量排序,BaoStock T+1延迟已考虑):
|
||
日线增量(当天实时):东方财富(实时,4s限频) → BaoStock(T+1,无反爬) → 腾讯(amount有时0)
|
||
15min增量(当天实时):东方财富(实时7周) → BaoStock(T+1,无反爬) → 新浪(已挂,保留)
|
||
|
||
设计原则:
|
||
- 多源fallback:按质量排序,成功即用,失败试下一个
|
||
- 增量更新,不重复下载
|
||
- vnpy DB本地构建 → mv原子重命名到NAS(避免SMB锁)
|
||
- 失败率检测(滑动窗口100只,>80%失败则终止)
|
||
- DB轮转备份(保留7天)
|
||
|
||
用法:
|
||
python3 daily_all_update.py # 全量更新(日线+15min)
|
||
python3 daily_all_update.py --skip-daily # 只更新15min
|
||
python3 daily_all_update.py --skip-15min # 只更新日线
|
||
|
||
变更记录:
|
||
v1.0 (2026-05-03) - 初始版本
|
||
v1.1 (2026-05-03) - 司马懿评审:interval→15m, 严格增量, 进度文件, 源检测, DB备份
|
||
v1.2 (2026-05-05) - 东方财富集成:日线主源切换东方财富
|
||
v2.0 (2026-05-06) - 重大架构变更(司马懿+姜维评审通过):
|
||
- BaoStock替代所有主源,多源fallback机制
|
||
- vnpy DB写入改为本地构建+mv原子重命名(解决SMB锁)
|
||
- interval统一1m(vnpy 4.x Interval.MINUTE硬约束)
|
||
- 日线跨年写入修复
|
||
- 进度文件加日期
|
||
- overview增量更新(不做全表聚合)
|
||
- 失败率检测替代固定次数暂停
|
||
- 东方财富当天实时+BaoStock T+1补全
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import re
|
||
import shutil
|
||
import sqlite3
|
||
import sys
|
||
import time
|
||
import logging
|
||
import random
|
||
import urllib.request
|
||
import urllib.error
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from typing import Optional, List, Tuple, Callable
|
||
from collections import deque
|
||
|
||
import pandas as pd
|
||
|
||
try:
|
||
import baostock as bs
|
||
HAS_BAOSTOCK = True
|
||
except ImportError:
|
||
HAS_BAOSTOCK = False
|
||
|
||
try:
|
||
import requests as _requests
|
||
HAS_REQUESTS = True
|
||
except ImportError:
|
||
HAS_REQUESTS = False
|
||
|
||
# ======================== 配置 ========================
|
||
|
||
LOG_DIR = Path("/Volumes/stock/logs/daily_update")
|
||
DAILY_DIR = Path("/Volumes/stock/A股数据/日线数据/daily")
|
||
MINUTE_15_DIR = Path("/Volumes/stock/minute_kline/15min")
|
||
VNPY_DB_PATH = Path("/Volumes/stock/sanguo_vnpy/data/quant_trading.db")
|
||
LOCAL_DB_PATH = Path("/tmp/quant_trading_new.db")
|
||
ALL_STOCKS_FILE = Path("/Volumes/stock/A股数据/stock_info/stock_basic_info_raw_20260326_113530.csv")
|
||
PROGRESS_DIR = Path("/Volumes/stock/logs/daily_update/progress")
|
||
|
||
REQUEST_INTERVAL_EM = 4.0 # 东方财富:4s + 随机抖动±1s
|
||
REQUEST_INTERVAL_SINA = 0.3 # 新浪(已挂,保留)
|
||
REQUEST_INTERVAL_BS = 0.0 # BaoStock:无需限频
|
||
EM_JITTER = 1.0 # 东方财富随机抖动范围(±秒)
|
||
MAX_RETRIES = 3
|
||
# 失败率检测:滑动窗口
|
||
GLOBAL_FAIL_WINDOW = 100 # 最近N只
|
||
GLOBAL_FAIL_THRESHOLD = 0.8 # 失败率阈值
|
||
# DB
|
||
DB_BACKUP_KEEP_DAYS = 7
|
||
BATCH_SIZE = 50000
|
||
# vnpy interval
|
||
# interval='1m' — vnpy 4.x Interval.MINUTE硬约束,实际存储15分钟线
|
||
INTERVAL_MINUTE = "1m"
|
||
INTERVAL_DAILY = "d"
|
||
|
||
HEADERS = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)"}
|
||
HEADERS_EM = {
|
||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
|
||
"Referer": "https://quote.eastmoney.com/",
|
||
"Accept": "*/*",
|
||
"Accept-Language": "zh-CN,zh;q=0.9",
|
||
}
|
||
|
||
|
||
def setup_logging():
|
||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
log_file = LOG_DIR / f"update_{ts}.log"
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(levelname)s %(message)s",
|
||
handlers=[
|
||
logging.FileHandler(log_file, encoding="utf-8"),
|
||
logging.StreamHandler(),
|
||
],
|
||
)
|
||
return logging.getLogger(__name__)
|
||
|
||
|
||
logger = setup_logging()
|
||
|
||
|
||
def _make_opener():
|
||
return urllib.request.build_opener(urllib.request.ProxyHandler({}))
|
||
|
||
|
||
# ======================== 工具函数 ========================
|
||
|
||
def get_market_prefix(code: str) -> Tuple[str, str]:
|
||
"""返回 (prefix, clean_code):sh/sz + 6位数字"""
|
||
code = re.sub(r"[^0-9]", "", code).zfill(6)
|
||
if code.startswith(("60", "68", "51")):
|
||
return "sh", code
|
||
return "sz", code
|
||
|
||
|
||
def code_to_baostock(code: str) -> str:
|
||
"""纯6位代码转BaoStock格式:sh.600000"""
|
||
prefix, clean = get_market_prefix(code)
|
||
return f"{prefix}.{clean}"
|
||
|
||
|
||
def get_all_codes() -> List[str]:
|
||
df = pd.read_csv(ALL_STOCKS_FILE)
|
||
for col in ["代码", "code", "股票代码"]:
|
||
if col in df.columns:
|
||
return [str(c).zfill(6) for c in df[col].tolist()]
|
||
raise ValueError(f"找不到代码列: {list(df.columns)}")
|
||
|
||
|
||
def nas_mounted() -> bool:
|
||
return DAILY_DIR.exists() and MINUTE_15_DIR.exists()
|
||
|
||
|
||
# ======================== DB备份 ========================
|
||
|
||
def rotate_db_backup():
|
||
"""轮转备份NAS vnpy DB,保留最近N天"""
|
||
backup_dir = VNPY_DB_PATH.parent
|
||
today = datetime.now().strftime("%Y%m%d")
|
||
backup_file = backup_dir / f"quant_trading_{today}.db.bak"
|
||
|
||
if backup_file.exists():
|
||
logger.info("DB今日已备份: %s", backup_file)
|
||
return
|
||
|
||
logger.info("开始DB备份: %s → %s", VNPY_DB_PATH.name, backup_file.name)
|
||
try:
|
||
shutil.copy2(str(VNPY_DB_PATH), str(backup_file))
|
||
logger.info("✅ DB备份完成 (%.1f MB)", backup_file.stat().st_size / 1024 / 1024)
|
||
except Exception as e:
|
||
logger.error("❌ DB备份失败: %s", e)
|
||
return
|
||
|
||
cutoff = datetime.now() - timedelta(days=DB_BACKUP_KEEP_DAYS)
|
||
for f in backup_dir.glob("quant_trading_*.db.bak"):
|
||
try:
|
||
date_str = f.stem.split("_")[-1]
|
||
file_date = datetime.strptime(date_str, "%Y%m%d")
|
||
if file_date < cutoff:
|
||
f.unlink()
|
||
logger.info("清理过期备份: %s", f.name)
|
||
except (ValueError, OSError):
|
||
pass
|
||
|
||
|
||
# ======================== 进度文件 ========================
|
||
|
||
def load_progress(name: str) -> set:
|
||
"""加载进度文件(v2.0:带日期)"""
|
||
today = datetime.now().strftime("%Y%m%d")
|
||
progress_file = PROGRESS_DIR / f"{name}_{today}_progress.json"
|
||
PROGRESS_DIR.mkdir(parents=True, exist_ok=True)
|
||
if progress_file.exists():
|
||
try:
|
||
return set(json.loads(progress_file.read_text()).get("done", []))
|
||
except Exception:
|
||
pass
|
||
return set()
|
||
|
||
|
||
def save_progress(name: str, done_set: set):
|
||
"""保存进度文件(v2.0:带日期)"""
|
||
today = datetime.now().strftime("%Y%m%d")
|
||
progress_file = PROGRESS_DIR / f"{name}_{today}_progress.json"
|
||
progress_file.write_text(json.dumps({
|
||
"done": sorted(list(done_set)),
|
||
"ts": datetime.now().isoformat(),
|
||
}))
|
||
|
||
|
||
# ======================== 失败率检测(v2.0) ========================
|
||
|
||
class SourceHealthMonitor:
|
||
"""滑动窗口失败率检测:最近N只中失败率>T则判定源不可用"""
|
||
|
||
def __init__(self, window: int = GLOBAL_FAIL_WINDOW, threshold: float = GLOBAL_FAIL_THRESHOLD):
|
||
self.window = window
|
||
self.threshold = threshold
|
||
self.history = deque(maxlen=window)
|
||
|
||
def report(self, code: str, failed: bool) -> bool:
|
||
"""
|
||
报告单只结果,返回True=源健康,False=源不可用应终止
|
||
"""
|
||
self.history.append(1 if failed else 0)
|
||
if len(self.history) >= 20: # 至少20只才判断
|
||
fail_rate = sum(self.history) / len(self.history)
|
||
if fail_rate >= self.threshold:
|
||
logger.error(
|
||
"⚠️ 源不可用检测触发:最近%d只失败率 %.0f%%,终止更新",
|
||
len(self.history), fail_rate * 100,
|
||
)
|
||
return False
|
||
return True
|
||
|
||
|
||
# ======================== 数据源:BaoStock ========================
|
||
|
||
def fetch_baostock_daily(code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||
"""BaoStock日线:全量历史,无反爬,amount真实,T+1延迟"""
|
||
if not HAS_BAOSTOCK:
|
||
return None
|
||
bs_code = code_to_baostock(code)
|
||
try:
|
||
rs = bs.query_history_k_data_plus(
|
||
bs_code,
|
||
"date,open,high,low,close,volume,amount",
|
||
start_date=start_date.replace("-", ""),
|
||
end_date=end_date.replace("-", ""),
|
||
frequency="d",
|
||
adjustflag="2",
|
||
)
|
||
rows = []
|
||
while (rs.error_code == "0") and rs.next():
|
||
rows.append(rs.get_row_data())
|
||
if not rows:
|
||
return None
|
||
df = pd.DataFrame(rows, columns=["date", "open", "high", "low", "close", "volume", "amount"])
|
||
for c in ["open", "high", "low", "close", "volume", "amount"]:
|
||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||
df = df.dropna(subset=["close"])
|
||
if df.empty:
|
||
return None
|
||
return df
|
||
except Exception as e:
|
||
logger.debug("BaoStock日线失败 %s: %s", code, e)
|
||
return None
|
||
|
||
|
||
def fetch_baostock_15min(code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||
"""BaoStock 15min:全量历史,无反爬,amount真实,T+1延迟"""
|
||
if not HAS_BAOSTOCK:
|
||
return None
|
||
bs_code = code_to_baostock(code)
|
||
try:
|
||
rs = bs.query_history_k_data_plus(
|
||
bs_code,
|
||
"date,time,open,high,low,close,volume,amount",
|
||
start_date=start_date.replace("-", ""),
|
||
end_date=end_date.replace("-", ""),
|
||
frequency="15",
|
||
adjustflag="2",
|
||
)
|
||
rows = []
|
||
while (rs.error_code == "0") and rs.next():
|
||
rows.append(rs.get_row_data())
|
||
if not rows:
|
||
return None
|
||
# BaoStock返回: [date, time(YYYYMMDDHHMMSSSSS), open, high, low, close, volume, amount]
|
||
df = pd.DataFrame(rows)
|
||
# 构造day列:YYYY-MM-DD HH:MM:SS
|
||
df.columns = ["date", "time", "open", "high", "low", "close", "volume", "amount"][:len(df.columns)]
|
||
df["day"] = df["time"].apply(lambda t: f"{t[:4]}-{t[4:6]}-{t[6:8]} {t[8:10]}:{t[10:12]}:00")
|
||
for c in ["open", "high", "low", "close", "volume", "amount"]:
|
||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||
df = df.dropna(subset=["close"])
|
||
if df.empty:
|
||
return None
|
||
return df[["day", "open", "high", "low", "close", "volume", "amount"]]
|
||
except Exception as e:
|
||
logger.debug("BaoStock 15min失败 %s: %s", code, e)
|
||
return None
|
||
|
||
|
||
# ======================== 数据源:东方财富 ========================
|
||
|
||
def _get_em_secid(code: str) -> str:
|
||
if code.startswith(("60", "68", "51")):
|
||
return f"1.{code}"
|
||
return f"0.{code}"
|
||
|
||
|
||
def _parse_em_klines(klines: list) -> Optional[pd.DataFrame]:
|
||
"""解析东方财富K线数据(日线和15min通用)"""
|
||
if not klines:
|
||
return None
|
||
rows = []
|
||
for line in klines:
|
||
parts = line.split(",")
|
||
if len(parts) < 7:
|
||
continue
|
||
rows.append({
|
||
"date": parts[0],
|
||
"open": float(parts[1]),
|
||
"close": float(parts[2]),
|
||
"high": float(parts[3]),
|
||
"low": float(parts[4]),
|
||
"volume": float(parts[5]),
|
||
"amount": float(parts[6]),
|
||
})
|
||
if not rows:
|
||
return None
|
||
return pd.DataFrame(rows)
|
||
|
||
|
||
def fetch_eastmoney_daily(code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||
"""东方财富日线:当天实时,amount真实,4s限频"""
|
||
if not HAS_REQUESTS:
|
||
return None
|
||
secid = _get_em_secid(code)
|
||
ts = str(int(time.time() * 1000))
|
||
url = (
|
||
f"https://push2his.eastmoney.com/api/qt/stock/kline/get?"
|
||
f"secid={secid}&klt=101&fqt=1&"
|
||
f"beg={start_date.replace('-', '')}&end={end_date.replace('-', '')}&"
|
||
f"fields1=f1,f2,f3,f4,f5,f6,f7,f8&"
|
||
f"fields2=f51,f52,f53,f54,f55,f56,f57,f58,f59,f60,f61&"
|
||
f"ut=b2884a393a59ad64002292a3e90d46a5&lmt=10000&"
|
||
f"cb=jQuery_em_{ts}&_={ts}"
|
||
)
|
||
session = _requests.Session()
|
||
session.trust_env = False
|
||
try:
|
||
r = session.get(url, headers=HEADERS_EM, timeout=15, verify=False)
|
||
if r.status_code != 200:
|
||
return None
|
||
text = r.text
|
||
data = json.loads(text[text.index("(") + 1:text.rindex(")")])
|
||
if data.get("rc") != 0:
|
||
return None
|
||
klines = data.get("data", {}).get("klines", [])
|
||
df = _parse_em_klines(klines)
|
||
if df is None:
|
||
return None
|
||
df["date"] = pd.to_datetime(df["date"]).dt.strftime("%Y-%m-%d")
|
||
mask = (df["date"] >= start_date) & (df["date"] <= end_date)
|
||
result = df.loc[mask, ["date", "open", "high", "low", "close", "volume", "amount"]]
|
||
return result if not result.empty else None
|
||
except Exception as e:
|
||
logger.debug("东方财富日线失败 %s: %s", code, e)
|
||
return None
|
||
|
||
|
||
def fetch_eastmoney_15min(code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||
"""东方财富15min:当天实时,约7周历史,4s限频"""
|
||
if not HAS_REQUESTS:
|
||
return None
|
||
secid = _get_em_secid(code)
|
||
ts = str(int(time.time() * 1000))
|
||
url = (
|
||
f"https://push2his.eastmoney.com/api/qt/stock/kline/get?"
|
||
f"secid={secid}&klt=15&fqt=1&"
|
||
f"beg={start_date.replace('-', '')}&end={end_date.replace('-', '')}&"
|
||
f"fields1=f1,f2,f3,f4,f5,f6,f7,f8&"
|
||
f"fields2=f51,f52,f53,f54,f55,f56,f57,f58,f59,f60,f61&"
|
||
f"ut=b2884a393a59ad64002292a3e90d46a5&lmt=100000&"
|
||
f"cb=jQuery_em_{ts}&_={ts}"
|
||
)
|
||
session = _requests.Session()
|
||
session.trust_env = False
|
||
try:
|
||
r = session.get(url, headers=HEADERS_EM, timeout=15, verify=False)
|
||
if r.status_code != 200:
|
||
return None
|
||
text = r.text
|
||
data = json.loads(text[text.index("(") + 1:text.rindex(")")])
|
||
if data.get("rc") != 0:
|
||
return None
|
||
klines = data.get("data", {}).get("klines", [])
|
||
if not klines:
|
||
return None
|
||
rows = []
|
||
for line in klines:
|
||
parts = line.split(",")
|
||
if len(parts) < 7:
|
||
continue
|
||
rows.append({
|
||
"day": parts[0],
|
||
"open": float(parts[1]),
|
||
"close": float(parts[2]),
|
||
"high": float(parts[3]),
|
||
"low": float(parts[4]),
|
||
"volume": float(parts[5]),
|
||
"amount": float(parts[6]),
|
||
})
|
||
if not rows:
|
||
return None
|
||
df = pd.DataFrame(rows)
|
||
# 转day格式:东方财富返回 "2026-04-30 15:00" 或 "2026-04-30"
|
||
df["day"] = df["day"].apply(lambda d: d if " " in str(d) else f"{d} 00:00:00")
|
||
# 补全秒
|
||
df["day"] = df["day"].apply(lambda d: d if d.count(":") == 2 else d + ":00")
|
||
for c in ["open", "high", "low", "close", "volume", "amount"]:
|
||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||
df = df.dropna(subset=["close"])
|
||
if df.empty:
|
||
return None
|
||
return df[["day", "open", "high", "low", "close", "volume", "amount"]]
|
||
except Exception as e:
|
||
logger.debug("东方财富15min失败 %s: %s", code, e)
|
||
return None
|
||
|
||
|
||
# ======================== 数据源:腾讯 ========================
|
||
|
||
def fetch_tencent_daily(code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||
"""腾讯日线:amount有时为0"""
|
||
prefix, clean = get_market_prefix(code)
|
||
tq = f"{prefix}{clean}"
|
||
days = (pd.Timestamp(end_date) - pd.Timestamp(start_date)).days + 10
|
||
url = f"https://web.ifzq.gtimg.cn/appstock/app/fqkline/get?param={tq},day,{start_date},,{days},"
|
||
opener = _make_opener()
|
||
try:
|
||
req = urllib.request.Request(url, headers=HEADERS)
|
||
with opener.open(req, timeout=10) as r:
|
||
raw = r.read().decode("utf-8", errors="replace")
|
||
data = json.loads(raw)
|
||
d = data.get("data")
|
||
if not isinstance(d, dict):
|
||
return None
|
||
klines = d.get(tq, {}).get("day", [])
|
||
if not klines:
|
||
return None
|
||
df = pd.DataFrame(klines)
|
||
ncols = len(df.columns)
|
||
if ncols >= 7:
|
||
df.columns = ["date", "open", "close", "high", "low", "volume", "amount"][:ncols]
|
||
else:
|
||
df.columns = ["date", "open", "close", "high", "low", "volume"][:ncols]
|
||
df["amount"] = 0.0
|
||
if "amount" not in df.columns:
|
||
df["amount"] = 0.0
|
||
for c in ["open", "close", "high", "low", "volume", "amount"]:
|
||
df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)
|
||
df["date"] = pd.to_datetime(df["date"]).dt.strftime("%Y-%m-%d")
|
||
mask = (df["date"] >= start_date) & (df["date"] <= end_date)
|
||
result = df.loc[mask, ["date", "open", "high", "low", "close", "volume", "amount"]]
|
||
return result if not result.empty else None
|
||
except Exception as e:
|
||
logger.debug("腾讯日线失败 %s: %s", code, e)
|
||
return None
|
||
|
||
|
||
# ======================== 数据源:新浪(已挂,保留代码) ========================
|
||
|
||
def try_sina_15min(symbol: str, datalen: int = 800) -> Optional[pd.DataFrame]:
|
||
"""新浪15分钟K线API(当前已挂,保留作为fallback)"""
|
||
url = (
|
||
f"https://quotes.sina.cn/cn/api/jsonp_v2.php/var%20=min15_{symbol}=/"
|
||
f"CN_MarketDataService.getKLineData?symbol={symbol}&scale=15&ma=no&datalen={datalen}"
|
||
)
|
||
opener = _make_opener()
|
||
try:
|
||
req = urllib.request.Request(url, headers=HEADERS)
|
||
with opener.open(req, timeout=15) as r:
|
||
raw = r.read().decode("utf-8", errors="replace")
|
||
m = re.search(r"\((\[.*\])\)", raw, re.DOTALL)
|
||
if not m:
|
||
return None
|
||
data = json.loads(m.group(1))
|
||
if not data:
|
||
return None
|
||
df = pd.DataFrame(data)
|
||
cols = ["day", "open", "high", "low", "close", "volume", "amount"]
|
||
for c in cols:
|
||
if c not in df.columns:
|
||
return None
|
||
return df[cols]
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
# ======================== Fallback机制 ========================
|
||
|
||
def fetch_with_fallback(
|
||
sources: List[Tuple[str, Callable, float]],
|
||
code: str,
|
||
start_date: str,
|
||
end_date: str,
|
||
is_daily: bool = True,
|
||
) -> Tuple[Optional[pd.DataFrame], str]:
|
||
"""
|
||
多源fallback获取数据
|
||
sources: [(name, fetch_fn, interval_seconds), ...]
|
||
返回: (DataFrame或None, 使用的源名)
|
||
"""
|
||
for name, fetch_fn, interval in sources:
|
||
if interval > 0:
|
||
jitter = (hash(code + name) % 200 - 100) / 100.0 * (interval * 0.1)
|
||
time.sleep(max(0, interval + jitter))
|
||
data = None
|
||
for attempt in range(MAX_RETRIES):
|
||
try:
|
||
if is_daily:
|
||
data = fetch_fn(code, start_date, end_date)
|
||
else:
|
||
data = fetch_fn(code, start_date, end_date)
|
||
if data is not None and len(data) > 0:
|
||
return data, name
|
||
except Exception:
|
||
pass
|
||
if attempt < MAX_RETRIES - 1:
|
||
time.sleep(1)
|
||
return None, ""
|
||
|
||
|
||
# 日线数据源(当天实时优先)
|
||
SOURCES_DAILY = [
|
||
("eastmoney", fetch_eastmoney_daily, REQUEST_INTERVAL_EM),
|
||
("baostock", fetch_baostock_daily, REQUEST_INTERVAL_BS),
|
||
("tencent", fetch_tencent_daily, 0),
|
||
]
|
||
|
||
# 15min数据源(当天实时优先)
|
||
SOURCES_15MIN = [
|
||
("eastmoney", fetch_eastmoney_15min, REQUEST_INTERVAL_EM),
|
||
("baostock", fetch_baostock_15min, REQUEST_INTERVAL_BS),
|
||
("sina", lambda code, s, e: try_sina_15min(f"{get_market_prefix(code)[0]}{get_market_prefix(code)[1]}"),
|
||
REQUEST_INTERVAL_SINA),
|
||
]
|
||
|
||
|
||
# ======================== 日线更新 ========================
|
||
|
||
def get_daily_last_date(code: str) -> str:
|
||
"""获取日线Parquet中最后日期"""
|
||
prefix, clean = get_market_prefix(code)
|
||
for year in range(datetime.now().year, 2009, -1):
|
||
fpath = DAILY_DIR / str(year) / f"{prefix}{clean}_daily.parquet"
|
||
if fpath.exists():
|
||
try:
|
||
df = pd.read_parquet(fpath, columns=["date"])
|
||
if not df.empty:
|
||
return str(df["date"].max())[:10]
|
||
except Exception:
|
||
pass
|
||
return ""
|
||
|
||
|
||
def update_daily_parquet(code: str, new_data: pd.DataFrame) -> int:
|
||
"""增量写入日线Parquet(v2.0:按数据日期分年目录)"""
|
||
prefix, clean = get_market_prefix(code)
|
||
new_data = new_data.copy()
|
||
new_data["date"] = new_data["date"].astype(str)
|
||
|
||
total_new = 0
|
||
for yr in new_data["date"].str[:4].unique():
|
||
year_data = new_data[new_data["date"].str[:4] == yr].copy()
|
||
parquet_path = DAILY_DIR / str(yr) / f"{prefix}{clean}_daily.parquet"
|
||
|
||
if parquet_path.exists():
|
||
existing = pd.read_parquet(parquet_path)
|
||
existing["date"] = existing["date"].astype(str)
|
||
combined = pd.concat([existing, year_data], ignore_index=True)
|
||
combined = combined.drop_duplicates(subset=["date"], keep="last")
|
||
combined = combined.sort_values("date").reset_index(drop=True)
|
||
else:
|
||
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
||
combined = year_data
|
||
|
||
tmp = parquet_path.with_suffix(".tmp")
|
||
combined.to_parquet(tmp, index=False)
|
||
tmp.rename(parquet_path)
|
||
total_new += len(year_data)
|
||
|
||
return total_new
|
||
|
||
|
||
def run_daily_update(codes: List[str], local_conn: sqlite3.Connection) -> dict:
|
||
"""日线增量更新"""
|
||
logger.info("=" * 60)
|
||
logger.info("日线增量更新开始,共 %d 只", len(codes))
|
||
today = datetime.now().strftime("%Y-%m-%d")
|
||
|
||
stats = {"updated": 0, "skipped": 0, "failed": 0, "records": 0, "db_records": 0}
|
||
all_db_values = []
|
||
|
||
done_set = load_progress("daily")
|
||
todo = [c for c in codes if c not in done_set]
|
||
logger.info("待更新: %d(已完成: %d)", len(todo), len(done_set))
|
||
|
||
health = SourceHealthMonitor()
|
||
|
||
for i, code in enumerate(todo):
|
||
last_date = get_daily_last_date(code)
|
||
if not last_date:
|
||
stats["skipped"] += 1
|
||
done_set.add(code)
|
||
continue
|
||
next_day = (pd.Timestamp(last_date) + timedelta(days=1)).strftime("%Y-%m-%d")
|
||
if next_day > today:
|
||
stats["skipped"] += 1
|
||
done_set.add(code)
|
||
continue
|
||
|
||
data, source = fetch_with_fallback(SOURCES_DAILY, code, next_day, today, is_daily=True)
|
||
failed = (data is None or data.empty)
|
||
|
||
if not health.report(code, failed):
|
||
logger.error("❌ 日线所有源不可用,终止日线更新")
|
||
stats["source_aborted"] = True
|
||
break
|
||
|
||
if failed:
|
||
stats["failed"] += 1
|
||
done_set.add(code)
|
||
if (i + 1) % 500 == 0:
|
||
logger.info("日线进度: %d/%d updated=%d failed=%d", i + 1, len(todo), stats["updated"], stats["failed"])
|
||
save_progress("daily", done_set)
|
||
continue
|
||
|
||
# 校验
|
||
for c in ["open", "high", "low", "close"]:
|
||
data[c] = pd.to_numeric(data[c], errors="coerce")
|
||
if (data[["close", "open"]] <= 0).any().any():
|
||
stats["failed"] += 1
|
||
done_set.add(code)
|
||
continue
|
||
|
||
try:
|
||
n = update_daily_parquet(code, data)
|
||
stats["updated"] += 1
|
||
stats["records"] += n
|
||
|
||
# 收集vnpy DB数据
|
||
prefix, clean = get_market_prefix(code)
|
||
exchange = "SSE" if prefix == "sh" else "SZSE"
|
||
for _, row in data.iterrows():
|
||
all_db_values.append((
|
||
clean, exchange, str(row["date"]), INTERVAL_DAILY,
|
||
float(row.get("volume", 0)), float(row.get("amount", 0)), 0.0,
|
||
float(row.get("open", 0)), float(row.get("high", 0)),
|
||
float(row.get("low", 0)), float(row.get("close", 0)),
|
||
))
|
||
stats["db_records"] += len(data)
|
||
except Exception as e:
|
||
stats["failed"] += 1
|
||
logger.warning("日线写入失败 %s: %s", code, e)
|
||
|
||
done_set.add(code)
|
||
|
||
if (i + 1) % 500 == 0:
|
||
logger.info("日线进度: %d/%d updated=%d failed=%d src=%s",
|
||
i + 1, len(todo), stats["updated"], stats["failed"], source)
|
||
save_progress("daily", done_set)
|
||
|
||
# 写入本地DB
|
||
if all_db_values:
|
||
_write_local_db(local_conn, all_db_values, "日线")
|
||
|
||
save_progress("daily", done_set)
|
||
if stats.get("source_aborted"):
|
||
pass
|
||
logger.info("日线完成: %s", json.dumps(stats, ensure_ascii=False))
|
||
return stats
|
||
|
||
|
||
# ======================== 15分钟线更新 ========================
|
||
|
||
def get_15min_last_date(parquet_path: Path) -> str:
|
||
"""获取15min Parquet中最后一条时间戳"""
|
||
if not parquet_path.exists():
|
||
return ""
|
||
try:
|
||
df = pd.read_parquet(parquet_path, columns=["day"])
|
||
if not df.empty:
|
||
return str(df["day"].max())
|
||
except Exception:
|
||
pass
|
||
return ""
|
||
|
||
|
||
def fetch_15min_with_fallback(code: str, start_date: str, end_date: str) -> Tuple[Optional[pd.DataFrame], str]:
|
||
"""15min多源fallback(特殊处理新浪接口不同)"""
|
||
prefix, clean = get_market_prefix(code)
|
||
symbol = f"{prefix}{clean}"
|
||
|
||
# 源1:东方财富
|
||
jitter = (hash(code) % 200 - 100) / 100.0 * EM_JITTER
|
||
time.sleep(max(0, REQUEST_INTERVAL_EM + jitter))
|
||
for attempt in range(MAX_RETRIES):
|
||
try:
|
||
data = fetch_eastmoney_15min(code, start_date, end_date)
|
||
if data is not None and len(data) > 0:
|
||
return data, "eastmoney"
|
||
except Exception:
|
||
pass
|
||
if attempt < MAX_RETRIES - 1:
|
||
time.sleep(1)
|
||
|
||
# 源2:BaoStock
|
||
for attempt in range(MAX_RETRIES):
|
||
try:
|
||
data = fetch_baostock_15min(code, start_date, end_date)
|
||
if data is not None and len(data) > 0:
|
||
return data, "baostock"
|
||
except Exception:
|
||
pass
|
||
if attempt < MAX_RETRIES - 1:
|
||
time.sleep(0.5)
|
||
|
||
# 源3:新浪(已挂,保留)
|
||
time.sleep(REQUEST_INTERVAL_SINA)
|
||
for attempt in range(MAX_RETRIES):
|
||
try:
|
||
data = try_sina_15min(symbol)
|
||
if data is not None and len(data) > 0:
|
||
return data, "sina"
|
||
except Exception:
|
||
pass
|
||
if attempt < MAX_RETRIES - 1:
|
||
time.sleep(0.5)
|
||
|
||
return None, ""
|
||
|
||
|
||
def run_15min_update(codes: List[str], local_conn: sqlite3.Connection) -> dict:
|
||
"""15分钟线增量更新"""
|
||
logger.info("=" * 60)
|
||
logger.info("15分钟线增量更新开始,共 %d 只", len(codes))
|
||
|
||
stats = {"updated": 0, "skipped": 0, "failed": 0, "records": 0, "db_records": 0}
|
||
all_db_values = []
|
||
|
||
done_set = load_progress("15min")
|
||
todo = [c for c in codes if c not in done_set]
|
||
logger.info("待更新: %d(已完成: %d)", len(todo), len(done_set))
|
||
|
||
health = SourceHealthMonitor()
|
||
|
||
for i, code in enumerate(todo):
|
||
prefix, clean = get_market_prefix(code)
|
||
parquet_path = MINUTE_15_DIR / f"{prefix}{clean}_15min.parquet"
|
||
|
||
# 获取最后日期,决定增量范围
|
||
last_date = get_15min_last_date(parquet_path)
|
||
if last_date:
|
||
next_dt = (pd.Timestamp(last_date) + timedelta(minutes=15)).strftime("%Y-%m-%d")
|
||
else:
|
||
next_dt = "2024-01-02"
|
||
today = datetime.now().strftime("%Y-%m-%d")
|
||
|
||
if next_dt > today:
|
||
stats["skipped"] += 1
|
||
done_set.add(code)
|
||
continue
|
||
|
||
# 多源fallback
|
||
df_new, source = fetch_15min_with_fallback(code, next_dt, today)
|
||
failed = (df_new is None or (hasattr(df_new, 'empty') and df_new.empty))
|
||
|
||
if not health.report(code, failed):
|
||
logger.error("❌ 15min所有源不可用,终止15min更新")
|
||
stats["source_aborted"] = True
|
||
break
|
||
|
||
if failed:
|
||
stats["failed"] += 1
|
||
done_set.add(code)
|
||
if (i + 1) % 500 == 0:
|
||
logger.info("15min进度: %d/%d updated=%d failed=%d", i + 1, len(todo), stats["updated"], stats["failed"])
|
||
save_progress("15min", done_set)
|
||
continue
|
||
|
||
# 数据校验
|
||
for col in ["open", "high", "low", "close"]:
|
||
df_new[col] = pd.to_numeric(df_new[col], errors="coerce")
|
||
df_new["volume"] = pd.to_numeric(df_new["volume"], errors="coerce").fillna(0)
|
||
df_new["amount"] = pd.to_numeric(df_new["amount"], errors="coerce").fillna(0)
|
||
df_new["day"] = df_new["day"].astype(str)
|
||
|
||
bad = (df_new[["close", "open"]] <= 0).any(axis=1)
|
||
if bad.any():
|
||
df_new = df_new[~bad]
|
||
if df_new.empty:
|
||
stats["failed"] += 1
|
||
done_set.add(code)
|
||
continue
|
||
|
||
# 增量合并
|
||
if last_date:
|
||
df_increment = df_new[df_new["day"] > last_date].copy()
|
||
if df_increment.empty:
|
||
stats["skipped"] += 1
|
||
done_set.add(code)
|
||
continue
|
||
|
||
existing = pd.read_parquet(parquet_path)
|
||
existing["day"] = existing["day"].astype(str)
|
||
combined = pd.concat([existing, df_increment], ignore_index=True)
|
||
combined = combined.sort_values("day").reset_index(drop=True)
|
||
new_rows = len(df_increment)
|
||
else:
|
||
df_increment = df_new
|
||
combined = df_new.sort_values("day").reset_index(drop=True)
|
||
new_rows = len(df_new)
|
||
|
||
# 原子写入Parquet
|
||
tmp = parquet_path.with_suffix(".tmp")
|
||
combined.to_parquet(tmp, index=False)
|
||
tmp.rename(parquet_path)
|
||
|
||
stats["updated"] += 1
|
||
stats["records"] += new_rows
|
||
|
||
# 收集vnpy DB数据
|
||
# interval='1m' — vnpy 4.x Interval.MINUTE硬约束,实际存储15分钟线
|
||
exchange = "SSE" if prefix == "sh" else "SZSE"
|
||
for _, row in df_increment.iterrows():
|
||
all_db_values.append((
|
||
clean, exchange, str(row["day"]),
|
||
INTERVAL_MINUTE, # interval='1m' — vnpy 4.x Interval.MINUTE硬约束,实际存储15分钟线
|
||
float(row.get("volume", 0)), float(row.get("amount", 0)), 0.0,
|
||
float(row.get("open", 0)), float(row.get("high", 0)),
|
||
float(row.get("low", 0)), float(row.get("close", 0)),
|
||
))
|
||
stats["db_records"] += len(df_increment)
|
||
|
||
done_set.add(code)
|
||
|
||
if (i + 1) % 500 == 0:
|
||
logger.info("15min进度: %d/%d updated=%d failed=%d src=%s",
|
||
i + 1, len(todo), stats["updated"], stats["failed"], source)
|
||
save_progress("15min", done_set)
|
||
|
||
# 写入本地DB
|
||
if all_db_values:
|
||
_write_local_db(local_conn, all_db_values, "15min")
|
||
|
||
save_progress("15min", done_set)
|
||
logger.info("15min完成: %s", json.dumps(stats, ensure_ascii=False))
|
||
return stats
|
||
|
||
|
||
# ======================== 本地vnpy DB写入(v2.0) ========================
|
||
|
||
def init_local_db() -> sqlite3.Connection:
|
||
"""初始化本地vnpy DB:从NAS复制或创建新的"""
|
||
local_path = str(LOCAL_DB_PATH)
|
||
|
||
if LOCAL_DB_PATH.exists():
|
||
logger.info("使用已有本地DB: %s", local_path)
|
||
conn = sqlite3.connect(local_path, timeout=30)
|
||
return conn
|
||
|
||
# 从NAS复制
|
||
if VNPY_DB_PATH.exists():
|
||
logger.info("从NAS复制DB到本地: %s → %s", VNPY_DB_PATH, local_path)
|
||
shutil.copy2(str(VNPY_DB_PATH), local_path)
|
||
conn = sqlite3.connect(local_path, timeout=30)
|
||
else:
|
||
logger.info("创建新本地DB: %s", local_path)
|
||
conn = sqlite3.connect(local_path, timeout=30)
|
||
c = conn.cursor()
|
||
c.execute("""CREATE TABLE IF NOT EXISTS dbbardata (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
symbol TEXT NOT NULL,
|
||
exchange TEXT NOT NULL,
|
||
datetime TEXT NOT NULL,
|
||
interval TEXT NOT NULL,
|
||
volume REAL DEFAULT 0,
|
||
turnover REAL DEFAULT 0,
|
||
open_interest REAL DEFAULT 0,
|
||
open_price REAL,
|
||
high_price REAL,
|
||
low_price REAL,
|
||
close_price REAL,
|
||
UNIQUE(symbol, exchange, datetime, interval)
|
||
)""")
|
||
c.execute("""CREATE TABLE IF NOT EXISTS dbbaroverview (
|
||
symbol TEXT NOT NULL,
|
||
exchange TEXT NOT NULL,
|
||
interval TEXT NOT NULL,
|
||
count INTEGER,
|
||
start TEXT,
|
||
end TEXT,
|
||
UNIQUE(symbol, exchange, interval)
|
||
)""")
|
||
conn.commit()
|
||
|
||
c = conn.cursor()
|
||
c.execute("PRAGMA journal_mode=WAL")
|
||
c.execute("PRAGMA synchronous=NORMAL")
|
||
conn.commit()
|
||
return conn
|
||
|
||
|
||
def _write_local_db(conn: sqlite3.Connection, values: list, label: str):
|
||
"""批量写入本地vnpy DB"""
|
||
logger.info("写入本地DB [%s]: %d 条记录", label, len(values))
|
||
c = conn.cursor()
|
||
for i in range(0, len(values), BATCH_SIZE):
|
||
c.executemany(
|
||
"""INSERT OR REPLACE INTO dbbardata
|
||
(symbol,exchange,datetime,interval,volume,turnover,open_interest,
|
||
open_price,high_price,low_price,close_price)
|
||
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||
values[i:i + BATCH_SIZE],
|
||
)
|
||
conn.commit()
|
||
logger.info("✅ 本地DB [%s] 写入完成: %d条", label, len(values))
|
||
|
||
# 增量更新overview(v2.0:只更新本次涉及的symbol)
|
||
_update_overview_incremental(conn, values)
|
||
|
||
|
||
def _update_overview_incremental(conn: sqlite3.Connection, values: list):
|
||
"""增量更新overview:只更新本次涉及的(symbol, exchange, interval)"""
|
||
c = conn.cursor()
|
||
affected = set((v[0], v[1], v[3]) for v in values) # symbol, exchange, interval
|
||
for sym, exc, ivl in affected:
|
||
c.execute(
|
||
"""INSERT OR REPLACE INTO dbbaroverview (symbol, exchange, interval, count, start, end)
|
||
SELECT ?, ?, ?,
|
||
COUNT(*),
|
||
MIN(datetime),
|
||
MAX(datetime)
|
||
FROM dbbardata
|
||
WHERE symbol=? AND exchange=? AND interval=?""",
|
||
(sym, exc, ivl, sym, exc, ivl),
|
||
)
|
||
conn.commit()
|
||
logger.info(" overview增量更新: %d 组", len(affected))
|
||
|
||
|
||
def sync_db_to_nas():
|
||
"""本地DB → NAS mv原子重命名(v2.0)"""
|
||
if not LOCAL_DB_PATH.exists():
|
||
logger.warning("本地DB不存在,跳过同步")
|
||
return
|
||
|
||
local_size = LOCAL_DB_PATH.stat().st_size / 1024 / 1024
|
||
logger.info("同步DB到NAS: %.1f MB", local_size)
|
||
|
||
nas_path = str(VNPY_DB_PATH)
|
||
new_path = nas_path + ".new"
|
||
old_path = nas_path + ".old"
|
||
|
||
# 1. 复制到NAS .new文件
|
||
logger.info(" 复制到 %s", new_path)
|
||
shutil.copy2(str(LOCAL_DB_PATH), new_path)
|
||
|
||
# 2. 原子重命名:old → old备份,current → old,new → current
|
||
# 在NAS同一文件系统上rename是原子的
|
||
try:
|
||
if os.path.exists(old_path):
|
||
os.remove(old_path)
|
||
if os.path.exists(nas_path):
|
||
os.rename(nas_path, old_path)
|
||
os.rename(new_path, nas_path)
|
||
logger.info("✅ DB同步完成(mv原子重命名)")
|
||
|
||
# 清理旧文件
|
||
if os.path.exists(old_path):
|
||
os.remove(old_path)
|
||
logger.info(" 清理旧文件: %s", old_path)
|
||
except Exception as e:
|
||
logger.error("❌ DB同步失败: %s,尝试恢复", e)
|
||
# 恢复:把old改回来
|
||
if os.path.exists(old_path) and not os.path.exists(nas_path):
|
||
os.rename(old_path, nas_path)
|
||
logger.info(" 已恢复旧DB")
|
||
|
||
|
||
# ======================== 告警与报告 ========================
|
||
|
||
def check_failure_rate(stats: dict, label: str) -> bool:
|
||
"""检查失败率"""
|
||
total = stats.get("updated", 0) + stats.get("failed", 0) + stats.get("skipped", 0)
|
||
failed = stats.get("failed", 0)
|
||
if total == 0:
|
||
return False
|
||
rate = failed / total
|
||
if rate > 0.05:
|
||
logger.error("❌ [%s] 失败率 %.1f%% (%d/%d)", label, rate * 100, failed, total)
|
||
return True
|
||
if stats.get("source_aborted"):
|
||
logger.error("❌ [%s] 源不可用导致终止", label)
|
||
return True
|
||
return False
|
||
|
||
|
||
# ======================== 主入口 ========================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="全市场每日增量更新 v2.0")
|
||
parser.add_argument("--skip-daily", action="store_true", help="跳过日线更新")
|
||
parser.add_argument("--skip-15min", action="store_true", help="跳过15分钟线更新")
|
||
parser.add_argument("--fresh-db", action="store_true", help="强制从NAS重新复制DB")
|
||
args = parser.parse_args()
|
||
|
||
if not nas_mounted():
|
||
logger.error("❌ NAS未挂载,退出")
|
||
sys.exit(1)
|
||
|
||
codes = get_all_codes()
|
||
logger.info("全市场股票数: %d", len(codes))
|
||
logger.info("更新时间: %s", datetime.now().isoformat())
|
||
logger.info("版本: v2.0 (多源fallback + 本地DB构建)")
|
||
|
||
# 如果强制刷新或本地DB不存在,先删掉旧的
|
||
if args.fresh_db and LOCAL_DB_PATH.exists():
|
||
LOCAL_DB_PATH.unlink()
|
||
logger.info("已删除旧本地DB(--fresh-db)")
|
||
|
||
# BaoStock login(如果可用)
|
||
if HAS_BAOSTOCK:
|
||
lg = bs.login()
|
||
logger.info("BaoStock login: %s", lg.error_msg)
|
||
|
||
t_start = time.time()
|
||
report = {}
|
||
has_alert = False
|
||
|
||
# DB备份
|
||
rotate_db_backup()
|
||
|
||
# 初始化本地DB
|
||
local_conn = init_local_db()
|
||
|
||
try:
|
||
if not args.skip_daily:
|
||
report["daily"] = run_daily_update(codes, local_conn)
|
||
if check_failure_rate(report["daily"], "日线"):
|
||
has_alert = True
|
||
|
||
if not args.skip_15min:
|
||
report["15min"] = run_15min_update(codes, local_conn)
|
||
if check_failure_rate(report["15min"], "15min"):
|
||
has_alert = True
|
||
finally:
|
||
# 关闭本地DB连接
|
||
local_conn.close()
|
||
|
||
# 同步DB到NAS
|
||
sync_db_to_nas()
|
||
|
||
# BaoStock logout
|
||
if HAS_BAOSTOCK:
|
||
bs.logout()
|
||
logger.info("BaoStock logout")
|
||
|
||
elapsed = time.time() - t_start
|
||
report["elapsed_sec"] = round(elapsed, 1)
|
||
report["has_alert"] = has_alert
|
||
|
||
logger.info("=" * 60)
|
||
if has_alert:
|
||
logger.error("⚠️ 本次更新存在异常,请检查日志")
|
||
else:
|
||
logger.info("✅ 全部完成,耗时 %.1f 秒", elapsed)
|
||
logger.info(json.dumps(report, ensure_ascii=False, indent=2))
|
||
|
||
# 报告文件
|
||
report_file = LOG_DIR / f"report_{datetime.now().strftime('%Y%m%d')}.json"
|
||
report_file.write_text(json.dumps(report, ensure_ascii=False, indent=2))
|
||
|
||
return report
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|