71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
"""
|
|
因子合成工具
|
|
将多个因子按权重合成最终得分
|
|
"""
|
|
import pandas as pd
|
|
import numpy as np
|
|
from typing import Dict, List, Optional
|
|
from factors.base_factor import BaseFactor
|
|
|
|
|
|
class FactorCombiner:
|
|
"""因子合成器 - 把多个因子按权重合成最终得分"""
|
|
|
|
def __init__(self, factors: Dict[str, BaseFactor], weights: Optional[Dict[str, float]] = None):
|
|
"""
|
|
参数:
|
|
factors: 因子字典 {因子名称: 因子实例}
|
|
weights: 因子权重 {因子名称: 权重}, 如果None,等权重
|
|
"""
|
|
self.factors = factors
|
|
|
|
# 如果没给权重,用等权重
|
|
if weights is None:
|
|
total = len(factors)
|
|
self.weights = {name: 1.0 / total for name in factors}
|
|
else:
|
|
# 标准化权重,让总和为1
|
|
total = sum(weights.values())
|
|
self.weights = {k: v / total for k, v in weights.items()}
|
|
|
|
def combine(self, data: pd.DataFrame) -> pd.Series:
|
|
"""
|
|
合成因子得分
|
|
参数:
|
|
data: 原始行情财务数据DataFrame
|
|
返回:
|
|
最终合并得分,index和data一致
|
|
"""
|
|
combined = None
|
|
|
|
for name, factor in self.factors.items():
|
|
weight = self.weights[name]
|
|
|
|
# 计算因子值(已经包含了标准化和rank)
|
|
factor_score = factor.process(data)
|
|
|
|
# 加权
|
|
weighted = factor_score * weight
|
|
|
|
if combined is None:
|
|
combined = weighted
|
|
else:
|
|
# 对齐索引相加
|
|
combined = combined.add(weighted, fill_value=0)
|
|
|
|
return combined
|
|
|
|
def get_factors(self) -> List[str]:
|
|
"""获取所有因子名称"""
|
|
return list(self.factors.keys())
|
|
|
|
def update_weights(self, new_weights: Dict[str, float]) -> None:
|
|
"""更新因子权重(用于动态加权)"""
|
|
# 标准化权重总和为1
|
|
total = sum(new_weights.values())
|
|
self.weights = {k: v / total for k, v in new_weights.items()}
|
|
|
|
def get_weights(self) -> Dict[str, float]:
|
|
"""获取当前权重"""
|
|
return self.weights.copy()
|