Coverage for src/pyhrp/hrp.py: 97%

74 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:19 +0000

1"""Hierarchical Risk Parity (HRP) algorithm implementation. 

2 

3This module implements the core HRP algorithm and related functions: 

4- hrp: Main function to compute HRP portfolio weights 

5- build_tree: Function to build hierarchical cluster tree from correlation matrix 

6- Dendrogram: Class to store and visualize hierarchical clustering results 

7""" 

8 

9from __future__ import annotations 

10 

11from dataclasses import dataclass 

12from typing import Literal 

13 

14import numpy as np 

15import pandas as pd 

16import scipy.cluster.hierarchy as sch 

17import scipy.spatial.distance as ssd 

18 

19from .algos import risk_parity 

20from .cluster import Asset, Cluster 

21 

22 

23def hrp( 

24 prices: pd.DataFrame, 

25 node: Cluster = None, 

26 method: Literal["single", "complete", "average", "ward"] = "ward", 

27 bisection: bool = False, 

28) -> Cluster: 

29 """Compute the hierarchical risk parity portfolio weights. 

30 

31 This is the main entry point for the HRP algorithm. It calculates returns from prices, 

32 builds a hierarchical clustering tree if not provided, and applies risk parity weights. 

33 

34 Args: 

35 prices (pd.DataFrame): Asset price time series 

36 node (Cluster, optional): Root node of the hierarchical clustering tree. 

37 If None, a tree will be built from the correlation matrix. 

38 method (Literal["single", "complete", "average", "ward"]): Linkage method to use for distance calculation 

39 - "single": minimum distance between points (nearest neighbor) 

40 - "complete": maximum distance between points (furthest neighbor) 

41 - "average": average distance between all points 

42 - "ward": Ward variance minimization 

43 bisection (bool): Whether to use bisection method for tree construction 

44 

45 Returns: 

46 Cluster: The root cluster with portfolio weights assigned according to HRP 

47 """ 

48 returns = prices.pct_change().dropna(axis=0, how="all") 

49 

50 cov, cor = returns.cov(), returns.corr() 

51 node = node or build_tree(cor, method=method, bisection=bisection).root 

52 

53 return risk_parity(root=node, cov=cov) 

54 

55 

56@dataclass(frozen=True) 

57class Dendrogram: 

58 """Container for hierarchical clustering dendrogram data and visualization. 

59 

60 This class stores the results of hierarchical clustering and provides methods 

61 for accessing and visualizing the dendrogram structure. 

62 

63 Attributes: 

64 root (Cluster): The root node of the hierarchical clustering tree 

65 assets (list[Asset]): List of assets included in the clustering 

66 linkage (np.ndarray | None): Linkage matrix in scipy format for plotting 

67 distance (np.ndarray | None): Distance matrix used for clustering 

68 method (str | None): Linkage method used for clustering 

69 """ 

70 

71 root: Cluster 

72 assets: list[Asset] 

73 linkage: np.ndarray | None = None 

74 distance: np.ndarray | None = None 

75 method: str | None = None 

76 

77 def __post_init__(self): 

78 """Validate the consistency between the number of assets and the leaves in the root node. 

79 

80 Raises: 

81 ValueError: If the number of leaves in the root node does not match 

82 the number of assets. 

83 """ 

84 if not len(self.root.leaves) == len(self.assets): 

85 raise ValueError("Inconsistent number of assets and leaves") 

86 

87 for asset in self.assets: 

88 assert isinstance(asset, Asset) 

89 

90 def plot(self, **kwargs): 

91 """Plot the dendrogram.""" 

92 try: 

93 labels = [asset.name for asset in self.assets] 

94 except AttributeError: 

95 labels = [asset for asset in self.assets] 

96 

97 sch.dendrogram(self.linkage, leaf_rotation=90, leaf_font_size=8, labels=labels, **kwargs) 

98 

99 @property 

100 def ids(self): 

101 """Node values in the order left -> right as they appear in the dendrogram.""" 

102 return [node.value for node in self.root.leaves] 

103 

104 @property 

105 def names(self): 

106 """The asset names as induced by the order of ids.""" 

107 return [self.assets[i].name for i in self.ids] 

108 

109 

110def _compute_distance_matrix(corr: np.ndarray) -> np.ndarray: 

111 """Convert correlation matrix to distance matrix.""" 

112 dist = np.sqrt(np.clip((1.0 - corr) / 2.0, a_min=0.0, a_max=1.0)) 

113 np.fill_diagonal(dist, 0.0) 

114 return dist 

115 

116 

117def build_tree( 

118 cor: pd.DataFrame, method: Literal["single", "complete", "average", "ward"] = "ward", bisection: bool = False 

119) -> Dendrogram: 

120 """Build hierarchical cluster tree from correlation matrix. 

121 

122 This function converts a correlation matrix to a distance matrix, performs 

123 hierarchical clustering, and returns a Dendrogram object containing the 

124 resulting tree structure. 

125 

126 Args: 

127 cor (pd.DataFrame): Correlation matrix of asset returns 

128 method (Literal["single", "complete", "average", "ward"]): Linkage method for hierarchical clustering 

129 - "single": minimum distance between points (nearest neighbor) 

130 - "complete": maximum distance between points (furthest neighbor) 

131 - "average": average distance between all points 

132 - "ward": Ward variance minimization 

133 bisection (bool): Whether to use bisection method for tree construction 

134 

135 Returns: 

136 Dendrogram: Object containing the hierarchical clustering tree, with: 

137 - root: Root cluster node 

138 - linkage: Linkage matrix for plotting 

139 - assets: List of assets 

140 - method: Clustering method used 

141 - distance: Distance matrix 

142 """ 

143 # Create distance matrix and linkage 

144 dist = _compute_distance_matrix(cor.values) 

145 links = sch.linkage(ssd.squareform(dist), method=method) 

146 

147 # Convert scipy tree to our Cluster format 

148 def to_cluster(node: sch.ClusterNode) -> Cluster: 

149 """Convert a scipy ClusterNode to our Cluster format. 

150 

151 Args: 

152 node (sch.ClusterNode): A node from scipy's hierarchical clustering 

153 

154 Returns: 

155 Cluster: Equivalent node in our Cluster format 

156 """ 

157 if node.left is not None: 

158 left = to_cluster(node.left) 

159 right = to_cluster(node.right) 

160 return Cluster(value=node.id, left=left, right=right) 

161 return Cluster(value=node.id) 

162 

163 root = to_cluster(sch.to_tree(links, rd=False)) 

164 

165 # Apply bisection if requested 

166 if bisection: 

167 

168 def bisect_tree(ids: list[int]) -> Cluster: 

169 """Build tree by recursive bisection. 

170 

171 This function recursively splits the list of IDs in half and creates 

172 a binary tree where each node represents a split. 

173 

174 Args: 

175 ids (list[int]): List of leaf node IDs to organize into a tree 

176 

177 Returns: 

178 Cluster: Root node of the constructed tree 

179 """ 

180 nonlocal nnn 

181 

182 if len(ids) == 1: 

183 return Cluster(value=ids[0]) 

184 

185 mid = len(ids) // 2 

186 left_ids, right_ids = ids[:mid], ids[mid:] 

187 

188 left = bisect_tree(left_ids) 

189 right = bisect_tree(right_ids) 

190 

191 nnn += 1 

192 return Cluster(value=nnn, left=left, right=right) 

193 

194 # Rebuild tree using bisection 

195 leaf_ids = [node.value for node in root.leaves] 

196 nnn = max(leaf_ids) 

197 root = bisect_tree(leaf_ids) 

198 

199 # Convert back to linkage format for plotting 

200 links = [] 

201 

202 def get_linkage(node: Cluster) -> None: 

203 """Convert tree structure back to linkage matrix format. 

204 

205 This function traverses the tree and builds a linkage matrix compatible 

206 with scipy's hierarchical clustering format for visualization. 

207 

208 Args: 

209 node (Cluster): Current node being processed 

210 """ 

211 if node.left is not None: 

212 get_linkage(node.left) 

213 get_linkage(node.right) 

214 links.append( 

215 [ 

216 node.left.value, 

217 node.right.value, 

218 float(node.size), 

219 len(node.left.leaves) + len(node.right.leaves), 

220 ] 

221 ) 

222 

223 get_linkage(root) 

224 links = np.array(links) 

225 

226 return Dendrogram(root=root, linkage=links, method=method, distance=dist, assets=list(cor.columns))