Coverage for src/pyhrp/hrp.py: 97%

119 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-16 16:29 +0000

1"""Hierarchical Risk Parity (HRP) algorithm implementation. 

2 

3This module implements the core HRP algorithm and related functions: 

4- hrp: Main function to compute HRP portfolio weights 

5- build_tree: Function to build hierarchical cluster tree from correlation matrix 

6- compute_cov: Function to compute a covariance matrix from returns 

7- compute_corr: Function to compute a correlation matrix from returns 

8- Dendrogram: Class to store and visualize hierarchical clustering results 

9""" 

10 

11from __future__ import annotations 

12 

13from dataclasses import dataclass 

14from typing import Literal 

15 

16import numpy as np 

17import plotly.graph_objects as go 

18import polars as pl 

19import scipy.cluster.hierarchy as sch 

20import scipy.spatial.distance as ssd 

21 

22from .algos import risk_parity, schur_risk_parity 

23from .cluster import Cluster 

24 

25__all__ = ["Dendrogram", "build_tree", "compute_corr", "compute_cov", "hrp", "schur_hrp"] 

26 

27 

28def compute_cov(df: pl.DataFrame) -> pl.DataFrame: 

29 """Compute covariance matrix from a DataFrame of returns.""" 

30 cols = df.columns 

31 cov = np.atleast_2d(np.cov(df.to_numpy().T)) 

32 return pl.DataFrame(dict(zip(cols, cov, strict=True))) 

33 

34 

35def compute_corr(df: pl.DataFrame) -> pl.DataFrame: 

36 """Compute correlation matrix from a DataFrame of returns.""" 

37 cols = df.columns 

38 corr = np.atleast_2d(np.corrcoef(df.to_numpy().T)) 

39 return pl.DataFrame(dict(zip(cols, corr, strict=True))) 

40 

41 

42def _returns(prices: pl.DataFrame) -> pl.DataFrame: 

43 """Compute simple returns from prices. 

44 

45 Drops leading all-null rows produced by pct_change and fills remaining 

46 nulls/NaNs (e.g. from missing prices) with zero returns. 

47 """ 

48 return ( 

49 prices.select(pl.all().pct_change()) 

50 .filter(pl.any_horizontal(pl.all().is_not_null())) 

51 .fill_null(0.0) 

52 .fill_nan(0.0) 

53 ) 

54 

55 

56def hrp( 

57 prices: pl.DataFrame, 

58 node: Cluster | None = None, 

59 method: Literal["single", "complete", "average", "ward"] = "ward", 

60 bisection: bool = False, 

61) -> Cluster: 

62 """Compute the hierarchical risk parity portfolio weights. 

63 

64 This is the main entry point for the HRP algorithm. It calculates returns from prices, 

65 builds a hierarchical clustering tree if not provided, and applies risk parity weights. 

66 

67 Args: 

68 prices (pl.DataFrame): Asset price time series (columns are assets, rows are dates) 

69 node (Cluster, optional): Root node of the hierarchical clustering tree. 

70 If None, a tree will be built from the correlation matrix. 

71 method (Literal["single", "complete", "average", "ward"]): Linkage method to use for distance calculation 

72 - "single": minimum distance between points (nearest neighbor) 

73 - "complete": maximum distance between points (furthest neighbor) 

74 - "average": average distance between all points 

75 - "ward": Ward variance minimization 

76 bisection (bool): Whether to use bisection method for tree construction 

77 

78 Returns: 

79 Cluster: The root cluster with portfolio weights assigned according to HRP 

80 

81 Examples: 

82 >>> import polars as pl 

83 >>> from pyhrp.hrp import hrp 

84 >>> prices = pl.DataFrame({"A": [100.0, 101.0, 99.0, 102.0], "B": [50.0, 51.0, 49.0, 52.0]}) 

85 >>> root = hrp(prices, method="ward") 

86 >>> round(sum(root.portfolio.weights.values()), 6) 

87 1.0 

88 """ 

89 returns = _returns(prices) 

90 cov = compute_cov(returns) 

91 cor = compute_corr(returns) 

92 node = node or build_tree(cor, method=method, bisection=bisection).root 

93 

94 return risk_parity(root=node, cov=cov) 

95 

96 

97def schur_hrp( 

98 prices: pl.DataFrame, 

99 node: Cluster | None = None, 

100 method: Literal["single", "complete", "average", "ward"] = "ward", 

101 bisection: bool = False, 

102 gamma: float = 0.5, 

103) -> Cluster: 

104 """Compute Schur Complementary Allocation portfolio weights. 

105 

106 Extends HRP by augmenting each sub-covariance block with off-diagonal information 

107 via Schur complements before splitting risk between clusters. Introduced by Peter Cotton 

108 (arXiv:2411.05807). At gamma=0 this is identical to HRP; at gamma=1 it recovers the 

109 global minimum-variance portfolio through the same recursive hierarchy. 

110 

111 Args: 

112 prices (pl.DataFrame): Asset price time series (columns are assets, rows are dates) 

113 node (Cluster, optional): Root node of the hierarchical clustering tree. 

114 If None, a tree will be built from the correlation matrix. 

115 method (Literal["single", "complete", "average", "ward"]): Linkage method for clustering 

116 bisection (bool): Whether to use bisection method for tree construction 

117 gamma (float): Schur interpolation parameter in [0, 1]. 

118 0 recovers standard HRP; 1 recovers minimum-variance portfolio. 

119 

120 Returns: 

121 Cluster: The root cluster with portfolio weights assigned 

122 

123 Examples: 

124 >>> import polars as pl 

125 >>> from pyhrp.hrp import schur_hrp 

126 >>> prices = pl.DataFrame({"A": [100.0, 101.0, 99.0, 102.0], "B": [50.0, 51.0, 49.0, 52.0]}) 

127 >>> root = schur_hrp(prices, method="ward", gamma=0.5) 

128 >>> round(sum(root.portfolio.weights.values()), 6) 

129 1.0 

130 """ 

131 returns = _returns(prices) 

132 cov = compute_cov(returns) 

133 cor = compute_corr(returns) 

134 node = node or build_tree(cor, method=method, bisection=bisection).root 

135 

136 return schur_risk_parity(root=node, cov=cov, gamma=gamma) 

137 

138 

139@dataclass(frozen=True) 

140class Dendrogram: 

141 """Container for hierarchical clustering dendrogram data and visualization. 

142 

143 This class stores the results of hierarchical clustering and provides methods 

144 for accessing and visualizing the dendrogram structure. 

145 

146 Attributes: 

147 root (Cluster): The root node of the hierarchical clustering tree 

148 assets (list[str]): Names of assets included in the clustering 

149 linkage (np.ndarray | None): Linkage matrix in scipy format for plotting 

150 distance (pl.DataFrame | None): Distance matrix used for clustering 

151 method (str | None): Linkage method used for clustering 

152 """ 

153 

154 root: Cluster 

155 assets: list[str] 

156 distance: pl.DataFrame | None = None 

157 linkage: np.ndarray | None = None 

158 method: str | None = None 

159 

160 def __post_init__(self) -> None: 

161 """Validate dataclass fields after initialization. 

162 

163 Ensures that the optional distance matrix, when provided, is a polars 

164 DataFrame with columns aligned to the asset list, and verifies that the 

165 number of leaves in the cluster tree matches the number of assets. 

166 """ 

167 if self.distance is not None: 

168 if not isinstance(self.distance, pl.DataFrame): 

169 raise TypeError("distance must be a polars DataFrame.") # noqa: TRY003 

170 

171 if self.distance.columns != list(self.assets): 

172 raise ValueError("Distance matrix index/columns must align with assets.") # noqa: TRY003 

173 

174 if len(self.root.leaves) != len(self.assets): 

175 raise ValueError("Number of leaves does not match number of assets.") # noqa: TRY003 

176 

177 def plot(self, **kwargs: object) -> go.Figure: 

178 """Build and return a plotly dendrogram figure.""" 

179 if self.linkage is None: 

180 msg = "Dendrogram has no linkage matrix to plot." 

181 raise ValueError(msg) 

182 ddata = sch.dendrogram(self.linkage, labels=self.assets, no_plot=True, **kwargs) 

183 fig = go.Figure() 

184 for xs, ys in zip(ddata["icoord"], ddata["dcoord"], strict=False): 

185 fig.add_trace(go.Scatter(x=xs, y=ys, mode="lines", line={"color": "steelblue"}, showlegend=False)) 

186 n = len(self.assets) 

187 fig.update_layout( 

188 xaxis={ 

189 "tickmode": "array", 

190 "tickvals": [5 + 10 * i for i in range(n)], 

191 "ticktext": ddata["ivl"], 

192 "tickangle": -90, 

193 }, 

194 ) 

195 return fig 

196 

197 @property 

198 def ids(self) -> list[int]: 

199 """Node values in the order left -> right as they appear in the dendrogram.""" 

200 return [node.value for node in self.root.leaves] 

201 

202 @property 

203 def names(self) -> list[str]: 

204 """The asset names as induced by the order of ids.""" 

205 return [self.assets[i] for i in self.ids] 

206 

207 

208def _compute_distance_matrix(corr: pl.DataFrame) -> pl.DataFrame: 

209 """Convert correlation matrix to distance matrix.""" 

210 c = corr.to_numpy() 

211 dist = np.sqrt(np.clip((1.0 - c) / 2.0, a_min=0.0, a_max=1.0)) 

212 np.fill_diagonal(dist, 0.0) 

213 cols = corr.columns 

214 return pl.DataFrame(dict(zip(cols, dist, strict=True))) 

215 

216 

217def _bisect_tree(ids: list[int], next_id: int) -> tuple[Cluster, int]: 

218 """Build tree by recursive bisection.""" 

219 if not ids: 

220 raise ValueError("ids must contain at least one node id.") # noqa: TRY003 

221 if len(ids) == 1: 

222 return Cluster(value=ids[0]), next_id 

223 

224 mid = len(ids) // 2 

225 left_ids, right_ids = ids[:mid], ids[mid:] 

226 left, next_id = _bisect_tree(left_ids, next_id) 

227 right, next_id = _bisect_tree(right_ids, next_id) 

228 next_id += 1 

229 return Cluster(value=next_id, left=left, right=right), next_id 

230 

231 

232def _get_linkage(node: Cluster) -> list[list[float]]: 

233 """Convert tree structure back to linkage matrix format.""" 

234 links_list: list[list[float]] = [] 

235 if node.left is not None and node.right is not None: 

236 if not isinstance(node.left, Cluster): 

237 raise TypeError("Expected left child to be a Cluster") # noqa: TRY003 # pragma: no cover 

238 if not isinstance(node.right, Cluster): 

239 raise TypeError("Expected right child to be a Cluster") # noqa: TRY003 # pragma: no cover 

240 links_list.extend(_get_linkage(node.left)) 

241 links_list.extend(_get_linkage(node.right)) 

242 links_list.append( 

243 [ 

244 float(node.left.value), 

245 float(node.right.value), 

246 float(node.size), 

247 float(len(node.left.leaves) + len(node.right.leaves)), 

248 ] 

249 ) 

250 return links_list 

251 

252 

253def build_tree( 

254 cor: pl.DataFrame, method: Literal["single", "complete", "average", "ward"] = "ward", bisection: bool = False 

255) -> Dendrogram: 

256 """Build hierarchical cluster tree from correlation matrix. 

257 

258 This function converts a correlation matrix to a distance matrix, performs 

259 hierarchical clustering, and returns a Dendrogram object containing the 

260 resulting tree structure. 

261 

262 Args: 

263 cor (pl.DataFrame): Correlation matrix of asset returns (columns are assets) 

264 method (Literal["single", "complete", "average", "ward"]): Linkage method for hierarchical clustering 

265 - "single": minimum distance between points (nearest neighbor) 

266 - "complete": maximum distance between points (furthest neighbor) 

267 - "average": average distance between all points 

268 - "ward": Ward variance minimization 

269 bisection (bool): Whether to use bisection method for tree construction 

270 

271 Returns: 

272 Dendrogram: Object containing the hierarchical clustering tree, with: 

273 - root: Root cluster node 

274 - linkage: Linkage matrix for plotting 

275 - assets: List of assets 

276 - method: Clustering method used 

277 - distance: Distance matrix 

278 

279 Examples: 

280 >>> import polars as pl 

281 >>> from pyhrp.hrp import build_tree 

282 >>> cor = pl.DataFrame({"A": [1.0, 0.5], "B": [0.5, 1.0]}) 

283 >>> dg = build_tree(cor, method="ward") 

284 >>> dg.root.leaf_count 

285 2 

286 """ 

287 if not isinstance(cor, pl.DataFrame): 

288 raise TypeError("Correlation matrix must be a polars DataFrame.") # noqa: TRY003 

289 if len(cor.columns) < 2: 

290 msg = "Correlation matrix must contain at least two assets." 

291 raise ValueError(msg) 

292 c = cor.to_numpy() 

293 bad = [col for col, diag in zip(cor.columns, np.diagonal(c), strict=True) if not np.isfinite(diag)] 

294 if bad: 

295 msg = ( 

296 f"Correlation matrix contains non-finite values for assets {bad}; " 

297 "constant (zero-variance) price series produce NaN correlations." 

298 ) 

299 raise ValueError(msg) 

300 if not np.isfinite(c).all(): 

301 msg = "Correlation matrix contains non-finite values." 

302 raise ValueError(msg) 

303 dist = _compute_distance_matrix(cor) 

304 links = sch.linkage(ssd.squareform(dist.to_numpy(), checks=False), method=method) 

305 

306 # Convert scipy tree to our Cluster format 

307 def to_cluster(node: sch.ClusterNode) -> Cluster: 

308 """Convert a scipy ClusterNode to our Cluster format. 

309 

310 Args: 

311 node (sch.ClusterNode): A node from scipy's hierarchical clustering 

312 

313 Returns: 

314 Cluster: Equivalent node in our Cluster format 

315 """ 

316 if node.left is not None and node.right is not None: 

317 left = to_cluster(node.left) 

318 right = to_cluster(node.right) 

319 return Cluster(value=node.id, left=left, right=right) 

320 return Cluster(value=node.id) 

321 

322 root = to_cluster(sch.to_tree(links, rd=False)) 

323 

324 # Apply bisection if requested 

325 if bisection: 

326 # Rebuild tree using bisection 

327 leaf_ids: list[int] = [int(node.value) for node in root.leaves] 

328 root, _ = _bisect_tree(ids=leaf_ids, next_id=max(leaf_ids)) 

329 links = np.array(_get_linkage(root)) 

330 

331 return Dendrogram(root=root, linkage=links, method=method, distance=dist, assets=cor.columns)