Files
2026-03-25 23:50:25 +08:00

258 lines
7.9 KiB
Python

#!/usr/bin/env python3
"""
批量下载全市场A股数据的主脚本
支持断点续传、失败重试、进度保存
作者:赵云(数据护军)
日期:2026-03-24
"""
import os
import sys
import json
import logging
from datetime import datetime, timedelta
from akshare_vnpy_adapter import AkshareToVnpyAdapter
from tqdm import tqdm
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('batch_downloader.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class BatchDownloader:
"""批量下载器,支持断点续传"""
def __init__(
self,
db_path: str,
progress_file: str = 'download_progress.json'
):
"""
初始化批量下载器
Args:
db_path: 数据库路径
progress_file: 进度保存文件
"""
self.db_path = db_path
self.progress_file = progress_file
self.adapter = None
self.progress = self._load_progress()
def _load_progress(self) -> dict:
"""加载进度"""
if os.path.exists(self.progress_file):
with open(self.progress_file, 'r', encoding='utf-8') as f:
return json.load(f)
return {
'last_code': None,
'completed': [],
'failed': [],
'start_time': None,
'stats': {
'total': 0,
'success': 0,
'failed': 0,
'total_bars': 0
}
}
def _save_progress(self):
"""保存进度"""
with open(self.progress_file, 'w', encoding='utf-8') as f:
json.dump(self.progress, f, ensure_ascii=False, indent=2)
def _update_progress(self, code: str, status: str, bars: int = 0):
"""
更新进度
Args:
code: 股票代码
status: 状态 'success' | 'failed'
bars: 插入的K线数
"""
if status == 'success':
if code not in self.progress['completed']:
self.progress['completed'].append(code)
self.progress['stats']['success'] += 1
self.progress['stats']['total_bars'] += bars
# 从失败列表中移除(如果是重试)
if code in self.progress['failed']:
self.progress['failed'].remove(code)
self.progress['stats']['failed'] -= 1
elif status == 'failed':
if code not in self.progress['failed'] and code not in self.progress['completed']:
self.progress['failed'].append(code)
self.progress['stats']['failed'] += 1
self.progress['last_code'] = code
self._save_progress()
def download(
self,
start_date: str = None,
end_date: str = None,
max_stocks: int = None,
resume: bool = True,
retry_failed: bool = True
) -> dict:
"""
批量下载
Args:
start_date: 开始日期
end_date: 结束日期
max_stocks: 最大下载数量
resume: 是否断点续传
retry_failed: 是否重试失败的
Returns:
统计信息
"""
# 初始化适配器
self.adapter = AkshareToVnpyAdapter(self.db_path)
self.adapter.initialize_database()
# 记录开始时间
if self.progress['start_time'] is None:
self.progress['start_time'] = datetime.now().isoformat()
self._save_progress()
# 获取股票列表
stock_list = self.adapter.get_stock_list()
self.progress['stats']['total'] = len(stock_list)
# 测试模式:限制数量
if max_stocks:
stock_list = stock_list.head(max_stocks)
logger.info(f"测试模式:只下载前 {max_stocks} 只股票")
# 断点续传:从上次位置继续
resume_from = None
if resume and self.progress['last_code']:
resume_from = self.progress['last_code']
logger.info(f"断点续传:从 {resume_from} 继续")
# 过滤已完成的
if resume and self.progress['completed']:
stock_list = stock_list[~stock_list['code'].isin(self.progress['completed'])]
logger.info(f"跳过已完成的 {len(self.progress['completed'])} 只股票")
# 处理队列
queue = stock_list
# 重试失败的
if retry_failed and self.progress['failed']:
failed_stocks = stock_list[stock_list['code'].isin(self.progress['failed'])]
queue = pd.concat([queue, failed_stocks])
logger.info(f"将重试 {len(self.progress['failed'])} 只失败的股票")
# 开始下载
logger.info(f"开始下载,队列中有 {len(queue)} 只股票")
logger.info("=" * 60)
for _, stock in tqdm(queue.iterrows(), total=len(queue)):
code = stock['code']
name = stock['name']
try:
# 下载并插入
inserted = self.adapter.download_and_insert_stock_daily(
code, start_date, end_date
)
if inserted > 0:
self._update_progress(code, 'success', inserted)
logger.info(f"{code} {name}: {inserted}")
else:
self._update_progress(code, 'failed', 0)
logger.warning(f"{code} {name}: 无数据")
except Exception as e:
self._update_progress(code, 'failed', 0)
logger.error(f"{code} {name}: {e}")
# 完成
self.progress['end_time'] = datetime.now().isoformat()
self._save_progress()
logger.info("=" * 60)
logger.info("批量下载完成!")
logger.info(f"总计: {self.progress['stats']['total']}")
logger.info(f"成功: {self.progress['stats']['success']}")
logger.info(f"失败: {self.progress['stats']['failed']}")
logger.info(f"总K线: {self.progress['stats']['total_bars']}")
return self.progress['stats']
def verify(self) -> dict:
"""验证数据完整性"""
if not self.adapter:
self.adapter = AkshareToVnpyAdapter(self.db_path)
result = self.adapter.verify_data_integrity()
# 保存验证结果
verify_file = self.progress_file.replace('.json', '_verify.json')
with open(verify_file, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
return result
def close(self):
"""关闭连接"""
if self.adapter:
self.adapter.close()
def main():
"""主函数"""
import pandas as pd
# 配置
config = {
'db_path': '/Users/chufeng/.openclaw/workspace-pangtong/sanguo_quant_live/running_data/database.db',
'progress_file': '/Users/chufeng/.openclaw/workspace-pangtong/sanguo_quant_live/data-engineering/download_progress.json',
'start_date': '20240101', # 从2024年开始
'max_stocks': None, # None表示全部,测试时可设置如50
'resume': True, # 断点续传
'retry_failed': True # 重试失败的
}
downloader = BatchDownloader(
db_path=config['db_path'],
progress_file=config['progress_file']
)
try:
# 下载
stats = downloader.download(
start_date=config['start_date'],
max_stocks=config['max_stocks'],
resume=config['resume'],
retry_failed=config['retry_failed']
)
# 验证
integrity = downloader.verify()
logger.info("\n" + "=" * 60)
logger.info("下载和验证全部完成!")
logger.info("=" * 60)
finally:
downloader.close()
if __name__ == '__main__':
main()