258 lines
7.9 KiB
Python
258 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
批量下载全市场A股数据的主脚本
|
|
支持断点续传、失败重试、进度保存
|
|
作者:赵云(数据护军)
|
|
日期:2026-03-24
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from akshare_vnpy_adapter import AkshareToVnpyAdapter
|
|
from tqdm import tqdm
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('batch_downloader.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BatchDownloader:
|
|
"""批量下载器,支持断点续传"""
|
|
|
|
def __init__(
|
|
self,
|
|
db_path: str,
|
|
progress_file: str = 'download_progress.json'
|
|
):
|
|
"""
|
|
初始化批量下载器
|
|
|
|
Args:
|
|
db_path: 数据库路径
|
|
progress_file: 进度保存文件
|
|
"""
|
|
self.db_path = db_path
|
|
self.progress_file = progress_file
|
|
self.adapter = None
|
|
self.progress = self._load_progress()
|
|
|
|
def _load_progress(self) -> dict:
|
|
"""加载进度"""
|
|
if os.path.exists(self.progress_file):
|
|
with open(self.progress_file, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
return {
|
|
'last_code': None,
|
|
'completed': [],
|
|
'failed': [],
|
|
'start_time': None,
|
|
'stats': {
|
|
'total': 0,
|
|
'success': 0,
|
|
'failed': 0,
|
|
'total_bars': 0
|
|
}
|
|
}
|
|
|
|
def _save_progress(self):
|
|
"""保存进度"""
|
|
with open(self.progress_file, 'w', encoding='utf-8') as f:
|
|
json.dump(self.progress, f, ensure_ascii=False, indent=2)
|
|
|
|
def _update_progress(self, code: str, status: str, bars: int = 0):
|
|
"""
|
|
更新进度
|
|
|
|
Args:
|
|
code: 股票代码
|
|
status: 状态 'success' | 'failed'
|
|
bars: 插入的K线数
|
|
"""
|
|
if status == 'success':
|
|
if code not in self.progress['completed']:
|
|
self.progress['completed'].append(code)
|
|
self.progress['stats']['success'] += 1
|
|
self.progress['stats']['total_bars'] += bars
|
|
|
|
# 从失败列表中移除(如果是重试)
|
|
if code in self.progress['failed']:
|
|
self.progress['failed'].remove(code)
|
|
self.progress['stats']['failed'] -= 1
|
|
|
|
elif status == 'failed':
|
|
if code not in self.progress['failed'] and code not in self.progress['completed']:
|
|
self.progress['failed'].append(code)
|
|
self.progress['stats']['failed'] += 1
|
|
|
|
self.progress['last_code'] = code
|
|
self._save_progress()
|
|
|
|
def download(
|
|
self,
|
|
start_date: str = None,
|
|
end_date: str = None,
|
|
max_stocks: int = None,
|
|
resume: bool = True,
|
|
retry_failed: bool = True
|
|
) -> dict:
|
|
"""
|
|
批量下载
|
|
|
|
Args:
|
|
start_date: 开始日期
|
|
end_date: 结束日期
|
|
max_stocks: 最大下载数量
|
|
resume: 是否断点续传
|
|
retry_failed: 是否重试失败的
|
|
|
|
Returns:
|
|
统计信息
|
|
"""
|
|
# 初始化适配器
|
|
self.adapter = AkshareToVnpyAdapter(self.db_path)
|
|
self.adapter.initialize_database()
|
|
|
|
# 记录开始时间
|
|
if self.progress['start_time'] is None:
|
|
self.progress['start_time'] = datetime.now().isoformat()
|
|
self._save_progress()
|
|
|
|
# 获取股票列表
|
|
stock_list = self.adapter.get_stock_list()
|
|
self.progress['stats']['total'] = len(stock_list)
|
|
|
|
# 测试模式:限制数量
|
|
if max_stocks:
|
|
stock_list = stock_list.head(max_stocks)
|
|
logger.info(f"测试模式:只下载前 {max_stocks} 只股票")
|
|
|
|
# 断点续传:从上次位置继续
|
|
resume_from = None
|
|
if resume and self.progress['last_code']:
|
|
resume_from = self.progress['last_code']
|
|
logger.info(f"断点续传:从 {resume_from} 继续")
|
|
|
|
# 过滤已完成的
|
|
if resume and self.progress['completed']:
|
|
stock_list = stock_list[~stock_list['code'].isin(self.progress['completed'])]
|
|
logger.info(f"跳过已完成的 {len(self.progress['completed'])} 只股票")
|
|
|
|
# 处理队列
|
|
queue = stock_list
|
|
|
|
# 重试失败的
|
|
if retry_failed and self.progress['failed']:
|
|
failed_stocks = stock_list[stock_list['code'].isin(self.progress['failed'])]
|
|
queue = pd.concat([queue, failed_stocks])
|
|
logger.info(f"将重试 {len(self.progress['failed'])} 只失败的股票")
|
|
|
|
# 开始下载
|
|
logger.info(f"开始下载,队列中有 {len(queue)} 只股票")
|
|
logger.info("=" * 60)
|
|
|
|
for _, stock in tqdm(queue.iterrows(), total=len(queue)):
|
|
code = stock['code']
|
|
name = stock['name']
|
|
|
|
try:
|
|
# 下载并插入
|
|
inserted = self.adapter.download_and_insert_stock_daily(
|
|
code, start_date, end_date
|
|
)
|
|
|
|
if inserted > 0:
|
|
self._update_progress(code, 'success', inserted)
|
|
logger.info(f"✓ {code} {name}: {inserted} 条")
|
|
else:
|
|
self._update_progress(code, 'failed', 0)
|
|
logger.warning(f"✗ {code} {name}: 无数据")
|
|
|
|
except Exception as e:
|
|
self._update_progress(code, 'failed', 0)
|
|
logger.error(f"✗ {code} {name}: {e}")
|
|
|
|
# 完成
|
|
self.progress['end_time'] = datetime.now().isoformat()
|
|
self._save_progress()
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("批量下载完成!")
|
|
logger.info(f"总计: {self.progress['stats']['total']}")
|
|
logger.info(f"成功: {self.progress['stats']['success']}")
|
|
logger.info(f"失败: {self.progress['stats']['failed']}")
|
|
logger.info(f"总K线: {self.progress['stats']['total_bars']}")
|
|
|
|
return self.progress['stats']
|
|
|
|
def verify(self) -> dict:
|
|
"""验证数据完整性"""
|
|
if not self.adapter:
|
|
self.adapter = AkshareToVnpyAdapter(self.db_path)
|
|
|
|
result = self.adapter.verify_data_integrity()
|
|
|
|
# 保存验证结果
|
|
verify_file = self.progress_file.replace('.json', '_verify.json')
|
|
with open(verify_file, 'w', encoding='utf-8') as f:
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
|
|
return result
|
|
|
|
def close(self):
|
|
"""关闭连接"""
|
|
if self.adapter:
|
|
self.adapter.close()
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
import pandas as pd
|
|
|
|
# 配置
|
|
config = {
|
|
'db_path': '/Users/chufeng/.openclaw/workspace-pangtong/sanguo_quant_live/running_data/database.db',
|
|
'progress_file': '/Users/chufeng/.openclaw/workspace-pangtong/sanguo_quant_live/data-engineering/download_progress.json',
|
|
'start_date': '20240101', # 从2024年开始
|
|
'max_stocks': None, # None表示全部,测试时可设置如50
|
|
'resume': True, # 断点续传
|
|
'retry_failed': True # 重试失败的
|
|
}
|
|
|
|
downloader = BatchDownloader(
|
|
db_path=config['db_path'],
|
|
progress_file=config['progress_file']
|
|
)
|
|
|
|
try:
|
|
# 下载
|
|
stats = downloader.download(
|
|
start_date=config['start_date'],
|
|
max_stocks=config['max_stocks'],
|
|
resume=config['resume'],
|
|
retry_failed=config['retry_failed']
|
|
)
|
|
|
|
# 验证
|
|
integrity = downloader.verify()
|
|
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("下载和验证全部完成!")
|
|
logger.info("=" * 60)
|
|
|
|
finally:
|
|
downloader.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|