Coverage for src/pyhrp/hrp.py: 100%

75 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-26 16:50 +0000

1"""Hierarchical Risk Parity (HRP) algorithm implementation. 

2 

3This module implements the core HRP algorithm and related functions: 

4- hrp: Main function to compute HRP portfolio weights 

5- build_tree: Function to build hierarchical cluster tree from correlation matrix 

6- Dendrogram: Class to store and visualize hierarchical clustering results 

7""" 

8 

9from __future__ import annotations 

10 

11from dataclasses import dataclass 

12from typing import Literal 

13 

14import numpy as np 

15import pandas as pd 

16import scipy.cluster.hierarchy as sch 

17import scipy.spatial.distance as ssd 

18 

19from .algos import risk_parity 

20from .cluster import Cluster 

21 

22 

23def hrp( 

24 prices: pd.DataFrame, 

25 node: Cluster = None, 

26 method: Literal["single", "complete", "average", "ward"] = "ward", 

27 bisection: bool = False, 

28) -> Cluster: 

29 """Compute the hierarchical risk parity portfolio weights. 

30 

31 This is the main entry point for the HRP algorithm. It calculates returns from prices, 

32 builds a hierarchical clustering tree if not provided, and applies risk parity weights. 

33 

34 Args: 

35 prices (pd.DataFrame): Asset price time series 

36 node (Cluster, optional): Root node of the hierarchical clustering tree. 

37 If None, a tree will be built from the correlation matrix. 

38 method (Literal["single", "complete", "average", "ward"]): Linkage method to use for distance calculation 

39 - "single": minimum distance between points (nearest neighbor) 

40 - "complete": maximum distance between points (furthest neighbor) 

41 - "average": average distance between all points 

42 - "ward": Ward variance minimization 

43 bisection (bool): Whether to use bisection method for tree construction 

44 

45 Returns: 

46 Cluster: The root cluster with portfolio weights assigned according to HRP 

47 """ 

48 returns = prices.pct_change().dropna(axis=0, how="all") 

49 cov, cor = returns.cov(), returns.corr() 

50 node = node or build_tree(cor, method=method, bisection=bisection).root 

51 

52 return risk_parity(root=node, cov=cov) 

53 

54 

55@dataclass(frozen=True) 

56class Dendrogram: 

57 """Container for hierarchical clustering dendrogram data and visualization. 

58 

59 This class stores the results of hierarchical clustering and provides methods 

60 for accessing and visualizing the dendrogram structure. 

61 

62 Attributes: 

63 root (Cluster): The root node of the hierarchical clustering tree 

64 assets (list[Asset]): List of assets included in the clustering 

65 linkage (np.ndarray | None): Linkage matrix in scipy format for plotting 

66 distance (np.ndarray | None): Distance matrix used for clustering 

67 method (str | None): Linkage method used for clustering 

68 """ 

69 

70 root: Cluster 

71 assets: list[str] 

72 distance: pd.DataFrame | None = None 

73 linkage: np.ndarray | None = None 

74 method: str | None = None 

75 

76 def __post_init__(self): 

77 """Validate dataclass fields after initialization. 

78 

79 Ensures that the optional distance matrix, when provided, is a pandas 

80 DataFrame aligned with the asset order, and verifies that the number of 

81 leaves in the cluster tree matches the number of assets. 

82 """ 

83 # ---- Optional: validate distance index/columns ---- 

84 if self.distance is not None: 

85 if not isinstance(self.distance, pd.DataFrame): 

86 raise TypeError("distance must be a pandas DataFrame.") 

87 

88 # Optionally check if distance matches assets 

89 if not self.distance.index.equals(pd.Index(self.assets)) or not self.distance.columns.equals( 

90 pd.Index(self.assets) 

91 ): 

92 raise ValueError("Distance matrix index/columns must align with assets.") 

93 

94 # Check the number of leaves and assets 

95 if len(self.root.leaves) != len(self.assets): 

96 raise ValueError("Number of leaves does not match number of assets.") 

97 

98 def plot(self, **kwargs): 

99 """Plot the dendrogram.""" 

100 sch.dendrogram(self.linkage, leaf_rotation=90, leaf_font_size=8, labels=self.assets, **kwargs) 

101 

102 @property 

103 def ids(self): 

104 """Node values in the order left -> right as they appear in the dendrogram.""" 

105 return [node.value for node in self.root.leaves] 

106 

107 @property 

108 def names(self): 

109 """The asset names as induced by the order of ids.""" 

110 return [self.assets[i] for i in self.ids] 

111 

112 

113def _compute_distance_matrix(corr: pd.DataFrame) -> pd.DataFrame: 

114 """Convert correlation matrix to distance matrix.""" 

115 c = corr.values 

116 dist = np.sqrt(np.clip((1.0 - c) / 2.0, a_min=0.0, a_max=1.0)) 

117 np.fill_diagonal(dist, 0.0) 

118 return pd.DataFrame(data=dist, index=corr.index, columns=corr.index) 

119 

120 

121def build_tree( 

122 cor: pd.DataFrame, method: Literal["single", "complete", "average", "ward"] = "ward", bisection: bool = False 

123) -> Dendrogram: 

124 """Build hierarchical cluster tree from correlation matrix. 

125 

126 This function converts a correlation matrix to a distance matrix, performs 

127 hierarchical clustering, and returns a Dendrogram object containing the 

128 resulting tree structure. 

129 

130 Args: 

131 cor (pd.DataFrame): Correlation matrix of asset returns 

132 method (Literal["single", "complete", "average", "ward"]): Linkage method for hierarchical clustering 

133 - "single": minimum distance between points (nearest neighbor) 

134 - "complete": maximum distance between points (furthest neighbor) 

135 - "average": average distance between all points 

136 - "ward": Ward variance minimization 

137 bisection (bool): Whether to use bisection method for tree construction 

138 

139 Returns: 

140 Dendrogram: Object containing the hierarchical clustering tree, with: 

141 - root: Root cluster node 

142 - linkage: Linkage matrix for plotting 

143 - assets: List of assets 

144 - method: Clustering method used 

145 - distance: Distance matrix 

146 """ 

147 # Create distance matrix and linkage 

148 assert isinstance(cor, pd.DataFrame), "Correlation matrix must be a pandas DataFrame." 

149 dist = _compute_distance_matrix(cor) 

150 links = sch.linkage(ssd.squareform(dist), method=method) 

151 

152 # Convert scipy tree to our Cluster format 

153 def to_cluster(node: sch.ClusterNode) -> Cluster: 

154 """Convert a scipy ClusterNode to our Cluster format. 

155 

156 Args: 

157 node (sch.ClusterNode): A node from scipy's hierarchical clustering 

158 

159 Returns: 

160 Cluster: Equivalent node in our Cluster format 

161 """ 

162 if node.left is not None: 

163 left = to_cluster(node.left) 

164 right = to_cluster(node.right) 

165 return Cluster(value=node.id, left=left, right=right) 

166 return Cluster(value=node.id) 

167 

168 root = to_cluster(sch.to_tree(links, rd=False)) 

169 

170 # Apply bisection if requested 

171 if bisection: 

172 

173 def bisect_tree(ids: list[int]) -> Cluster: 

174 """Build tree by recursive bisection. 

175 

176 This function recursively splits the list of IDs in half and creates 

177 a binary tree where each node represents a split. 

178 

179 Args: 

180 ids (list[int]): List of leaf node IDs to organize into a tree 

181 

182 Returns: 

183 Cluster: Root node of the constructed tree 

184 """ 

185 nonlocal nnn 

186 

187 if len(ids) == 1: 

188 return Cluster(value=ids[0]) 

189 

190 mid = len(ids) // 2 

191 left_ids, right_ids = ids[:mid], ids[mid:] 

192 

193 left = bisect_tree(left_ids) 

194 right = bisect_tree(right_ids) 

195 

196 nnn += 1 

197 return Cluster(value=nnn, left=left, right=right) 

198 

199 # Rebuild tree using bisection 

200 leaf_ids = [node.value for node in root.leaves] 

201 nnn = max(leaf_ids) 

202 root = bisect_tree(leaf_ids) 

203 

204 # Convert back to linkage format for plotting 

205 links = [] 

206 

207 def get_linkage(node: Cluster) -> None: 

208 """Convert tree structure back to linkage matrix format. 

209 

210 This function traverses the tree and builds a linkage matrix compatible 

211 with scipy's hierarchical clustering format for visualization. 

212 

213 Args: 

214 node (Cluster): Current node being processed 

215 """ 

216 if node.left is not None: 

217 get_linkage(node.left) 

218 get_linkage(node.right) 

219 links.append( 

220 [ 

221 node.left.value, 

222 node.right.value, 

223 float(node.size), 

224 len(node.left.leaves) + len(node.right.leaves), 

225 ] 

226 ) 

227 

228 get_linkage(root) 

229 links = np.array(links) 

230 

231 return Dendrogram(root=root, linkage=links, method=method, distance=dist, assets=cor.columns)