352 lines
13 KiB
Python
352 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
简单稳定的分钟数据下载器
|
|
"""
|
|
import os
|
|
import sys
|
|
import time
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Optional
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
import pandas as pd
|
|
try:
|
|
import akshare as ak
|
|
AKSHARE_AVAILABLE = True
|
|
except ImportError:
|
|
AKSHARE_AVAILABLE = False
|
|
print("❌ AKShare未安装,请运行: pip install akshare")
|
|
sys.exit(1)
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SimpleMinuteDownloader:
|
|
"""简单稳定的分钟数据下载器"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_dir: str = "/Users/chufeng/nas/stock/minute_kline",
|
|
timeframe: str = "15min",
|
|
start_date: str = "2021-01-01",
|
|
end_date: str = None,
|
|
batch_size: int = 100,
|
|
max_workers: int = 5,
|
|
retry_count: int = 3
|
|
):
|
|
"""初始化下载器"""
|
|
self.base_dir = base_dir
|
|
self.timeframe = timeframe
|
|
self.start_date = start_date
|
|
self.end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
|
self.batch_size = batch_size
|
|
self.max_workers = max_workers
|
|
self.retry_count = retry_count
|
|
|
|
# 创建目录
|
|
self.data_dir = os.path.join(self.base_dir, self.timeframe)
|
|
self.log_dir = os.path.join(self.base_dir, "logs")
|
|
self.report_dir = os.path.join(self.base_dir, "reports")
|
|
|
|
os.makedirs(self.data_dir, exist_ok=True)
|
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
os.makedirs(self.report_dir, exist_ok=True)
|
|
|
|
# 下载状态
|
|
self.download_stats = {
|
|
"total_stocks": 0,
|
|
"downloaded_stocks": 0,
|
|
"failed_stocks": 0,
|
|
"start_time": datetime.now(),
|
|
"end_time": None
|
|
}
|
|
|
|
logger.info(f"初始化下载器: {self.timeframe} 数据")
|
|
logger.info(f"数据目录: {self.data_dir}")
|
|
logger.info(f"时间范围: {self.start_date} 至 {self.end_date}")
|
|
|
|
def get_all_a_stock_codes(self) -> List[str]:
|
|
"""获取所有A股代码"""
|
|
logger.info("获取A股代码列表...")
|
|
|
|
try:
|
|
# 使用AKShare获取A股代码
|
|
stock_info = ak.stock_info_a_code_name()
|
|
if stock_info is not None and not stock_info.empty:
|
|
# 提取股票代码,格式化为带市场前缀
|
|
codes = []
|
|
for _, row in stock_info.iterrows():
|
|
code = str(row.get('code', ''))
|
|
if not code:
|
|
continue
|
|
|
|
# 添加市场前缀
|
|
if code.startswith('6'):
|
|
codes.append(f"sh{code}")
|
|
elif code.startswith('0') or code.startswith('3'):
|
|
codes.append(f"sz{code}")
|
|
|
|
logger.info(f"获取到 {len(codes)} 只A股代码")
|
|
return codes
|
|
else:
|
|
logger.warning("AKShare返回空数据,使用预定义列表")
|
|
return self._get_default_stock_codes()
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取A股代码失败: {e}")
|
|
return self._get_default_stock_codes()
|
|
|
|
def _get_default_stock_codes(self) -> List[str]:
|
|
"""获取默认股票代码列表(前1000只)"""
|
|
codes = []
|
|
# 上证股票
|
|
for i in range(1, 600):
|
|
codes.append(f"sh600{i:03d}")
|
|
# 深证股票
|
|
for i in range(1, 300):
|
|
codes.append(f"sz000{i:03d}")
|
|
for i in range(1, 100):
|
|
codes.append(f"sz300{i:03d}")
|
|
|
|
logger.info(f"使用默认股票代码列表: {len(codes)} 只")
|
|
return codes
|
|
|
|
def download_stock_data(self, stock_code: str) -> bool:
|
|
"""下载单只股票的分钟数据"""
|
|
for attempt in range(self.retry_count):
|
|
try:
|
|
logger.info(f"下载 {stock_code} {self.timeframe} 数据...")
|
|
|
|
# 根据时间粒度设置参数
|
|
period_map = {
|
|
'1min': '1',
|
|
'5min': '5',
|
|
'15min': '15',
|
|
'30min': '30',
|
|
'60min': '60'
|
|
}
|
|
|
|
period = period_map.get(self.timeframe, '15')
|
|
|
|
# 下载数据
|
|
data = ak.stock_zh_a_minute(
|
|
symbol=stock_code,
|
|
period=period,
|
|
adjust='hfq'
|
|
)
|
|
|
|
# 检查数据是否有效
|
|
if data is None:
|
|
logger.warning(f"{stock_code}: 返回None数据,跳过")
|
|
return False
|
|
|
|
if hasattr(data, 'empty') and data.empty:
|
|
logger.warning(f"{stock_code}: 数据为空,跳过")
|
|
return False
|
|
|
|
# 数据清理
|
|
data = data.copy()
|
|
if hasattr(data, 'columns'):
|
|
data.columns = [col.strip() if isinstance(col, str) else col for col in data.columns]
|
|
|
|
# 保存数据
|
|
parquet_file = os.path.join(self.data_dir, f"{stock_code}_{self.timeframe}.parquet")
|
|
csv_file = os.path.join(self.data_dir, f"{stock_code}_{self.timeframe}.csv")
|
|
|
|
# 保存为Parquet
|
|
data.to_parquet(parquet_file, compression='snappy')
|
|
|
|
# 保存为CSV(便于查看)
|
|
data.to_csv(csv_file, index=False)
|
|
|
|
logger.info(f"✅ {stock_code}: 下载成功 {len(data)} 条记录")
|
|
logger.info(f" 保存文件: {parquet_file} ({os.path.getsize(parquet_file) // 1024} KB)")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ {stock_code}: 下载失败 (尝试 {attempt + 1}/{self.retry_count}) - {str(e)[:100]}")
|
|
if attempt < self.retry_count - 1:
|
|
time.sleep(2) # 重试前等待
|
|
else:
|
|
return False
|
|
|
|
return False
|
|
|
|
def download_all_stocks(self, all_stocks: bool = True):
|
|
"""下载所有股票数据"""
|
|
logger.info("="*70)
|
|
logger.info("🚀 开始全量分钟数据下载")
|
|
logger.info("="*70)
|
|
|
|
# 获取股票代码
|
|
stock_codes = self.get_all_a_stock_codes()
|
|
if not stock_codes:
|
|
logger.error("❌ 无法获取股票代码列表")
|
|
return
|
|
|
|
self.download_stats["total_stocks"] = len(stock_codes)
|
|
|
|
# 检查已下载的股票
|
|
downloaded_stocks = set()
|
|
if os.path.exists(self.data_dir):
|
|
for file in os.listdir(self.data_dir):
|
|
if file.endswith(f"_{self.timeframe}.parquet"):
|
|
stock_code = file.replace(f"_{self.timeframe}.parquet", "")
|
|
downloaded_stocks.add(stock_code)
|
|
|
|
# 过滤已下载的股票
|
|
if not all_stocks and downloaded_stocks:
|
|
remaining_stocks = [code for code in stock_codes if code not in downloaded_stocks]
|
|
logger.info(f"已下载 {len(downloaded_stocks)} 只,剩余 {len(remaining_stocks)} 只")
|
|
stock_codes = remaining_stocks
|
|
|
|
logger.info(f"📊 开始下载 {len(stock_codes)} 只股票")
|
|
|
|
# 分批下载
|
|
success_count = 0
|
|
fail_count = 0
|
|
|
|
for i, stock_code in enumerate(stock_codes, 1):
|
|
logger.info(f"\n📈 进度: {i}/{len(stock_codes)} ({i/len(stock_codes)*100:.1f}%)")
|
|
|
|
if self.download_stock_data(stock_code):
|
|
success_count += 1
|
|
else:
|
|
fail_count += 1
|
|
|
|
# 更新统计
|
|
self.download_stats["downloaded_stocks"] = success_count
|
|
self.download_stats["failed_stocks"] = fail_count
|
|
|
|
# 每下载10只股票保存一次进度
|
|
if i % 10 == 0 or i == len(stock_codes):
|
|
self._save_progress_report()
|
|
|
|
# 控制下载速度
|
|
time.sleep(0.5) # 避免请求过快
|
|
|
|
# 完成下载
|
|
self.download_stats["end_time"] = datetime.now()
|
|
self._save_final_report()
|
|
|
|
logger.info("="*70)
|
|
logger.info(f"✅ 下载完成!")
|
|
logger.info(f" 成功: {success_count} 只")
|
|
logger.info(f" 失败: {fail_count} 只")
|
|
logger.info(f" 成功率: {success_count/(success_count+fail_count)*100:.1f}%")
|
|
logger.info(f" 总耗时: {self.download_stats['end_time'] - self.download_stats['start_time']}")
|
|
logger.info("="*70)
|
|
|
|
def _save_progress_report(self):
|
|
"""保存进度报告"""
|
|
# 创建可序列化的统计副本
|
|
serializable_stats = self.download_stats.copy()
|
|
|
|
# 转换datetime对象为字符串
|
|
if isinstance(serializable_stats.get("start_time"), datetime):
|
|
serializable_stats["start_time"] = serializable_stats["start_time"].isoformat()
|
|
|
|
if isinstance(serializable_stats.get("end_time"), datetime):
|
|
serializable_stats["end_time"] = serializable_stats["end_time"].isoformat()
|
|
|
|
report = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"timeframe": self.timeframe,
|
|
"stats": serializable_stats,
|
|
"progress": {
|
|
"percentage": self.download_stats["downloaded_stocks"] / max(1, self.download_stats["total_stocks"]) * 100,
|
|
"estimated_time_left": self._estimate_time_left()
|
|
}
|
|
}
|
|
|
|
report_file = os.path.join(self.report_dir, f"progress_{self.timeframe}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
|
|
with open(report_file, 'w', encoding='utf-8') as f:
|
|
json.dump(report, f, ensure_ascii=False, indent=2)
|
|
|
|
def _save_final_report(self):
|
|
"""保存最终报告"""
|
|
# 创建可序列化的统计副本
|
|
serializable_stats = self.download_stats.copy()
|
|
|
|
# 转换datetime对象为字符串
|
|
if isinstance(serializable_stats.get("start_time"), datetime):
|
|
serializable_stats["start_time"] = serializable_stats["start_time"].isoformat()
|
|
|
|
if isinstance(serializable_stats.get("end_time"), datetime):
|
|
serializable_stats["end_time"] = serializable_stats["end_time"].isoformat()
|
|
|
|
report = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"timeframe": self.timeframe,
|
|
"summary": serializable_stats,
|
|
"duration": str(self.download_stats["end_time"] - self.download_stats["start_time"]) if self.download_stats["end_time"] else None,
|
|
"success_rate": self.download_stats["downloaded_stocks"] / max(1, self.download_stats["total_stocks"]) * 100
|
|
}
|
|
|
|
report_file = os.path.join(self.report_dir, f"final_{self.timeframe}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
|
|
with open(report_file, 'w', encoding='utf-8') as f:
|
|
json.dump(report, f, ensure_ascii=False, indent=2)
|
|
|
|
logger.info(f"📄 最终报告已保存: {report_file}")
|
|
|
|
def _estimate_time_left(self) -> str:
|
|
"""估计剩余时间"""
|
|
downloaded = self.download_stats["downloaded_stocks"]
|
|
total = self.download_stats["total_stocks"]
|
|
|
|
if downloaded == 0:
|
|
return "未知"
|
|
|
|
elapsed = (datetime.now() - self.download_stats["start_time"]).total_seconds()
|
|
time_per_stock = elapsed / downloaded
|
|
remaining_stocks = total - downloaded
|
|
remaining_seconds = time_per_stock * remaining_stocks
|
|
|
|
if remaining_seconds < 60:
|
|
return f"{int(remaining_seconds)}秒"
|
|
elif remaining_seconds < 3600:
|
|
return f"{int(remaining_seconds/60)}分钟"
|
|
else:
|
|
return f"{int(remaining_seconds/3600)}小时{int((remaining_seconds%3600)/60)}分钟"
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="简单稳定的分钟数据下载器")
|
|
parser.add_argument("--timeframe", default="15min", help="时间粒度: 1min, 5min, 15min, 30min, 60min")
|
|
parser.add_argument("--start-date", default="2021-01-01", help="开始日期")
|
|
parser.add_argument("--end-date", default=None, help="结束日期")
|
|
parser.add_argument("--batch-size", type=int, default=100, help="批次大小")
|
|
parser.add_argument("--max-workers", type=int, default=5, help="最大工作线程数")
|
|
parser.add_argument("--all-stocks", action="store_true", help="下载所有股票(包括已下载的)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("="*70)
|
|
print("🚀 赵云分钟数据下载器启动")
|
|
print("="*70)
|
|
|
|
downloader = SimpleMinuteDownloader(
|
|
timeframe=args.timeframe,
|
|
start_date=args.start_date,
|
|
end_date=args.end_date,
|
|
batch_size=args.batch_size,
|
|
max_workers=args.max_workers
|
|
)
|
|
|
|
downloader.download_all_stocks(all_stocks=args.all_stocks)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |