Files
sanguo_quant_live/zhaoyun-data/scripts/data_acquisition/simple_downloader.py
T
2026-03-28 00:14:34 +08:00

352 lines
13 KiB
Python

#!/usr/bin/env python3
"""
简单稳定的分钟数据下载器
"""
import os
import sys
import time
import json
import logging
from datetime import datetime, timedelta
from typing import List, Optional
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
try:
import akshare as ak
AKSHARE_AVAILABLE = True
except ImportError:
AKSHARE_AVAILABLE = False
print("❌ AKShare未安装,请运行: pip install akshare")
sys.exit(1)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class SimpleMinuteDownloader:
"""简单稳定的分钟数据下载器"""
def __init__(
self,
base_dir: str = "/Users/chufeng/nas/stock/minute_kline",
timeframe: str = "15min",
start_date: str = "2021-01-01",
end_date: str = None,
batch_size: int = 100,
max_workers: int = 5,
retry_count: int = 3
):
"""初始化下载器"""
self.base_dir = base_dir
self.timeframe = timeframe
self.start_date = start_date
self.end_date = end_date or datetime.now().strftime("%Y-%m-%d")
self.batch_size = batch_size
self.max_workers = max_workers
self.retry_count = retry_count
# 创建目录
self.data_dir = os.path.join(self.base_dir, self.timeframe)
self.log_dir = os.path.join(self.base_dir, "logs")
self.report_dir = os.path.join(self.base_dir, "reports")
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(self.log_dir, exist_ok=True)
os.makedirs(self.report_dir, exist_ok=True)
# 下载状态
self.download_stats = {
"total_stocks": 0,
"downloaded_stocks": 0,
"failed_stocks": 0,
"start_time": datetime.now(),
"end_time": None
}
logger.info(f"初始化下载器: {self.timeframe} 数据")
logger.info(f"数据目录: {self.data_dir}")
logger.info(f"时间范围: {self.start_date}{self.end_date}")
def get_all_a_stock_codes(self) -> List[str]:
"""获取所有A股代码"""
logger.info("获取A股代码列表...")
try:
# 使用AKShare获取A股代码
stock_info = ak.stock_info_a_code_name()
if stock_info is not None and not stock_info.empty:
# 提取股票代码,格式化为带市场前缀
codes = []
for _, row in stock_info.iterrows():
code = str(row.get('code', ''))
if not code:
continue
# 添加市场前缀
if code.startswith('6'):
codes.append(f"sh{code}")
elif code.startswith('0') or code.startswith('3'):
codes.append(f"sz{code}")
logger.info(f"获取到 {len(codes)} 只A股代码")
return codes
else:
logger.warning("AKShare返回空数据,使用预定义列表")
return self._get_default_stock_codes()
except Exception as e:
logger.error(f"获取A股代码失败: {e}")
return self._get_default_stock_codes()
def _get_default_stock_codes(self) -> List[str]:
"""获取默认股票代码列表(前1000只)"""
codes = []
# 上证股票
for i in range(1, 600):
codes.append(f"sh600{i:03d}")
# 深证股票
for i in range(1, 300):
codes.append(f"sz000{i:03d}")
for i in range(1, 100):
codes.append(f"sz300{i:03d}")
logger.info(f"使用默认股票代码列表: {len(codes)}")
return codes
def download_stock_data(self, stock_code: str) -> bool:
"""下载单只股票的分钟数据"""
for attempt in range(self.retry_count):
try:
logger.info(f"下载 {stock_code} {self.timeframe} 数据...")
# 根据时间粒度设置参数
period_map = {
'1min': '1',
'5min': '5',
'15min': '15',
'30min': '30',
'60min': '60'
}
period = period_map.get(self.timeframe, '15')
# 下载数据
data = ak.stock_zh_a_minute(
symbol=stock_code,
period=period,
adjust='hfq'
)
# 检查数据是否有效
if data is None:
logger.warning(f"{stock_code}: 返回None数据,跳过")
return False
if hasattr(data, 'empty') and data.empty:
logger.warning(f"{stock_code}: 数据为空,跳过")
return False
# 数据清理
data = data.copy()
if hasattr(data, 'columns'):
data.columns = [col.strip() if isinstance(col, str) else col for col in data.columns]
# 保存数据
parquet_file = os.path.join(self.data_dir, f"{stock_code}_{self.timeframe}.parquet")
csv_file = os.path.join(self.data_dir, f"{stock_code}_{self.timeframe}.csv")
# 保存为Parquet
data.to_parquet(parquet_file, compression='snappy')
# 保存为CSV(便于查看)
data.to_csv(csv_file, index=False)
logger.info(f"{stock_code}: 下载成功 {len(data)} 条记录")
logger.info(f" 保存文件: {parquet_file} ({os.path.getsize(parquet_file) // 1024} KB)")
return True
except Exception as e:
logger.error(f"{stock_code}: 下载失败 (尝试 {attempt + 1}/{self.retry_count}) - {str(e)[:100]}")
if attempt < self.retry_count - 1:
time.sleep(2) # 重试前等待
else:
return False
return False
def download_all_stocks(self, all_stocks: bool = True):
"""下载所有股票数据"""
logger.info("="*70)
logger.info("🚀 开始全量分钟数据下载")
logger.info("="*70)
# 获取股票代码
stock_codes = self.get_all_a_stock_codes()
if not stock_codes:
logger.error("❌ 无法获取股票代码列表")
return
self.download_stats["total_stocks"] = len(stock_codes)
# 检查已下载的股票
downloaded_stocks = set()
if os.path.exists(self.data_dir):
for file in os.listdir(self.data_dir):
if file.endswith(f"_{self.timeframe}.parquet"):
stock_code = file.replace(f"_{self.timeframe}.parquet", "")
downloaded_stocks.add(stock_code)
# 过滤已下载的股票
if not all_stocks and downloaded_stocks:
remaining_stocks = [code for code in stock_codes if code not in downloaded_stocks]
logger.info(f"已下载 {len(downloaded_stocks)} 只,剩余 {len(remaining_stocks)}")
stock_codes = remaining_stocks
logger.info(f"📊 开始下载 {len(stock_codes)} 只股票")
# 分批下载
success_count = 0
fail_count = 0
for i, stock_code in enumerate(stock_codes, 1):
logger.info(f"\n📈 进度: {i}/{len(stock_codes)} ({i/len(stock_codes)*100:.1f}%)")
if self.download_stock_data(stock_code):
success_count += 1
else:
fail_count += 1
# 更新统计
self.download_stats["downloaded_stocks"] = success_count
self.download_stats["failed_stocks"] = fail_count
# 每下载10只股票保存一次进度
if i % 10 == 0 or i == len(stock_codes):
self._save_progress_report()
# 控制下载速度
time.sleep(0.5) # 避免请求过快
# 完成下载
self.download_stats["end_time"] = datetime.now()
self._save_final_report()
logger.info("="*70)
logger.info(f"✅ 下载完成!")
logger.info(f" 成功: {success_count}")
logger.info(f" 失败: {fail_count}")
logger.info(f" 成功率: {success_count/(success_count+fail_count)*100:.1f}%")
logger.info(f" 总耗时: {self.download_stats['end_time'] - self.download_stats['start_time']}")
logger.info("="*70)
def _save_progress_report(self):
"""保存进度报告"""
# 创建可序列化的统计副本
serializable_stats = self.download_stats.copy()
# 转换datetime对象为字符串
if isinstance(serializable_stats.get("start_time"), datetime):
serializable_stats["start_time"] = serializable_stats["start_time"].isoformat()
if isinstance(serializable_stats.get("end_time"), datetime):
serializable_stats["end_time"] = serializable_stats["end_time"].isoformat()
report = {
"timestamp": datetime.now().isoformat(),
"timeframe": self.timeframe,
"stats": serializable_stats,
"progress": {
"percentage": self.download_stats["downloaded_stocks"] / max(1, self.download_stats["total_stocks"]) * 100,
"estimated_time_left": self._estimate_time_left()
}
}
report_file = os.path.join(self.report_dir, f"progress_{self.timeframe}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
def _save_final_report(self):
"""保存最终报告"""
# 创建可序列化的统计副本
serializable_stats = self.download_stats.copy()
# 转换datetime对象为字符串
if isinstance(serializable_stats.get("start_time"), datetime):
serializable_stats["start_time"] = serializable_stats["start_time"].isoformat()
if isinstance(serializable_stats.get("end_time"), datetime):
serializable_stats["end_time"] = serializable_stats["end_time"].isoformat()
report = {
"timestamp": datetime.now().isoformat(),
"timeframe": self.timeframe,
"summary": serializable_stats,
"duration": str(self.download_stats["end_time"] - self.download_stats["start_time"]) if self.download_stats["end_time"] else None,
"success_rate": self.download_stats["downloaded_stocks"] / max(1, self.download_stats["total_stocks"]) * 100
}
report_file = os.path.join(self.report_dir, f"final_{self.timeframe}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
logger.info(f"📄 最终报告已保存: {report_file}")
def _estimate_time_left(self) -> str:
"""估计剩余时间"""
downloaded = self.download_stats["downloaded_stocks"]
total = self.download_stats["total_stocks"]
if downloaded == 0:
return "未知"
elapsed = (datetime.now() - self.download_stats["start_time"]).total_seconds()
time_per_stock = elapsed / downloaded
remaining_stocks = total - downloaded
remaining_seconds = time_per_stock * remaining_stocks
if remaining_seconds < 60:
return f"{int(remaining_seconds)}"
elif remaining_seconds < 3600:
return f"{int(remaining_seconds/60)}分钟"
else:
return f"{int(remaining_seconds/3600)}小时{int((remaining_seconds%3600)/60)}分钟"
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="简单稳定的分钟数据下载器")
parser.add_argument("--timeframe", default="15min", help="时间粒度: 1min, 5min, 15min, 30min, 60min")
parser.add_argument("--start-date", default="2021-01-01", help="开始日期")
parser.add_argument("--end-date", default=None, help="结束日期")
parser.add_argument("--batch-size", type=int, default=100, help="批次大小")
parser.add_argument("--max-workers", type=int, default=5, help="最大工作线程数")
parser.add_argument("--all-stocks", action="store_true", help="下载所有股票(包括已下载的)")
args = parser.parse_args()
print("="*70)
print("🚀 赵云分钟数据下载器启动")
print("="*70)
downloader = SimpleMinuteDownloader(
timeframe=args.timeframe,
start_date=args.start_date,
end_date=args.end_date,
batch_size=args.batch_size,
max_workers=args.max_workers
)
downloader.download_all_stocks(all_stocks=args.all_stocks)
if __name__ == "__main__":
main()