Files
2026-03-25 23:50:25 +08:00

559 lines
17 KiB
Python

#!/usr/bin/env python3
"""
akshare → vn.py 数据适配器
功能:将akshare获取的A股数据转换为vn.py格式并入库到SQLite数据库
作者:赵云(数据护军)
日期:2026-03-24
"""
import akshare as ak
import pandas as pd
import sqlite3
import os
import time
import requests
from datetime import datetime, timedelta
from typing import List, Dict, Optional
import logging
from tqdm import tqdm
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('akshare_vnpy_adapter.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class AkshareToVnpyAdapter:
"""akshare数据转vn.py SQLite数据库适配器"""
# 交易所映射
EXCHANGE_MAP = {
'SSE': 'SH', # 上交所
'SZSE': 'SZ', # 深交所
'BJSE': 'BJ', # 北交所
}
def __init__(self, db_path: str = 'running_data/database.db'):
"""
初始化适配器
Args:
db_path: vn.py SQLite数据库路径
"""
self.db_path = db_path
self._ensure_db_path()
self.conn = self._create_connection()
def _ensure_db_path(self):
"""确保数据库目录存在"""
db_dir = os.path.dirname(self.db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
logger.info(f"创建数据库目录: {db_dir}")
def _create_connection(self) -> sqlite3.Connection:
"""创建数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def initialize_database(self):
"""初始化数据库表结构"""
with self.conn:
# 创建K线数据表
self.conn.execute('''
CREATE TABLE IF NOT EXISTS dbbardata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
exchange TEXT NOT NULL,
datetime TEXT NOT NULL,
interval TEXT NOT NULL,
volume REAL,
turnover REAL,
open_interest REAL,
open_price REAL REAL,
high_price REAL,
low_price REAL,
close_price REAL,
UNIQUE(symbol, exchange, datetime, interval)
)
''')
# 创建索引
self.conn.execute('''
CREATE INDEX IF NOT EXISTS idx_bardata_symbol
ON dbbardata(symbol, exchange, interval, datetime)
''')
self.conn.execute('''
CREATE INDEX IF NOT EXISTS idx_bardata_datetime
ON dbbardata(datetime)
''')
# 创建TICK数据表
self.conn.execute('''
CREATE TABLE IF NOT EXISTS dbtickdata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
exchange TEXT NOT NULL,
datetime TEXT NOT NULL,
name TEXT,
volume REAL,
turnover REAL,
open_interest REAL,
last_price REAL,
last_volume REAL,
limit_up REAL,
limit_down REAL,
open_price REAL,
high_price REAL,
low_price REAL,
pre_close REAL,
bid_price_1 REAL, bid_price_2 REAL, bid_price_3 REAL, bid_price_4 REAL, bid_price_5 REAL,
bid_volume_1 REAL, bid_volume_2 REAL, bid_volume_3 REAL, bid_volume_4 REAL, bid_volume_5 REAL,
ask_price_1 REAL, ask_price_2 REAL, ask_price_3 REAL, ask_price_4 REAL, ask_price_5 REAL,
ask_volume_1 REAL, ask_volume_2 REAL, ask_volume_3 REAL, ask_volume_4 REAL, ask_volume_5 REAL,
UNIQUE(symbol, exchange, datetime)
)
''')
logger.info("数据库表结构初始化完成")
def get_stock_list(self, max_retries: int = 3, retry_delay: int = 5) -> pd.DataFrame:
"""
获取A股全市场股票列表
Args:
max_retries: 最大重试次数
retry_delay: 重试延迟(秒)
Returns:
股票列表DataFrame
"""
for attempt in range(max_retries):
try:
logger.info(f"获取股票列表 (尝试 {attempt + 1}/{max_retries})...")
# 获取A股股票列表
stock_list = ak.stock_zh_a_spot_em()
# 筛选需要的数据
stock_list = stock_list[['代码', '名称', '最新价']]
# 重命名列
stock_list.columns = ['code', 'name', 'price']
logger.info(f"✓ 获取到 {len(stock_list)} 只A股股票")
return stock_list
except (requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
ConnectionError) as e:
if attempt < max_retries - 1:
logger.warning(f"获取股票列表失败: {e}")
logger.info(f"等待 {retry_delay} 秒后重试...")
time.sleep(retry_delay)
else:
logger.error(f"获取股票列表失败,已重试 {max_retries} 次: {e}")
raise
except Exception as e:
logger.error(f"获取股票列表失败: {e}")
raise
def parse_symbol(self, code: str) -> tuple:
"""
解析股票代码,返回symbol和exchange
Args:
code: 股票代码,如 "000001""600000"
Returns:
(symbol, exchange): 如 ("000001", "SZ") 或 ("600000", "SH")
"""
if code.startswith('6'):
exchange = 'SH'
elif code.startswith(('0', '3')):
exchange = 'SZ'
elif code.startswith('8'):
exchange = 'BJ'
else:
exchange = 'SZ' # 默认深交所
return code, exchange
def fetch_stock_daily(
self,
code: str,
start_date: str = None,
end_date: str = None,
adjust: str = ''
) -> pd.DataFrame:
"""
获取单只股票历史K线数据
Args:
code: 股票代码
start_date: 开始日期 "YYYYMMDD"
end_date: 结束日期 "YYYYMMDD"
adjust: 复权类型 ""不复权 "qfq"前复权 "hfq"后复权
Returns:
K线数据DataFrame
"""
try:
# 转换日期格式
if start_date:
start_date = start_date.replace('-', '')
if end_date:
end_date = end_date.replace('-', '')
# 获取数据
df = ak.stock_zh_a_hist(
symbol=code,
period="daily",
start_date=start_date,
end_date=end_date,
adjust=adjust
)
if df is None or len(df) == 0:
logger.warning(f"股票 {code} 无数据")
return pd.DataFrame()
# 标准化列名(英文)
df.rename(columns={
'日期': 'date',
'开盘': 'open',
'收盘': 'close',
'最高': 'high',
'最低': 'low',
'成交量': 'volume',
'成交额': 'turnover',
'振幅': 'amplitude',
'涨跌幅': 'change_pct',
'涨跌额': 'change',
'换手率': 'turnover_rate'
}, inplace=True)
# 格式化日期
df['date'] = pd.to_datetime(df['date']).dt.strftime('%Y-%m-%d %H:%M:%S')
return df
except Exception as e:
logger.error(f"获取股票 {code} 数据失败: {e}")
return pd.DataFrame()
def convert_bar_to_vnpy(
self,
df: pd.DataFrame,
symbol: str,
exchange: str,
interval: str = '1d'
) -> List[Dict]:
"""
将akshare K线数据转换为vn.py格式
Args:
df: akshare K线数据DataFrame
symbol: 股票代码
exchange: 交易所
interval: 周期
Returns:
vn.py格式的K线数据列表
"""
if df is None or len(df) == 0:
return []
bars = []
for _, row in df.iterrows():
bar = {
'symbol': symbol,
'exchange': exchange,
'datetime': row['date'],
'interval': interval,
'open_price': float(row['open']),
'high_price': float(row['high']),
'low_price': float(row['low']),
'close_price': float(row['close']),
'volume': float(row['volume']),
'turnover': float(row.get('turnover', 0)), # 已经是万元
'open_interest': 0.0
}
bars.append(bar)
return bars
def insert_bars_bulk(self, bars: List[Dict], batch_size: int = 1000) -> int:
"""
批量插入K线数据
Args:
bars: K线数据列表
batch_size: 批量大小
Returns:
成功插入的记录数
"""
if not bars:
return 0
total_inserted = 0
total_failed = 0
# 分批处理
for i in range(0, len(bars), batch_size):
batch = bars[i:i + batch_size]
try:
with self.conn:
# 使用executemany批量插入
self.conn.executemany('''
INSERT OR IGNORE INTO dbbardata (
symbol, exchange, datetime, interval,
open_price, high_price, low_price, close_price,
volume, turnover, open_interest
) VALUES (
:symbol, :exchange, :datetime, :interval,
:open_price, :high_price, :low_price, :close_price,
:volume, :turnover, :open_interest
)
''', batch)
inserted = self.conn.total_changes - total_inserted
total_inserted += inserted
except Exception as e:
logger.error(f"批量插入失败: {e}")
total_failed += len(batch)
logger.info(f"批量插入完成: 成功 {total_inserted} 条, 失败 {total_failed}")
return total_inserted
def download_and_insert_stock_daily(
self,
code: str,
start_date: str = None,
end_date: str = None,
interval: str = '1d'
) -> int:
"""
下载单只股票日线数据并入库
Args:
code: 股票代码
start_date: 开始日期
end_date: 结束日期
interval: 周期
Returns:
成功插入的记录数
"""
# 解析代码
symbol, exchange = self.parse_symbol(code)
# 获取数据
df = self.fetch_stock_daily(code, start_date, end_date)
if df is None or len(df) == 0:
return 0
# 转换格式
bars = self.convert_bar_to_vnpy(df, symbol, exchange, interval)
if not bars:
return 0
# 批量插入
inserted = self.insert_bars_bulk(bars)
return inserted
def download_all_stock_daily(
self,
start_date: str = None,
end_date: str = None,
max_stocks: int = None,
resume_from: str = None
) -> Dict:
"""
下载全市场A股日线数据
Args:
start_date: 开始日期
end_date: 结束日期
max_stocks: 最大下载数量(用于测试)
resume_from: 从指定股票代码恢复下载
Returns:
下载统计信息
"""
# 获取股票列表
stock_list = self.get_stock_list()
if max_stocks:
stock_list = stock_list.head(max_stocks)
# 如果需要恢复
if resume_from:
idx = stock_list[stock_list['code'] == resume_from].index
if len(idx) > 0:
stock_list = stock_list.loc[idx[0]:]
# 统计
stats = {
'total': len(stock_list),
'success': 0,
'failed': 0,
'total_bars': 0
}
# 批量下载
logger.info(f"开始下载全市场A股日线数据,共 {len(stock_list)} 只股票")
for _, stock in tqdm(stock_list.iterrows(), total=len(stock_list)):
code = stock['code']
name = stock['name']
try:
inserted = self.download_and_insert_stock_daily(code, start_date, end_date)
if inserted > 0:
stats['success'] += 1
stats['total_bars'] += inserted
logger.debug(f"{code} {name}: 插入 {inserted}")
else:
stats['failed'] += 1
logger.warning(f"{code} {name}: 无数据或失败")
except Exception as e:
stats['failed'] += 1
logger.error(f"{code} {name}: 错误 - {e}")
logger.info("=" * 60)
logger.info("下载完成!")
logger.info(f"总计股票: {stats['total']}")
logger.info(f"成功: {stats['success']}, 失败: {stats['failed']}")
logger.info(f"总K线数: {stats['total_bars']}")
logger.info("=" * 60)
return stats
def verify_data_integrity(self) -> Dict:
"""
验证数据库数据完整性
Returns:
验证结果
"""
with self.conn:
# 查询统计
cursor = self.conn.cursor()
# 总记录数
cursor.execute('SELECT COUNT(*) FROM dbbardata')
total_bars = cursor.fetchone()[0]
# 股票数量
cursor.execute('SELECT COUNT(DISTINCT symbol || exchange) FROM dbbardata')
total_stocks = cursor.fetchone()[0]
# 时间范围
cursor.execute('SELECT MIN(datetime), MAX(datetime) FROM dbbardata')
min_date, max_date = cursor.fetchone()
# 缺失检查
cursor.execute('''
SELECT symbol, exchange, interval, COUNT(*) as count
FROM dbbardata
GROUP BY symbol, exchange, interval
HAVING count < 100
ORDER BY count ASC
LIMIT 10
''')
low_count = cursor.fetchall()
# 重复检查
cursor.execute('''
SELECT symbol, exchange, datetime, COUNT(*) as count
FROM dbbardata
GROUP BY symbol, exchange, datetime
HAVING count > 1
LIMIT 10
''')
duplicates = cursor.fetchall()
result = {
'total_bars': total_bars,
'total_stocks': total_stocks,
'min_date': min_date,
'max_date': max_date,
'low_count_samples': len(low_count),
'has_duplicates': len(duplicates) > 0,
'duplicates_count': len(duplicates),
'status': 'OK' if len(duplicates) == 0 else 'HAS_DUPLICATES'
}
logger.info("=" * 60)
logger.info("数据完整性验证")
logger.info(f"总K线记录: {total_bars}")
logger.info(f"股票数量: {total_stocks}")
logger.info(f"时间范围: {min_date} ~ {max_date}")
logger.info(f"状态: {result['status']}")
if len(duplicates) > 0:
logger.warning(f"发现重复记录: {len(duplicates)}")
logger.info("=" * 60)
return result
def close(self):
"""关闭数据库连接"""
if self.conn:
self.conn.close()
logger.info("数据库连接已关闭")
def main():
"""主函数 - 下载全市场A股数据"""
# 创建适配器
adapter = AkshareToVnpyAdapter(
db_path='/Users/chufeng/.openclaw/workspace-pangtong/sanguo_quant_live/running_data/database.db'
)
try:
# 初始化数据库
adapter.initialize_database()
# 下载全市场数据(可以设置max_stocks进行测试)
# start_date="20000101" # 从2000年开始
stats = adapter.download_all_stock_daily(
start_date="20240101", # 从2024年开始
max_stocks=None, # None表示全部下载,测试时可设置如10
resume_from=None # 从指定股票恢复
)
# 验证数据完整性
integrity = adapter.verify_data_integrity()
# logger.info(f"\n下载统计: {stats}")
# logger.info(f"\n完整性验证: {integrity}")
finally:
# 关闭连接
adapter.close()
if __name__ == '__main__':
main()