Files

389 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# AKShare-vnPy数据适配器 - 赵云数据工程工具
# 将AKShare数据格式转换为vnPy兼容格式
import sys
import os
import json
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Union, Any
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class AKShareDataAdapter:
"""AKShare到vnPy的数据适配器"""
def __init__(self, config_path: str = None):
"""初始化适配器
Args:
config_path: 配置文件路径
"""
self.config = self._load_config(config_path)
self.data_cache = {}
# 尝试导入akshare(可选)
try:
import akshare as ak
self.ak = ak
self.akshare_available = True
logger.info("AKShare已成功导入")
except ImportError:
self.ak = None
self.akshare_available = False
logger.warning("AKShare未安装,将使用模拟数据")
def _load_config(self, config_path: str) -> Dict:
"""加载配置文件
Args:
config_path: 配置文件路径
Returns:
Dict: 配置信息
"""
default_config = {
'data_sources': {
'stock': {
'provider': 'akshare',
'fields_mapping': {
'date': 'date',
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'volume': 'volume',
'amount': 'amount',
'turnover': 'turnover'
}
},
'index': {
'provider': 'akshare',
'fields_mapping': {
'date': 'date',
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'volume': 'volume',
'amount': 'amount'
}
}
},
'vnpy_format': {
'datetime_format': '%Y-%m-%d',
'numeric_precision': 6,
'null_value': 0.0
},
'cache_settings': {
'enabled': True,
'ttl_hours': 24,
'cache_dir': './data/running_data/cache'
}
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as f:
user_config = json.load(f)
default_config.update(user_config)
except Exception as e:
logger.error(f"加载配置文件失败 {config_path}: {e}")
return default_config
def get_stock_daily(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""获取股票日线数据
Args:
symbol: 股票代码(如:000001
start_date: 开始日期(格式:YYYY-MM-DD
end_date: 结束日期(格式:YYYY-MM-DD
Returns:
pd.DataFrame: 转换后的vnPy格式数据
"""
logger.info(f"获取股票日线数据: {symbol} [{start_date} - {end_date}]")
try:
if self.akshare_available:
# 使用akshare获取数据
df = self.ak.stock_zh_a_hist(
symbol=symbol,
period="daily",
start_date=start_date,
end_date=end_date,
adjust="qfq" # 前复权
)
else:
# 模拟数据
df = self._generate_mock_stock_data(symbol, start_date, end_date)
# 转换数据格式
vnpy_df = self._convert_to_vnpy_format(df, 'stock')
logger.info(f"股票数据获取成功: {symbol}, 数据量: {len(vnpy_df)}")
return vnpy_df
except Exception as e:
logger.error(f"获取股票数据失败 {symbol}: {e}")
# 返回空DataFrame
return pd.DataFrame()
def _generate_mock_stock_data(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""生成模拟股票数据(当akshare不可用时)
Args:
symbol: 股票代码
start_date: 开始日期
end_date: 结束日期
Returns:
pd.DataFrame: 模拟数据
"""
# 生成日期范围
dates = pd.date_range(start=start_date, end=end_date, freq='D')
# 生成模拟数据
data = {
'日期': dates,
'开盘': np.random.uniform(10, 100, len(dates)),
'收盘': np.random.uniform(10, 100, len(dates)),
'最高': np.random.uniform(10, 100, len(dates)),
'最低': np.random.uniform(10, 100, len(dates)),
'成交量': np.random.uniform(10000, 1000000, len(dates)),
'成交额': np.random.uniform(100000, 10000000, len(dates)),
'振幅': np.random.uniform(0.1, 5.0, len(dates)),
'涨跌幅': np.random.uniform(-5.0, 5.0, len(dates)),
'涨跌额': np.random.uniform(-5.0, 5.0, len(dates)),
'换手率': np.random.uniform(0.1, 10.0, len(dates))
}
df = pd.DataFrame(data)
return df
def _convert_to_vnpy_format(self, df: pd.DataFrame, data_type: str) -> pd.DataFrame:
"""转换为vnPy格式
Args:
df: 原始数据DataFrame
data_type: 数据类型(stock, index等)
Returns:
pd.DataFrame: 转换后的数据
"""
if df.empty:
return df
# 获取字段映射
mapping = self.config['data_sources'].get(data_type, {}).get('fields_mapping', {})
# 创建新的DataFrame
vnpy_data = {}
for vnpy_field, source_field in mapping.items():
if source_field in df.columns:
vnpy_data[vnpy_field] = df[source_field]
else:
# 如果字段不存在,填充默认值
logger.warning(f"字段 {source_field} 不存在,使用默认值填充 {vnpy_field}")
vnpy_data[vnpy_field] = np.nan
vnpy_df = pd.DataFrame(vnpy_data)
# 确保日期列为datetime类型
if 'date' in vnpy_df.columns:
vnpy_df['date'] = pd.to_datetime(vnpy_df['date'])
# 处理空值
null_value = self.config['vnpy_format'].get('null_value', 0.0)
vnpy_df = vnpy_df.fillna(null_value)
# 设置数值精度
numeric_precision = self.config['vnpy_format'].get('numeric_precision', 6)
for col in vnpy_df.select_dtypes(include=[np.number]).columns:
vnpy_df[col] = vnpy_df[col].round(numeric_precision)
# 按日期排序
if 'date' in vnpy_df.columns:
vnpy_df = vnpy_df.sort_values('date').reset_index(drop=True)
return vnpy_df
def get_index_daily(self, index_symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""获取指数日线数据
Args:
index_symbol: 指数代码(如:000001.SH
start_date: 开始日期
end_date: 结束日期
Returns:
pd.DataFrame: 转换后的vnPy格式数据
"""
logger.info(f"获取指数日线数据: {index_symbol} [{start_date} - {end_date}]")
try:
if self.akshare_available:
# 使用akshare获取数据
df = self.ak.index_zh_a_hist(
symbol=index_symbol,
period="daily",
start_date=start_date,
end_date=end_date
)
else:
# 模拟数据
df = self._generate_mock_index_data(index_symbol, start_date, end_date)
# 转换数据格式
vnpy_df = self._convert_to_vnpy_format(df, 'index')
logger.info(f"指数数据获取成功: {index_symbol}, 数据量: {len(vnpy_df)}")
return vnpy_df
except Exception as e:
logger.error(f"获取指数数据失败 {index_symbol}: {e}")
return pd.DataFrame()
def _generate_mock_index_data(self, index_symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""生成模拟指数数据
Args:
index_symbol: 指数代码
start_date: 开始日期
end_date: 结束日期
Returns:
pd.DataFrame: 模拟数据
"""
# 生成日期范围
dates = pd.date_range(start=start_date, end=end_date, freq='D')
# 生成模拟数据
data = {
'日期': dates,
'开盘': np.random.uniform(3000, 4000, len(dates)),
'收盘': np.random.uniform(3000, 4000, len(dates)),
'最高': np.random.uniform(3000, 4000, len(dates)),
'最低': np.random.uniform(3000, 4000, len(dates)),
'成交量': np.random.uniform(1000000, 10000000, len(dates)),
'成交额': np.random.uniform(10000000, 100000000, len(dates))
}
df = pd.DataFrame(data)
return df
def export_to_vnpy_csv(self, df: pd.DataFrame, symbol: str, output_dir: str = None) -> str:
"""导出为vnPy CSV格式
Args:
df: 数据DataFrame
symbol: 标的代码
output_dir: 输出目录
Returns:
str: 输出文件路径
"""
if df.empty:
logger.warning(f"数据为空,跳过导出: {symbol}")
return ""
if output_dir is None:
output_dir = './data/running_data/vnpy_import'
os.makedirs(output_dir, exist_ok=True)
# 生成文件名
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"vnpy_{symbol}_{timestamp}.csv"
output_path = os.path.join(output_dir, filename)
# 保存为CSV
df.to_csv(output_path, index=False, encoding='utf-8-sig')
logger.info(f"数据已导出为vnPy CSV格式: {output_path}")
return output_path
def export_to_vnpy_database(self, df: pd.DataFrame, symbol: str, table_name: str = None) -> bool:
"""导出到vnPy数据库格式(模拟)
Args:
df: 数据DataFrame
symbol: 标的代码
table_name: 数据库表名
Returns:
bool: 是否成功
"""
if df.empty:
logger.warning(f"数据为空,跳过数据库导出: {symbol}")
return False
# 这里可以集成vnPy的数据库接口
# 示例:保存为JSON文件
if table_name is None:
table_name = f"vnpy_data_{symbol}"
output_dir = './data/running_data/vnpy_database'
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{table_name}.json")
# 转换为字典格式
data_dict = {
'symbol': symbol,
'table_name': table_name,
'export_time': datetime.now().isoformat(),
'data_count': len(df),
'data': df.to_dict(orient='records')
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data_dict, f, ensure_ascii=False, indent=2)
logger.info(f"数据已导出为vnPy数据库格式: {output_path}")
return True
def main():
"""示例使用"""
adapter = AKShareDataAdapter()
# 示例:获取股票数据
stock_data = adapter.get_stock_daily(
symbol='000001',
start_date='2024-01-01',
end_date='2024-01-31'
)
if not stock_data.empty:
print(f"股票数据获取成功,数据量: {len(stock_data)}")
print(stock_data.head())
# 导出为vnPy CSV格式
csv_path = adapter.export_to_vnpy_csv(stock_data, '000001')
print(f"CSV导出路径: {csv_path}")
else:
print("股票数据获取失败")
# 示例:获取指数数据
index_data = adapter.get_index_daily(
index_symbol='000001.SH',
start_date='2024-01-01',
end_date='2024-01-31'
)
if not index_data.empty:
print(f"\n指数数据获取成功,数据量: {len(index_data)}")
print(index_data.head())
if __name__ == "__main__":
main()