469 lines
16 KiB
Python
Executable File
469 lines
16 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
vn.py本地数据适配器 - 姜维
|
||
功能:让vn.py优先加载赵云将军下载的本地数据,本地没有再去akshare接口下载
|
||
"""
|
||
|
||
import pandas as pd
|
||
import os
|
||
import glob
|
||
import logging
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, Dict, List, Tuple
|
||
import akshare as ak
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler('vnpy_local_data_adapter.log'),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class VnpyLocalDataAdapter:
|
||
"""
|
||
vn.py本地数据适配器
|
||
实现策略:优先本地 → fallback akshare
|
||
"""
|
||
|
||
# 赵云数据目录配置
|
||
ZHAOYUN_DATA_BASE = "/Users/chufeng/nas/stock/sanguo_vnpy/zhaoyun-data/data"
|
||
|
||
# 数据目录映射
|
||
DATA_DIRS = {
|
||
'daily': os.path.join(ZHAOYUN_DATA_BASE, "raw/daily"),
|
||
'financial': os.path.join(ZHAOYUN_DATA_BASE, "raw/financial"),
|
||
'stock_info': os.path.join(ZHAOYUN_DATA_BASE, "raw/stock_info"),
|
||
'minute': os.path.join(ZHAOYUN_DATA_BASE, "raw/minute_kline"),
|
||
}
|
||
|
||
# vn.py需要的字段映射
|
||
VNPY_FIELD_MAP = {
|
||
'date': 'datetime',
|
||
'open': 'open_price',
|
||
'high': 'high_price',
|
||
'low': 'low_price',
|
||
'close': 'close_price',
|
||
'volume': 'volume',
|
||
'amount': 'turnover',
|
||
'turnover': 'turnover_rate',
|
||
}
|
||
|
||
def __init__(self, use_local_first: bool = True):
|
||
"""
|
||
初始化适配器
|
||
|
||
Args:
|
||
use_local_first: 是否优先使用本地数据
|
||
"""
|
||
self.use_local_first = use_local_first
|
||
self._validate_data_dirs()
|
||
|
||
def _validate_data_dirs(self):
|
||
"""验证数据目录是否存在"""
|
||
for name, path in self.DATA_DIRS.items():
|
||
if os.path.exists(path):
|
||
logger.info(f"✅ 赵云数据目录 {name}: {path}")
|
||
else:
|
||
logger.warning(f"⚠️ 赵云数据目录不存在 {name}: {path}")
|
||
|
||
def _parse_symbol(self, symbol: str) -> Tuple[str, str]:
|
||
"""
|
||
解析股票代码,返回标准化代码和交易所
|
||
|
||
Args:
|
||
symbol: 股票代码,如 "000001.SZ" 或 "600000"
|
||
|
||
Returns:
|
||
(symbol_code, exchange): 如 ("000001", "SZ")
|
||
"""
|
||
# 移除后缀
|
||
if '.' in symbol:
|
||
symbol_code, exchange = symbol.split('.')
|
||
exchange = exchange.upper()
|
||
else:
|
||
symbol_code = symbol
|
||
# 根据代码判断交易所
|
||
if symbol_code.startswith('6'):
|
||
exchange = 'SH'
|
||
elif symbol_code.startswith(('0', '3')):
|
||
exchange = 'SZ'
|
||
elif symbol_code.startswith('8'):
|
||
exchange = 'BJ'
|
||
else:
|
||
exchange = 'SZ' # 默认深交所
|
||
|
||
return symbol_code, exchange
|
||
|
||
def _get_local_daily_file_path(self, symbol: str, year: int) -> Optional[str]:
|
||
"""
|
||
获取本地日线数据文件路径
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
year: 年份
|
||
|
||
Returns:
|
||
文件路径,如果不存在返回None
|
||
"""
|
||
symbol_code, exchange = self._parse_symbol(symbol)
|
||
|
||
# 构建文件名格式
|
||
if exchange == 'SH':
|
||
file_prefix = f"sh{symbol_code}"
|
||
elif exchange == 'SZ':
|
||
file_prefix = f"sz{symbol_code}"
|
||
elif exchange == 'BJ':
|
||
file_prefix = f"bj{symbol_code}"
|
||
else:
|
||
file_prefix = symbol_code
|
||
|
||
# 查找文件
|
||
pattern = os.path.join(self.DATA_DIRS['daily'], str(year), f"{file_prefix}_daily.parquet")
|
||
if os.path.exists(pattern):
|
||
return pattern
|
||
|
||
# 尝试其他可能的文件名格式
|
||
pattern2 = os.path.join(self.DATA_DIRS['daily'], str(year), f"{symbol_code}_daily.parquet")
|
||
if os.path.exists(pattern2):
|
||
return pattern2
|
||
|
||
return None
|
||
|
||
def load_local_daily_data(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||
"""
|
||
从赵云本地数据加载日线数据
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
start_date: 开始日期 "YYYY-MM-DD"
|
||
end_date: 结束日期 "YYYY-MM-DD"
|
||
|
||
Returns:
|
||
日线数据DataFrame,如果本地没有返回None
|
||
"""
|
||
if not self.use_local_first:
|
||
return None
|
||
|
||
try:
|
||
# 解析日期范围
|
||
start_dt = pd.to_datetime(start_date)
|
||
end_dt = pd.to_datetime(end_date)
|
||
|
||
# 收集所有年份的数据
|
||
all_data = []
|
||
for year in range(start_dt.year, end_dt.year + 1):
|
||
file_path = self._get_local_daily_file_path(symbol, year)
|
||
if file_path and os.path.exists(file_path):
|
||
df = pd.read_parquet(file_path)
|
||
|
||
# 过滤日期范围
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
mask = (df['date'] >= start_dt) & (df['date'] <= end_dt)
|
||
df_filtered = df[mask]
|
||
|
||
if not df_filtered.empty:
|
||
all_data.append(df_filtered)
|
||
logger.debug(f"✅ 从本地加载 {symbol} {year}年数据: {len(df_filtered)} 条")
|
||
|
||
if all_data:
|
||
# 合并所有年份数据
|
||
result = pd.concat(all_data, ignore_index=True)
|
||
result = result.sort_values('date')
|
||
|
||
# 转换为vn.py字段名
|
||
result = result.rename(columns=self.VNPY_FIELD_MAP)
|
||
|
||
# 添加symbol和exchange字段
|
||
symbol_code, exchange = self._parse_symbol(symbol)
|
||
result['symbol'] = symbol_code
|
||
result['exchange'] = exchange
|
||
result['interval'] = '1d'
|
||
|
||
logger.info(f"✅ 成功从本地加载 {symbol} 数据: {len(result)} 条 ({start_date} 到 {end_date})")
|
||
return result
|
||
else:
|
||
logger.info(f"⚠️ 本地没有找到 {symbol} 的数据")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 加载本地数据失败 {symbol}: {e}")
|
||
return None
|
||
|
||
def fetch_akshare_daily_data(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||
"""
|
||
从akshare获取日线数据(fallback方案)
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
start_date: 开始日期 "YYYY-MM-DD"
|
||
end_date: 结束日期 "YYYY-MM-DD"
|
||
|
||
Returns:
|
||
日线数据DataFrame
|
||
"""
|
||
try:
|
||
symbol_code, exchange = self._parse_symbol(symbol)
|
||
|
||
# 转换日期格式
|
||
start_date_ak = start_date.replace('-', '')
|
||
end_date_ak = end_date.replace('-', '')
|
||
|
||
logger.info(f"📡 从akshare获取 {symbol} 数据 ({start_date} 到 {end_date})")
|
||
|
||
# 获取数据
|
||
df = ak.stock_zh_a_hist(
|
||
symbol=symbol_code,
|
||
period="daily",
|
||
start_date=start_date_ak,
|
||
end_date=end_date_ak,
|
||
adjust="" # 不复权
|
||
)
|
||
|
||
if df is None or df.empty:
|
||
logger.warning(f"⚠️ akshare没有 {symbol} 的数据")
|
||
return None
|
||
|
||
# 重命名列
|
||
df.rename(columns={
|
||
'日期': 'datetime',
|
||
'开盘': 'open_price',
|
||
'收盘': 'close_price',
|
||
'最高': 'high_price',
|
||
'最低': 'low_price',
|
||
'成交量': 'volume',
|
||
'成交额': 'turnover',
|
||
}, inplace=True)
|
||
|
||
# 格式化日期
|
||
df['datetime'] = pd.to_datetime(df['datetime']).dt.strftime('%Y-%m-%d %H:%M:%S')
|
||
|
||
# 添加其他字段
|
||
df['symbol'] = symbol_code
|
||
df['exchange'] = exchange
|
||
df['interval'] = '1d'
|
||
|
||
logger.info(f"✅ 从akshare获取 {symbol} 数据成功: {len(df)} 条")
|
||
return df
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 从akshare获取数据失败 {symbol}: {e}")
|
||
return None
|
||
|
||
def get_daily_data(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||
"""
|
||
获取日线数据(优先本地,fallback akshare)
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
start_date: 开始日期 "YYYY-MM-DD"
|
||
end_date: 结束日期 "YYYY-MM-DD"
|
||
|
||
Returns:
|
||
日线数据DataFrame,如果都失败返回空DataFrame
|
||
"""
|
||
# 1. 优先尝试本地数据
|
||
if self.use_local_first:
|
||
local_data = self.load_local_daily_data(symbol, start_date, end_date)
|
||
if local_data is not None and not local_data.empty:
|
||
return local_data
|
||
|
||
# 2. fallback到akshare
|
||
akshare_data = self.fetch_akshare_daily_data(symbol, start_date, end_date)
|
||
if akshare_data is not None and not akshare_data.empty:
|
||
return akshare_data
|
||
|
||
# 3. 都失败
|
||
logger.error(f"❌ 无法获取 {symbol} 的数据")
|
||
return pd.DataFrame()
|
||
|
||
def verify_local_data_structure(self, symbol: str) -> Dict:
|
||
"""
|
||
验证本地数据结构是否符合vn.py要求
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
|
||
Returns:
|
||
验证结果字典
|
||
"""
|
||
result = {
|
||
'symbol': symbol,
|
||
'has_local_data': False,
|
||
'data_years': [],
|
||
'missing_fields': [],
|
||
'recommendations': [],
|
||
'status': 'UNKNOWN'
|
||
}
|
||
|
||
try:
|
||
# 查找所有年份的数据
|
||
data_years = []
|
||
for year in range(2010, 2027): # 假设数据范围
|
||
file_path = self._get_local_daily_file_path(symbol, year)
|
||
if file_path and os.path.exists(file_path):
|
||
data_years.append(year)
|
||
|
||
# 检查字段
|
||
df = pd.read_parquet(file_path)
|
||
required_fields = ['date', 'open', 'high', 'low', 'close', 'volume']
|
||
missing = [field for field in required_fields if field not in df.columns]
|
||
|
||
if missing:
|
||
result['missing_fields'].extend(missing)
|
||
|
||
result['data_years'] = data_years
|
||
result['has_local_data'] = len(data_years) > 0
|
||
|
||
if result['has_local_data']:
|
||
if result['missing_fields']:
|
||
result['status'] = 'INCOMPLETE'
|
||
result['recommendations'].append(f"缺少字段: {result['missing_fields']}")
|
||
result['recommendations'].append("建议:使用data_convert_tool.py转换数据格式")
|
||
else:
|
||
result['status'] = 'OK'
|
||
result['recommendations'].append(f"✅ 数据结构完整,覆盖 {min(data_years)}-{max(data_years)} 年")
|
||
else:
|
||
result['status'] = 'NO_DATA'
|
||
result['recommendations'].append("建议:联系赵云将军下载该股票数据")
|
||
|
||
except Exception as e:
|
||
result['status'] = 'ERROR'
|
||
result['recommendations'].append(f"验证错误: {e}")
|
||
|
||
return result
|
||
|
||
|
||
class DataConvertTool:
|
||
"""
|
||
数据格式转换工具
|
||
用于将赵云的数据格式转换为vn.py需要的格式
|
||
"""
|
||
|
||
@staticmethod
|
||
def convert_zhaoyun_to_vnpy(input_path: str, output_path: str, symbol: str):
|
||
"""
|
||
将赵云数据格式转换为vn.py格式
|
||
|
||
Args:
|
||
input_path: 赵云数据文件路径
|
||
output_path: 输出文件路径
|
||
symbol: 股票代码
|
||
"""
|
||
try:
|
||
# 读取赵云数据
|
||
df = pd.read_parquet(input_path)
|
||
|
||
# 检查必要字段
|
||
required = ['date', 'open', 'high', 'low', 'close', 'volume']
|
||
missing = [field for field in required if field not in df.columns]
|
||
if missing:
|
||
raise ValueError(f"缺少必要字段: {missing}")
|
||
|
||
# 转换为vn.py格式
|
||
vnpy_df = pd.DataFrame()
|
||
vnpy_df['datetime'] = pd.to_datetime(df['date']).dt.strftime('%Y-%m-%d %H:%M:%S')
|
||
vnpy_df['open_price'] = df['open']
|
||
vnpy_df['high_price'] = df['high']
|
||
vnpy_df['low_price'] = df['low']
|
||
vnpy_df['close_price'] = df['close']
|
||
vnpy_df['volume'] = df['volume']
|
||
|
||
# 添加其他字段
|
||
if 'amount' in df.columns:
|
||
vnpy_df['turnover'] = df['amount']
|
||
else:
|
||
vnpy_df['turnover'] = df['volume'] * df['close'] # 估算成交额
|
||
|
||
if 'turnover' in df.columns:
|
||
vnpy_df['turnover_rate'] = df['turnover']
|
||
|
||
# 添加标识字段
|
||
symbol_code, exchange = VnpyLocalDataAdapter._parse_symbol(VnpyLocalDataAdapter(), symbol)
|
||
vnpy_df['symbol'] = symbol_code
|
||
vnpy_df['exchange'] = exchange
|
||
vnpy_df['interval'] = '1d'
|
||
|
||
# 保存为parquet
|
||
vnpy_df.to_parquet(output_path, index=False)
|
||
logger.info(f"✅ 数据转换完成: {input_path} → {output_path}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据转换失败: {e}")
|
||
raise
|
||
|
||
|
||
# vn.py数据管理器包装器
|
||
class VnpyDataManagerWrapper:
|
||
"""
|
||
vn.py数据管理器包装器
|
||
替换vn.py默认的数据获取逻辑
|
||
"""
|
||
|
||
def __init__(self, original_data_manager, adapter: VnpyLocalDataAdapter):
|
||
"""
|
||
初始化包装器
|
||
|
||
Args:
|
||
original_data_manager: 原始vn.py数据管理器
|
||
adapter: 本地数据适配器
|
||
"""
|
||
self.original_dm = original_data_manager
|
||
self.adapter = adapter
|
||
self._patch_methods()
|
||
|
||
def _patch_methods(self):
|
||
"""修补vn.py数据获取方法"""
|
||
# 这里需要根据vn.py的具体API进行修补
|
||
# 由于vn.py版本和实现不同,这里提供示例代码
|
||
|
||
logger.info("✅ vn.py数据管理器已修补为优先使用本地数据")
|
||
|
||
def get_daily_bar_data(self, symbol: str, start_date: str, end_date: str):
|
||
"""获取日线数据(重写方法)"""
|
||
return self.adapter.get_daily_data(symbol, start_date, end_date)
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 1. 创建适配器
|
||
adapter = VnpyLocalDataAdapter(use_local_first=True)
|
||
|
||
# 2. 测试数据获取
|
||
test_symbol = "000001.SZ" # 平安银行
|
||
start_date = "2024-01-01"
|
||
end_date = "2024-03-01"
|
||
|
||
print("=" * 60)
|
||
print("vn.py本地数据适配器测试")
|
||
print("=" * 60)
|
||
|
||
# 3. 验证本地数据
|
||
print("\n1. 验证本地数据结构:")
|
||
verification = adapter.verify_local_data_structure(test_symbol)
|
||
for key, value in verification.items():
|
||
print(f" {key}: {value}")
|
||
|
||
# 4. 获取数据
|
||
print(f"\n2. 获取 {test_symbol} 数据 ({start_date} 到 {end_date}):")
|
||
data = adapter.get_daily_data(test_symbol, start_date, end_date)
|
||
|
||
if not data.empty:
|
||
print(f"✅ 成功获取 {len(data)} 条数据")
|
||
print(f"数据字段: {list(data.columns)}")
|
||
print(f"时间范围: {data['datetime'].min()} 到 {data['datetime'].max()}")
|
||
print(f"数据来源: {'本地' if 'outstanding_share' in data.columns else 'akshare'}")
|
||
else:
|
||
print("❌ 获取数据失败")
|
||
|
||
print("\n3. 使用建议:")
|
||
print(" a) 在vn.py策略中导入此适配器")
|
||
print(" b) 替换原有的数据获取逻辑")
|
||
print(" c) 配置赵云数据目录路径")
|
||
print(" d) 定期更新本地数据(联系赵云将军)")
|
||
|
||
print("=" * 60) |