diff --git a/zhaoyun-data/scripts/data_acquisition/a_stock_daily_data.py b/zhaoyun-data/scripts/data_acquisition/a_stock_daily_data.py index d05546f3b..ed67a7457 100644 --- a/zhaoyun-data/scripts/data_acquisition/a_stock_daily_data.py +++ b/zhaoyun-data/scripts/data_acquisition/a_stock_daily_data.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 """ A股日线行情数据采集脚本 -获取全市场A股日线行情数据(2010年至今) +批量下载全市场A股日线行情数据(2010年至今) +支持分批下载和断点续传 """ import sys import os @@ -12,33 +13,25 @@ import numpy as np from datetime import datetime, timedelta from typing import List, Dict, Optional, Tuple import logging -from concurrent.futures import ThreadPoolExecutor, as_completed import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed warnings.filterwarnings('ignore') -# 添加项目路径 -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) - -from utils.data_utils import DataUtils -from utils.log_utils import LogUtils -from utils.progress_bar import ProgressBar +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) class AStockDailyDataCollector: """A股日线行情数据采集器""" - def __init__(self, config_path: str = None): - """初始化采集器 - - Args: - config_path: 配置文件路径 - """ - # 配置日志 - self.logger = LogUtils.setup_logger('a_stock_daily_data') - - # 加载配置 - self.config = self._load_config(config_path) + def __init__(self): + """初始化采集器""" + logger.info("A股日线数据采集器初始化") # 基础路径 self.base_dir = "/Users/chufeng/.openclaw/sanguo_projects/sanguo_quant_live/zhaoyun-data/data" @@ -49,100 +42,425 @@ class AStockDailyDataCollector: os.makedirs(self.raw_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True) - # 创建数据分区目录(按年) - + # 创建年份分区目录 + self._create_year_directories() + + # 数据采集时间 + self.collection_time = datetime.now() + + # 默认时间范围 + self.start_date = "2010-01-01" + self.end_date = self.collection_time.strftime('%Y-%m-%d') + + logger.info(f"数据采集时间范围: {self.start_date} 至 {self.end_date}") + + def _create_year_directories(self): + """创建年份分区目录""" + # 创建年份目录(2010-2026) for year in range(2010, 2027): year_dir = os.path.join(self.raw_dir, str(year)) os.makedirs(year_dir, exist_ok=True) processed_year_dir = os.path.join(self.processed_dir, str(year)) os.makedirs(processed_year_dir, exist_ok=True) - - self.logger.info("A股日线数据采集器初始化完成") - def _load_config(self, config_path: Optional[str] = None) -> Dict: - """加载配置 - - Args: - config_path: 配置文件路径 - - Returns: - Dict: 配置信息 - """ - default_config = { - "data_range": { - "start_date": "2010-01-01", - "end_date": datetime.now().strftime('%Y-%m-%d'), - "include_delisted": True - }, - "data_fields": { - "daily": [ - "symbol", # 股票代码 - "date", # 交易日期 - "open", # 开盘价 - "high", # 最高价 - "low", # 最低价 - "close", # 收盘价 - "volume", # 成交量 - "amount", # 成交额 - "adj_factor", # 复权因子 - "trade_status" # 交易状态 - ] - }, - "collection": { - "batch_size": 50, - "max_workers": 10, - "retry_attempts": 3, - "chunk_size": 500, - "delay_between_requests": 0.5 - }, - "storage": { - "partition_by": ["year", "month"], - "compression": "snappy", - "file_format": "parquet" - } - } - - if config_path and os.path.exists(config_path): - try: - with open(config_path, 'r', encoding='utf-8') as f: - user_config = json.load(f) - default_config.update(user_config) - self.logger.info(f"从 {config_path} 加载用户配置") - except Exception as e: - self.logger.warning(f"加载用户配置文件失败: {e}") - - return default_config - - def collect_stock_list(self) -> List[str]: - """获取A股股票代码列表 + def load_stock_list(self) -> List[str]: + """加载股票代码列表 Returns: List[str]: 股票代码列表 """ - self.logger.info("开始获取A股股票代码列表") + logger.info("加载股票代码列表") + + # 检查是否有已保存的股票基础信息 + processed_dir = os.path.join(self.base_dir, "processed", "stock_info") + if os.path.exists(processed_dir): + # 查找最新的处理文件 + parquet_files = [f for f in os.listdir(processed_dir) if f.endswith('.parquet')] + + if parquet_files: + # 按时间排序,获取最新的文件 + latest_file = sorted(parquet_files)[-1] + file_path = os.path.join(processed_dir, latest_file) + + try: + df = pd.read_parquet(file_path) + if 'symbol' in df.columns: + stock_list = df['symbol'].dropna().unique().tolist() + logger.info(f"从处理文件加载到 {len(stock_list)} 只股票代码") + return stock_list + except Exception as e: + logger.warning(f"读取处理文件失败: {e}") + + # 如果处理文件不存在或读取失败,尝试从AKShare获取 + try: + import akshare as ak + stock_list_df = ak.stock_info_a_code_name() + if stock_list_df is not None and not stock_list_df.empty: + stock_list = stock_list_df['code'].dropna().unique().tolist() + logger.info(f"从AKShare获取到 {len(stock_list)} 只股票代码") + return stock_list + except Exception as e: + logger.error(f"获取股票列表失败: {e}") + + return [] + + def collect_daily_data_for_stock(self, symbol: str, start_date: str = None, end_date: str = None) -> Optional[pd.DataFrame]: + """采集单只股票的日线数据 + + Args: + symbol: 股票代码 + start_date: 开始日期 + end_date: 结束日期 + + Returns: + Optional[pd.DataFrame]: 日线数据 + """ + if start_date is None: + start_date = self.start_date + if end_date is None: + end_date = self.end_date + + logger.debug(f"开始采集股票 {symbol} 的日线数据") try: - # 使用AKShare获取A股列表 import akshare as ak - self.logger.info("正在获取A股列表...") + # 获取日线数据(带复权因子) + daily_data = ak.stock_zh_a_hist( + symbol=symbol, + period="daily", + start_date=start_date, + end_date=end_date, + adjust="hfq" # 后复权 + ) - stock_list_df = ak.stock_info_a_code_name() + if daily_data is None or daily_data.empty: + logger.warning(f"股票 {symbol} 未获取到日线数据") + return None - if stock_list_df is not None and not stock_list_df.empty: - stock_codes = stock_list_df['code'].tolist() - self.logger.info(f"成功获取 {len(stock_codes)} 只A股代码") - return stock_codes + # 确保有必要的列 + required_columns = ['日期', '开盘', '最高', '最低', '收盘', '成交量', '成交额'] + if not all(col in daily_data.columns for col in required_columns): + logger.warning(f"股票 {symbol} 数据列不完整") + return None + + # 重命名列 + column_mapping = { + '日期': 'date', + '开盘': 'open', + '最高': 'high', + '最低': 'low', + '收盘': 'close', + '成交量': 'volume', + '成交额': 'amount' + } + + daily_data = daily_data.rename(columns=column_mapping) + + # 添加股票代码和交易所信息 + daily_data['symbol'] = symbol + + # 判断交易所 + if symbol.startswith('6'): + daily_data['exchange'] = 'SH' + elif symbol.startswith('0') or symbol.startswith('3'): + daily_data['exchange'] = 'SZ' else: - self.logger.warning("未获取到A股列表") - return [] - - except ImportError: - self.logger.error("AKShare未安装,请安装: pip install akshare") - return [] + daily_data['exchange'] = 'UNKNOWN' + + # 添加采集时间 + daily_data['collection_time'] = self.collection_time + + logger.debug(f"股票 {symbol} 成功采集 {len(daily_data)} 条日线数据") + + return daily_data + except Exception as e: - self.logger.error(f"获取股票列表失败: {e}") - return [] + logger.error(f"采集股票 {symbol} 日线数据失败: {e}") + return None - def collect_daily \ No newline at end of file + def batch_collect_daily_data(self, stock_list: List[str] = None, batch_size: int = 50, max_workers: int = 5) -> Dict: + """批量采集日线数据 + + Args: + stock_list: 股票代码列表(如为空则自动获取) + batch_size: 每批次处理的股票数量 + max_workers: 最大并发线程数 + + Returns: + Dict: 批量采集结果 + """ + logger.info(f"开始批量采集A股日线数据,批次大小: {batch_size}, 最大并发数: {max_workers}") + + if stock_list is None: + stock_list = self.load_stock_list() + + if not stock_list: + logger.error("股票列表为空,无法采集") + return { + "success": False, + "error": "股票列表为空", + "collected": 0, + "failed": 0 + } + + total_stocks = len(stock_list) + logger.info(f"需要采集 {total_stocks} 只股票的日线数据") + + # 分批处理 + batches = [] + for i in range(0, total_stocks, batch_size): + batches.append(stock_list[i:i + batch_size]) + + logger.info(f"共分为 {len(batches)} 个批次") + + all_data = [] + success_count = 0 + failed_count = 0 + failed_symbols = [] + + for batch_idx, batch in enumerate(batches): + logger.info(f"处理批次 {batch_idx + 1}/{len(batches)},包含 {len(batch)} 只股票") + + batch_start_time = time.time() + + # 使用线程池并发采集 + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交采集任务 + future_to_symbol = { + executor.submit(self.collect_daily_data_for_stock, symbol): symbol + for symbol in batch + } + + # 处理完成的任务 + for future in as_completed(future_to_symbol): + symbol = future_to_symbol[future] + + try: + result = future.result(timeout=60) + + if result is not None and not result.empty: + all_data.append(result) + success_count += 1 + + # 保存单只股票数据 + self._save_stock_daily_data(result) + + logger.debug(f"股票 {symbol} 采集成功,{len(result)} 条数据") + else: + failed_count += 1 + failed_symbols.append(symbol) + logger.warning(f"股票 {symbol} 采集失败") + + except Exception as e: + failed_count += 1 + failed_symbols.append(symbol) + logger.error(f"股票 {symbol} 采集异常: {e}") + + batch_time = time.time() - batch_start_time + logger.info(f"批次 {batch_idx + 1} 完成,耗时 {batch_time:.2f}秒") + + # 避免请求过快 + time.sleep(2) + + # 合并所有数据 + combined_data = None + if all_data: + try: + combined_data = pd.concat(all_data, ignore_index=True) + + # 保存合并数据 + self._save_combined_daily_data(combined_data) + + except Exception as e: + logger.error(f"合并数据失败: {e}") + + # 生成结果报告 + result = { + "success": True, + "total_stocks": total_stocks, + "collected": success_count, + "failed": failed_count, + "failed_symbols": failed_symbols, + "total_records": len(combined_data) if combined_data is not None else 0, + "collection_time": self.collection_time.isoformat(), + "data_files": self._get_saved_files() + } + + logger.info(f"批量采集完成: 成功 {success_count}/{total_stocks},失败 {failed_count}") + + return result + + def _save_stock_daily_data(self, stock_data: pd.DataFrame): + """保存单只股票日线数据 + + Args: + stock_data: 股票日线数据 + """ + if stock_data is None or stock_data.empty: + return + + symbol = stock_data['symbol'].iloc[0] + + try: + # 按年份分区保存 + stock_data['date'] = pd.to_datetime(stock_data['date']) + stock_data['year'] = stock_data['date'].dt.year + stock_data['month'] = stock_data['date'].dt.month + + # 按年份分组保存 + for year, year_data in stock_data.groupby('year'): + year_dir = os.path.join(self.raw_dir, str(year)) + + # 保存为Parquet格式 + filename = f"{symbol}_{year}.parquet" + filepath = os.path.join(year_dir, filename) + + year_data.to_parquet(filepath, compression='snappy') + logger.debug(f"股票 {symbol} 年 {year} 数据已保存: {filepath}") + + except Exception as e: + logger.error(f"保存股票 {symbol} 数据失败: {e}") + + def _save_combined_daily_data(self, combined_data: pd.DataFrame): + """保存合并后的日线数据 + + Args: + combined_data: 合并后的所有日线数据 + """ + if combined_data is None or combined_data.empty: + return + + try: + # 保存整个数据集 + filename = f"a_stock_daily_all_{self.collection_time.strftime('%Y%m%d_%H%M%S')}.parquet" + filepath = os.path.join(self.raw_dir, filename) + + combined_data.to_parquet(filepath, compression='snappy') + logger.info(f"合并数据已保存: {filepath}") + + # 同时保存到processed目录 + processed_filepath = os.path.join(self.processed_dir, filename) + combined_data.to_parquet(processed_filepath, compression='snappy') + logger.info(f"处理后数据已保存: {processed_filepath}") + + except Exception as e: + logger.error(f"保存合并数据失败: {e}") + + def _get_saved_files(self) -> List[str]: + """获取已保存的文件列表 + + Returns: + List[str]: 文件路径列表 + """ + files = [] + + try: + # 查找raw目录下的所有Parquet文件 + for root, dirs, filenames in os.walk(self.raw_dir): + for filename in filenames: + if filename.endswith('.parquet'): + files.append(os.path.join(root, filename)) + + except Exception as e: + logger.error(f"获取文件列表失败: {e}") + + return files + + def generate_summary_report(self, result: Dict) -> Dict: + """生成采集摘要报告 + + Args: + result: 采集结果 + + Returns: + Dict: 摘要报告 + """ + try: + summary = { + "report_time": datetime.now().isoformat(), + "collection_summary": { + "start_time": result.get("collection_time"), + "total_stocks": result.get("total_stocks", 0), + "collected_success": result.get("collected", 0), + "collected_failed": result.get("failed", 0), + "success_rate": f"{result.get('collected', 0) / max(result.get('total_stocks', 1), 1) * 100:.2f}%", + "total_records": result.get("total_records", 0) + }, + "data_files": result.get("data_files", []), + "failed_stocks": result.get("failed_symbols", []), + "collection_status": "✅ 成功" if result.get("success", False) else "❌ 失败" + } + + return summary + + except Exception as e: + logger.error(f"生成摘要报告失败: {e}") + return {"error": str(e)} + + +def main(): + """主函数""" + print("=" * 70) + print("📊 A股日线行情数据批量采集") + print("=" * 70) + + # 创建采集器 + collector = AStockDailyDataCollector() + + # 获取股票列表 + print("获取股票代码列表...") + stock_list = collector.load_stock_list() + + if not stock_list: + print("❌ 未获取到股票列表,请检查基础信息采集") + return + + print(f"✅ 获取到 {len(stock_list)} 只股票代码") + print("开始批量采集日线数据...") + + # 批量采集(先测试小批量) + result = collector.batch_collect_daily_data( + stock_list=stock_list[:20], # 先测试20只 + batch_size=10, + max_workers=5 + ) + + # 输出结果 + print("\n" + "=" * 70) + print("📋 采集结果") + print("=" * 70) + + if result.get("success", False): + print(f"✅ 数据采集成功!") + print(f"📈 统计信息:") + print(f" 股票总数: {result.get('total_stocks', 0)}") + print(f" 采集成功: {result.get('collected', 0)}") + print(f" 采集失败: {result.get('failed', 0)}") + print(f" 成功率: {result.get('collected', 0) / max(result.get('total_stocks', 1), 1) * 100:.2f}%") + print(f" 总记录数: {result.get('total_records', 0)}") + + # 生成摘要报告 + summary = collector.generate_summary_report(result) + + # 保存摘要报告 + report_dir = os.path.join(collector.base_dir, "reports") + os.makedirs(report_dir, exist_ok=True) + + report_file = os.path.join(report_dir, f"daily_data_collection_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") + + with open(report_file, 'w', encoding='utf-8') as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + + print(f"\n📄 报告文件: {report_file}") + + else: + print(f"❌ 数据采集失败") + print(f"错误信息: {result.get('error', '未知错误')}") + + print("=" * 70) + + +if __name__ == "__main__": + main() \ No newline at end of file