#!/usr/bin/env python3 """ 简单稳定的分钟数据下载器 """ import os import sys import time import json import logging from datetime import datetime, timedelta from typing import List, Optional import warnings warnings.filterwarnings('ignore') import pandas as pd try: import akshare as ak AKSHARE_AVAILABLE = True except ImportError: AKSHARE_AVAILABLE = False print("❌ AKShare未安装,请运行: pip install akshare") sys.exit(1) # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class SimpleMinuteDownloader: """简单稳定的分钟数据下载器""" def __init__( self, base_dir: str = "/Users/chufeng/nas/stock/minute_kline", timeframe: str = "15min", start_date: str = "2021-01-01", end_date: str = None, batch_size: int = 100, max_workers: int = 5, retry_count: int = 3 ): """初始化下载器""" self.base_dir = base_dir self.timeframe = timeframe self.start_date = start_date self.end_date = end_date or datetime.now().strftime("%Y-%m-%d") self.batch_size = batch_size self.max_workers = max_workers self.retry_count = retry_count # 创建目录 self.data_dir = os.path.join(self.base_dir, self.timeframe) self.log_dir = os.path.join(self.base_dir, "logs") self.report_dir = os.path.join(self.base_dir, "reports") os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.log_dir, exist_ok=True) os.makedirs(self.report_dir, exist_ok=True) # 下载状态 self.download_stats = { "total_stocks": 0, "downloaded_stocks": 0, "failed_stocks": 0, "start_time": datetime.now(), "end_time": None } logger.info(f"初始化下载器: {self.timeframe} 数据") logger.info(f"数据目录: {self.data_dir}") logger.info(f"时间范围: {self.start_date} 至 {self.end_date}") def get_all_a_stock_codes(self) -> List[str]: """获取所有A股代码""" logger.info("获取A股代码列表...") try: # 使用AKShare获取A股代码 stock_info = ak.stock_info_a_code_name() if stock_info is not None and not stock_info.empty: # 提取股票代码,格式化为带市场前缀 codes = [] for _, row in stock_info.iterrows(): code = str(row.get('code', '')) if not code: continue # 添加市场前缀 if code.startswith('6'): codes.append(f"sh{code}") elif code.startswith('0') or code.startswith('3'): codes.append(f"sz{code}") logger.info(f"获取到 {len(codes)} 只A股代码") return codes else: logger.warning("AKShare返回空数据,使用预定义列表") return self._get_default_stock_codes() except Exception as e: logger.error(f"获取A股代码失败: {e}") return self._get_default_stock_codes() def _get_default_stock_codes(self) -> List[str]: """获取默认股票代码列表(前1000只)""" codes = [] # 上证股票 for i in range(1, 600): codes.append(f"sh600{i:03d}") # 深证股票 for i in range(1, 300): codes.append(f"sz000{i:03d}") for i in range(1, 100): codes.append(f"sz300{i:03d}") logger.info(f"使用默认股票代码列表: {len(codes)} 只") return codes def download_stock_data(self, stock_code: str) -> bool: """下载单只股票的分钟数据""" for attempt in range(self.retry_count): try: logger.info(f"下载 {stock_code} {self.timeframe} 数据...") # 根据时间粒度设置参数 period_map = { '1min': '1', '5min': '5', '15min': '15', '30min': '30', '60min': '60' } period = period_map.get(self.timeframe, '15') # 下载数据 data = ak.stock_zh_a_minute( symbol=stock_code, period=period, adjust='hfq' ) # 检查数据是否有效 if data is None: logger.warning(f"{stock_code}: 返回None数据,跳过") return False if hasattr(data, 'empty') and data.empty: logger.warning(f"{stock_code}: 数据为空,跳过") return False # 数据清理 data = data.copy() if hasattr(data, 'columns'): data.columns = [col.strip() if isinstance(col, str) else col for col in data.columns] # 保存数据 parquet_file = os.path.join(self.data_dir, f"{stock_code}_{self.timeframe}.parquet") csv_file = os.path.join(self.data_dir, f"{stock_code}_{self.timeframe}.csv") # 保存为Parquet data.to_parquet(parquet_file, compression='snappy') # 保存为CSV(便于查看) data.to_csv(csv_file, index=False) logger.info(f"✅ {stock_code}: 下载成功 {len(data)} 条记录") logger.info(f" 保存文件: {parquet_file} ({os.path.getsize(parquet_file) // 1024} KB)") return True except Exception as e: logger.error(f"❌ {stock_code}: 下载失败 (尝试 {attempt + 1}/{self.retry_count}) - {str(e)[:100]}") if attempt < self.retry_count - 1: time.sleep(2) # 重试前等待 else: return False return False def download_all_stocks(self, all_stocks: bool = True): """下载所有股票数据""" logger.info("="*70) logger.info("🚀 开始全量分钟数据下载") logger.info("="*70) # 获取股票代码 stock_codes = self.get_all_a_stock_codes() if not stock_codes: logger.error("❌ 无法获取股票代码列表") return self.download_stats["total_stocks"] = len(stock_codes) # 检查已下载的股票 downloaded_stocks = set() if os.path.exists(self.data_dir): for file in os.listdir(self.data_dir): if file.endswith(f"_{self.timeframe}.parquet"): stock_code = file.replace(f"_{self.timeframe}.parquet", "") downloaded_stocks.add(stock_code) # 过滤已下载的股票 if not all_stocks and downloaded_stocks: remaining_stocks = [code for code in stock_codes if code not in downloaded_stocks] logger.info(f"已下载 {len(downloaded_stocks)} 只,剩余 {len(remaining_stocks)} 只") stock_codes = remaining_stocks logger.info(f"📊 开始下载 {len(stock_codes)} 只股票") # 分批下载 success_count = 0 fail_count = 0 for i, stock_code in enumerate(stock_codes, 1): logger.info(f"\n📈 进度: {i}/{len(stock_codes)} ({i/len(stock_codes)*100:.1f}%)") if self.download_stock_data(stock_code): success_count += 1 else: fail_count += 1 # 更新统计 self.download_stats["downloaded_stocks"] = success_count self.download_stats["failed_stocks"] = fail_count # 每下载10只股票保存一次进度 if i % 10 == 0 or i == len(stock_codes): self._save_progress_report() # 控制下载速度 time.sleep(0.5) # 避免请求过快 # 完成下载 self.download_stats["end_time"] = datetime.now() self._save_final_report() logger.info("="*70) logger.info(f"✅ 下载完成!") logger.info(f" 成功: {success_count} 只") logger.info(f" 失败: {fail_count} 只") logger.info(f" 成功率: {success_count/(success_count+fail_count)*100:.1f}%") logger.info(f" 总耗时: {self.download_stats['end_time'] - self.download_stats['start_time']}") logger.info("="*70) def _save_progress_report(self): """保存进度报告""" # 创建可序列化的统计副本 serializable_stats = self.download_stats.copy() # 转换datetime对象为字符串 if isinstance(serializable_stats.get("start_time"), datetime): serializable_stats["start_time"] = serializable_stats["start_time"].isoformat() if isinstance(serializable_stats.get("end_time"), datetime): serializable_stats["end_time"] = serializable_stats["end_time"].isoformat() report = { "timestamp": datetime.now().isoformat(), "timeframe": self.timeframe, "stats": serializable_stats, "progress": { "percentage": self.download_stats["downloaded_stocks"] / max(1, self.download_stats["total_stocks"]) * 100, "estimated_time_left": self._estimate_time_left() } } report_file = os.path.join(self.report_dir, f"progress_{self.timeframe}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") with open(report_file, 'w', encoding='utf-8') as f: json.dump(report, f, ensure_ascii=False, indent=2) def _save_final_report(self): """保存最终报告""" # 创建可序列化的统计副本 serializable_stats = self.download_stats.copy() # 转换datetime对象为字符串 if isinstance(serializable_stats.get("start_time"), datetime): serializable_stats["start_time"] = serializable_stats["start_time"].isoformat() if isinstance(serializable_stats.get("end_time"), datetime): serializable_stats["end_time"] = serializable_stats["end_time"].isoformat() report = { "timestamp": datetime.now().isoformat(), "timeframe": self.timeframe, "summary": serializable_stats, "duration": str(self.download_stats["end_time"] - self.download_stats["start_time"]) if self.download_stats["end_time"] else None, "success_rate": self.download_stats["downloaded_stocks"] / max(1, self.download_stats["total_stocks"]) * 100 } report_file = os.path.join(self.report_dir, f"final_{self.timeframe}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") with open(report_file, 'w', encoding='utf-8') as f: json.dump(report, f, ensure_ascii=False, indent=2) logger.info(f"📄 最终报告已保存: {report_file}") def _estimate_time_left(self) -> str: """估计剩余时间""" downloaded = self.download_stats["downloaded_stocks"] total = self.download_stats["total_stocks"] if downloaded == 0: return "未知" elapsed = (datetime.now() - self.download_stats["start_time"]).total_seconds() time_per_stock = elapsed / downloaded remaining_stocks = total - downloaded remaining_seconds = time_per_stock * remaining_stocks if remaining_seconds < 60: return f"{int(remaining_seconds)}秒" elif remaining_seconds < 3600: return f"{int(remaining_seconds/60)}分钟" else: return f"{int(remaining_seconds/3600)}小时{int((remaining_seconds%3600)/60)}分钟" def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description="简单稳定的分钟数据下载器") parser.add_argument("--timeframe", default="15min", help="时间粒度: 1min, 5min, 15min, 30min, 60min") parser.add_argument("--start-date", default="2021-01-01", help="开始日期") parser.add_argument("--end-date", default=None, help="结束日期") parser.add_argument("--batch-size", type=int, default=100, help="批次大小") parser.add_argument("--max-workers", type=int, default=5, help="最大工作线程数") parser.add_argument("--all-stocks", action="store_true", help="下载所有股票(包括已下载的)") args = parser.parse_args() print("="*70) print("🚀 赵云分钟数据下载器启动") print("="*70) downloader = SimpleMinuteDownloader( timeframe=args.timeframe, start_date=args.start_date, end_date=args.end_date, batch_size=args.batch_size, max_workers=args.max_workers ) downloader.download_all_stocks(all_stocks=args.all_stocks) if __name__ == "__main__": main()