Coverage for src/pyhrp/hrp.py: 97%
119 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-16 16:29 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-16 16:29 +0000
1"""Hierarchical Risk Parity (HRP) algorithm implementation.
3This module implements the core HRP algorithm and related functions:
4- hrp: Main function to compute HRP portfolio weights
5- build_tree: Function to build hierarchical cluster tree from correlation matrix
6- compute_cov: Function to compute a covariance matrix from returns
7- compute_corr: Function to compute a correlation matrix from returns
8- Dendrogram: Class to store and visualize hierarchical clustering results
9"""
11from __future__ import annotations
13from dataclasses import dataclass
14from typing import Literal
16import numpy as np
17import plotly.graph_objects as go
18import polars as pl
19import scipy.cluster.hierarchy as sch
20import scipy.spatial.distance as ssd
22from .algos import risk_parity, schur_risk_parity
23from .cluster import Cluster
25__all__ = ["Dendrogram", "build_tree", "compute_corr", "compute_cov", "hrp", "schur_hrp"]
28def compute_cov(df: pl.DataFrame) -> pl.DataFrame:
29 """Compute covariance matrix from a DataFrame of returns."""
30 cols = df.columns
31 cov = np.atleast_2d(np.cov(df.to_numpy().T))
32 return pl.DataFrame(dict(zip(cols, cov, strict=True)))
35def compute_corr(df: pl.DataFrame) -> pl.DataFrame:
36 """Compute correlation matrix from a DataFrame of returns."""
37 cols = df.columns
38 corr = np.atleast_2d(np.corrcoef(df.to_numpy().T))
39 return pl.DataFrame(dict(zip(cols, corr, strict=True)))
42def _returns(prices: pl.DataFrame) -> pl.DataFrame:
43 """Compute simple returns from prices.
45 Drops leading all-null rows produced by pct_change and fills remaining
46 nulls/NaNs (e.g. from missing prices) with zero returns.
47 """
48 return (
49 prices.select(pl.all().pct_change())
50 .filter(pl.any_horizontal(pl.all().is_not_null()))
51 .fill_null(0.0)
52 .fill_nan(0.0)
53 )
56def hrp(
57 prices: pl.DataFrame,
58 node: Cluster | None = None,
59 method: Literal["single", "complete", "average", "ward"] = "ward",
60 bisection: bool = False,
61) -> Cluster:
62 """Compute the hierarchical risk parity portfolio weights.
64 This is the main entry point for the HRP algorithm. It calculates returns from prices,
65 builds a hierarchical clustering tree if not provided, and applies risk parity weights.
67 Args:
68 prices (pl.DataFrame): Asset price time series (columns are assets, rows are dates)
69 node (Cluster, optional): Root node of the hierarchical clustering tree.
70 If None, a tree will be built from the correlation matrix.
71 method (Literal["single", "complete", "average", "ward"]): Linkage method to use for distance calculation
72 - "single": minimum distance between points (nearest neighbor)
73 - "complete": maximum distance between points (furthest neighbor)
74 - "average": average distance between all points
75 - "ward": Ward variance minimization
76 bisection (bool): Whether to use bisection method for tree construction
78 Returns:
79 Cluster: The root cluster with portfolio weights assigned according to HRP
81 Examples:
82 >>> import polars as pl
83 >>> from pyhrp.hrp import hrp
84 >>> prices = pl.DataFrame({"A": [100.0, 101.0, 99.0, 102.0], "B": [50.0, 51.0, 49.0, 52.0]})
85 >>> root = hrp(prices, method="ward")
86 >>> round(sum(root.portfolio.weights.values()), 6)
87 1.0
88 """
89 returns = _returns(prices)
90 cov = compute_cov(returns)
91 cor = compute_corr(returns)
92 node = node or build_tree(cor, method=method, bisection=bisection).root
94 return risk_parity(root=node, cov=cov)
97def schur_hrp(
98 prices: pl.DataFrame,
99 node: Cluster | None = None,
100 method: Literal["single", "complete", "average", "ward"] = "ward",
101 bisection: bool = False,
102 gamma: float = 0.5,
103) -> Cluster:
104 """Compute Schur Complementary Allocation portfolio weights.
106 Extends HRP by augmenting each sub-covariance block with off-diagonal information
107 via Schur complements before splitting risk between clusters. Introduced by Peter Cotton
108 (arXiv:2411.05807). At gamma=0 this is identical to HRP; at gamma=1 it recovers the
109 global minimum-variance portfolio through the same recursive hierarchy.
111 Args:
112 prices (pl.DataFrame): Asset price time series (columns are assets, rows are dates)
113 node (Cluster, optional): Root node of the hierarchical clustering tree.
114 If None, a tree will be built from the correlation matrix.
115 method (Literal["single", "complete", "average", "ward"]): Linkage method for clustering
116 bisection (bool): Whether to use bisection method for tree construction
117 gamma (float): Schur interpolation parameter in [0, 1].
118 0 recovers standard HRP; 1 recovers minimum-variance portfolio.
120 Returns:
121 Cluster: The root cluster with portfolio weights assigned
123 Examples:
124 >>> import polars as pl
125 >>> from pyhrp.hrp import schur_hrp
126 >>> prices = pl.DataFrame({"A": [100.0, 101.0, 99.0, 102.0], "B": [50.0, 51.0, 49.0, 52.0]})
127 >>> root = schur_hrp(prices, method="ward", gamma=0.5)
128 >>> round(sum(root.portfolio.weights.values()), 6)
129 1.0
130 """
131 returns = _returns(prices)
132 cov = compute_cov(returns)
133 cor = compute_corr(returns)
134 node = node or build_tree(cor, method=method, bisection=bisection).root
136 return schur_risk_parity(root=node, cov=cov, gamma=gamma)
139@dataclass(frozen=True)
140class Dendrogram:
141 """Container for hierarchical clustering dendrogram data and visualization.
143 This class stores the results of hierarchical clustering and provides methods
144 for accessing and visualizing the dendrogram structure.
146 Attributes:
147 root (Cluster): The root node of the hierarchical clustering tree
148 assets (list[str]): Names of assets included in the clustering
149 linkage (np.ndarray | None): Linkage matrix in scipy format for plotting
150 distance (pl.DataFrame | None): Distance matrix used for clustering
151 method (str | None): Linkage method used for clustering
152 """
154 root: Cluster
155 assets: list[str]
156 distance: pl.DataFrame | None = None
157 linkage: np.ndarray | None = None
158 method: str | None = None
160 def __post_init__(self) -> None:
161 """Validate dataclass fields after initialization.
163 Ensures that the optional distance matrix, when provided, is a polars
164 DataFrame with columns aligned to the asset list, and verifies that the
165 number of leaves in the cluster tree matches the number of assets.
166 """
167 if self.distance is not None:
168 if not isinstance(self.distance, pl.DataFrame):
169 raise TypeError("distance must be a polars DataFrame.") # noqa: TRY003
171 if self.distance.columns != list(self.assets):
172 raise ValueError("Distance matrix index/columns must align with assets.") # noqa: TRY003
174 if len(self.root.leaves) != len(self.assets):
175 raise ValueError("Number of leaves does not match number of assets.") # noqa: TRY003
177 def plot(self, **kwargs: object) -> go.Figure:
178 """Build and return a plotly dendrogram figure."""
179 if self.linkage is None:
180 msg = "Dendrogram has no linkage matrix to plot."
181 raise ValueError(msg)
182 ddata = sch.dendrogram(self.linkage, labels=self.assets, no_plot=True, **kwargs)
183 fig = go.Figure()
184 for xs, ys in zip(ddata["icoord"], ddata["dcoord"], strict=False):
185 fig.add_trace(go.Scatter(x=xs, y=ys, mode="lines", line={"color": "steelblue"}, showlegend=False))
186 n = len(self.assets)
187 fig.update_layout(
188 xaxis={
189 "tickmode": "array",
190 "tickvals": [5 + 10 * i for i in range(n)],
191 "ticktext": ddata["ivl"],
192 "tickangle": -90,
193 },
194 )
195 return fig
197 @property
198 def ids(self) -> list[int]:
199 """Node values in the order left -> right as they appear in the dendrogram."""
200 return [node.value for node in self.root.leaves]
202 @property
203 def names(self) -> list[str]:
204 """The asset names as induced by the order of ids."""
205 return [self.assets[i] for i in self.ids]
208def _compute_distance_matrix(corr: pl.DataFrame) -> pl.DataFrame:
209 """Convert correlation matrix to distance matrix."""
210 c = corr.to_numpy()
211 dist = np.sqrt(np.clip((1.0 - c) / 2.0, a_min=0.0, a_max=1.0))
212 np.fill_diagonal(dist, 0.0)
213 cols = corr.columns
214 return pl.DataFrame(dict(zip(cols, dist, strict=True)))
217def _bisect_tree(ids: list[int], next_id: int) -> tuple[Cluster, int]:
218 """Build tree by recursive bisection."""
219 if not ids:
220 raise ValueError("ids must contain at least one node id.") # noqa: TRY003
221 if len(ids) == 1:
222 return Cluster(value=ids[0]), next_id
224 mid = len(ids) // 2
225 left_ids, right_ids = ids[:mid], ids[mid:]
226 left, next_id = _bisect_tree(left_ids, next_id)
227 right, next_id = _bisect_tree(right_ids, next_id)
228 next_id += 1
229 return Cluster(value=next_id, left=left, right=right), next_id
232def _get_linkage(node: Cluster) -> list[list[float]]:
233 """Convert tree structure back to linkage matrix format."""
234 links_list: list[list[float]] = []
235 if node.left is not None and node.right is not None:
236 if not isinstance(node.left, Cluster):
237 raise TypeError("Expected left child to be a Cluster") # noqa: TRY003 # pragma: no cover
238 if not isinstance(node.right, Cluster):
239 raise TypeError("Expected right child to be a Cluster") # noqa: TRY003 # pragma: no cover
240 links_list.extend(_get_linkage(node.left))
241 links_list.extend(_get_linkage(node.right))
242 links_list.append(
243 [
244 float(node.left.value),
245 float(node.right.value),
246 float(node.size),
247 float(len(node.left.leaves) + len(node.right.leaves)),
248 ]
249 )
250 return links_list
253def build_tree(
254 cor: pl.DataFrame, method: Literal["single", "complete", "average", "ward"] = "ward", bisection: bool = False
255) -> Dendrogram:
256 """Build hierarchical cluster tree from correlation matrix.
258 This function converts a correlation matrix to a distance matrix, performs
259 hierarchical clustering, and returns a Dendrogram object containing the
260 resulting tree structure.
262 Args:
263 cor (pl.DataFrame): Correlation matrix of asset returns (columns are assets)
264 method (Literal["single", "complete", "average", "ward"]): Linkage method for hierarchical clustering
265 - "single": minimum distance between points (nearest neighbor)
266 - "complete": maximum distance between points (furthest neighbor)
267 - "average": average distance between all points
268 - "ward": Ward variance minimization
269 bisection (bool): Whether to use bisection method for tree construction
271 Returns:
272 Dendrogram: Object containing the hierarchical clustering tree, with:
273 - root: Root cluster node
274 - linkage: Linkage matrix for plotting
275 - assets: List of assets
276 - method: Clustering method used
277 - distance: Distance matrix
279 Examples:
280 >>> import polars as pl
281 >>> from pyhrp.hrp import build_tree
282 >>> cor = pl.DataFrame({"A": [1.0, 0.5], "B": [0.5, 1.0]})
283 >>> dg = build_tree(cor, method="ward")
284 >>> dg.root.leaf_count
285 2
286 """
287 if not isinstance(cor, pl.DataFrame):
288 raise TypeError("Correlation matrix must be a polars DataFrame.") # noqa: TRY003
289 if len(cor.columns) < 2:
290 msg = "Correlation matrix must contain at least two assets."
291 raise ValueError(msg)
292 c = cor.to_numpy()
293 bad = [col for col, diag in zip(cor.columns, np.diagonal(c), strict=True) if not np.isfinite(diag)]
294 if bad:
295 msg = (
296 f"Correlation matrix contains non-finite values for assets {bad}; "
297 "constant (zero-variance) price series produce NaN correlations."
298 )
299 raise ValueError(msg)
300 if not np.isfinite(c).all():
301 msg = "Correlation matrix contains non-finite values."
302 raise ValueError(msg)
303 dist = _compute_distance_matrix(cor)
304 links = sch.linkage(ssd.squareform(dist.to_numpy(), checks=False), method=method)
306 # Convert scipy tree to our Cluster format
307 def to_cluster(node: sch.ClusterNode) -> Cluster:
308 """Convert a scipy ClusterNode to our Cluster format.
310 Args:
311 node (sch.ClusterNode): A node from scipy's hierarchical clustering
313 Returns:
314 Cluster: Equivalent node in our Cluster format
315 """
316 if node.left is not None and node.right is not None:
317 left = to_cluster(node.left)
318 right = to_cluster(node.right)
319 return Cluster(value=node.id, left=left, right=right)
320 return Cluster(value=node.id)
322 root = to_cluster(sch.to_tree(links, rd=False))
324 # Apply bisection if requested
325 if bisection:
326 # Rebuild tree using bisection
327 leaf_ids: list[int] = [int(node.value) for node in root.leaves]
328 root, _ = _bisect_tree(ids=leaf_ids, next_id=max(leaf_ids))
329 links = np.array(_get_linkage(root))
331 return Dendrogram(root=root, linkage=links, method=method, distance=dist, assets=cor.columns)