Coverage for src / pyhrp / hrp.py: 98%
80 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-10 05:36 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-10 05:36 +0000
1"""Hierarchical Risk Parity (HRP) algorithm implementation.
3This module implements the core HRP algorithm and related functions:
4- hrp: Main function to compute HRP portfolio weights
5- build_tree: Function to build hierarchical cluster tree from correlation matrix
6- Dendrogram: Class to store and visualize hierarchical clustering results
7"""
9from __future__ import annotations
11from dataclasses import dataclass
12from typing import Literal
14import numpy as np
15import pandas as pd
16import scipy.cluster.hierarchy as sch
17import scipy.spatial.distance as ssd
19from .algos import risk_parity
20from .cluster import Cluster
23def hrp(
24 prices: pd.DataFrame,
25 node: Cluster | None = None,
26 method: Literal["single", "complete", "average", "ward"] = "ward",
27 bisection: bool = False,
28) -> Cluster:
29 """Compute the hierarchical risk parity portfolio weights.
31 This is the main entry point for the HRP algorithm. It calculates returns from prices,
32 builds a hierarchical clustering tree if not provided, and applies risk parity weights.
34 Args:
35 prices (pd.DataFrame): Asset price time series
36 node (Cluster, optional): Root node of the hierarchical clustering tree.
37 If None, a tree will be built from the correlation matrix.
38 method (Literal["single", "complete", "average", "ward"]): Linkage method to use for distance calculation
39 - "single": minimum distance between points (nearest neighbor)
40 - "complete": maximum distance between points (furthest neighbor)
41 - "average": average distance between all points
42 - "ward": Ward variance minimization
43 bisection (bool): Whether to use bisection method for tree construction
45 Returns:
46 Cluster: The root cluster with portfolio weights assigned according to HRP
47 """
48 returns = prices.pct_change().dropna(axis=0, how="all")
49 cov, cor = returns.cov(), returns.corr()
50 node = node or build_tree(cor, method=method, bisection=bisection).root
52 return risk_parity(root=node, cov=cov)
55@dataclass(frozen=True)
56class Dendrogram:
57 """Container for hierarchical clustering dendrogram data and visualization.
59 This class stores the results of hierarchical clustering and provides methods
60 for accessing and visualizing the dendrogram structure.
62 Attributes:
63 root (Cluster): The root node of the hierarchical clustering tree
64 assets (pd.Index): Index of assets included in the clustering
65 linkage (np.ndarray | None): Linkage matrix in scipy format for plotting
66 distance (np.ndarray | None): Distance matrix used for clustering
67 method (str | None): Linkage method used for clustering
68 """
70 root: Cluster
71 assets: pd.Index
72 distance: pd.DataFrame | None = None
73 linkage: np.ndarray | None = None
74 method: str | None = None
76 def __post_init__(self) -> None:
77 """Validate dataclass fields after initialization.
79 Ensures that the optional distance matrix, when provided, is a pandas
80 DataFrame aligned with the asset order, and verifies that the number of
81 leaves in the cluster tree matches the number of assets.
82 """
83 # ---- Optional: validate distance index/columns ----
84 if self.distance is not None:
85 if not isinstance(self.distance, pd.DataFrame):
86 raise TypeError("distance must be a pandas DataFrame.") # noqa: TRY003
88 # Optionally check if distance matches assets
89 if not self.distance.index.equals(pd.Index(self.assets)) or not self.distance.columns.equals(
90 pd.Index(self.assets)
91 ):
92 raise ValueError("Distance matrix index/columns must align with assets.") # noqa: TRY003
94 # Check the number of leaves and assets
95 if len(self.root.leaves) != len(self.assets):
96 raise ValueError("Number of leaves does not match number of assets.") # noqa: TRY003
98 def plot(self, **kwargs: object) -> None:
99 """Plot the dendrogram."""
100 sch.dendrogram(self.linkage, leaf_rotation=90, leaf_font_size=8, labels=self.assets, **kwargs)
102 @property
103 def ids(self) -> list[int]:
104 """Node values in the order left -> right as they appear in the dendrogram."""
105 return [node.value for node in self.root.leaves] # type: ignore[misc]
107 @property
108 def names(self) -> list[str]:
109 """The asset names as induced by the order of ids."""
110 return [self.assets[i] for i in self.ids]
113def _compute_distance_matrix(corr: pd.DataFrame) -> pd.DataFrame:
114 """Convert correlation matrix to distance matrix."""
115 c = corr.values
116 dist = np.sqrt(np.clip((1.0 - c) / 2.0, a_min=0.0, a_max=1.0))
117 np.fill_diagonal(dist, 0.0)
118 return pd.DataFrame(data=dist, index=corr.index, columns=corr.index)
121def build_tree(
122 cor: pd.DataFrame, method: Literal["single", "complete", "average", "ward"] = "ward", bisection: bool = False
123) -> Dendrogram:
124 """Build hierarchical cluster tree from correlation matrix.
126 This function converts a correlation matrix to a distance matrix, performs
127 hierarchical clustering, and returns a Dendrogram object containing the
128 resulting tree structure.
130 Args:
131 cor (pd.DataFrame): Correlation matrix of asset returns
132 method (Literal["single", "complete", "average", "ward"]): Linkage method for hierarchical clustering
133 - "single": minimum distance between points (nearest neighbor)
134 - "complete": maximum distance between points (furthest neighbor)
135 - "average": average distance between all points
136 - "ward": Ward variance minimization
137 bisection (bool): Whether to use bisection method for tree construction
139 Returns:
140 Dendrogram: Object containing the hierarchical clustering tree, with:
141 - root: Root cluster node
142 - linkage: Linkage matrix for plotting
143 - assets: List of assets
144 - method: Clustering method used
145 - distance: Distance matrix
146 """
147 # Create distance matrix and linkage
148 if not isinstance(cor, pd.DataFrame):
149 raise TypeError("Correlation matrix must be a pandas DataFrame.") # noqa: TRY003
150 dist = _compute_distance_matrix(cor)
151 links = sch.linkage(ssd.squareform(dist), method=method)
153 # Convert scipy tree to our Cluster format
154 def to_cluster(node: sch.ClusterNode) -> Cluster:
155 """Convert a scipy ClusterNode to our Cluster format.
157 Args:
158 node (sch.ClusterNode): A node from scipy's hierarchical clustering
160 Returns:
161 Cluster: Equivalent node in our Cluster format
162 """
163 if node.left is not None and node.right is not None:
164 left = to_cluster(node.left)
165 right = to_cluster(node.right)
166 return Cluster(value=node.id, left=left, right=right)
167 return Cluster(value=node.id)
169 root = to_cluster(sch.to_tree(links, rd=False))
171 # Apply bisection if requested
172 if bisection:
173 # Rebuild tree using bisection
174 leaf_ids: list[int] = [int(node.value) for node in root.leaves]
175 nnn: int = max(leaf_ids)
177 def bisect_tree(ids: list[int]) -> Cluster:
178 """Build tree by recursive bisection.
180 This function recursively splits the list of IDs in half and creates
181 a binary tree where each node represents a split.
183 Args:
184 ids (list[int]): List of leaf node IDs to organize into a tree
186 Returns:
187 Cluster: Root node of the constructed tree
188 """
189 nonlocal nnn
191 if len(ids) == 1:
192 return Cluster(value=ids[0])
194 mid = len(ids) // 2
195 left_ids, right_ids = ids[:mid], ids[mid:]
197 left = bisect_tree(left_ids)
198 right = bisect_tree(right_ids)
200 nnn += 1
201 return Cluster(value=nnn, left=left, right=right)
203 root = bisect_tree(leaf_ids)
205 # Convert back to linkage format for plotting
206 links_list: list[list[float]] = []
208 def get_linkage(node: Cluster) -> None:
209 """Convert tree structure back to linkage matrix format.
211 This function traverses the tree and builds a linkage matrix compatible
212 with scipy's hierarchical clustering format for visualization.
214 Args:
215 node (Cluster): Current node being processed
216 """
217 if node.left is not None and node.right is not None:
218 if not isinstance(node.left, Cluster):
219 raise TypeError("Expected left child to be a Cluster") # noqa: TRY003
220 if not isinstance(node.right, Cluster):
221 raise TypeError("Expected right child to be a Cluster") # noqa: TRY003
222 get_linkage(node.left)
223 get_linkage(node.right)
224 links_list.append(
225 [
226 float(node.left.value),
227 float(node.right.value),
228 float(node.size),
229 float(len(node.left.leaves) + len(node.right.leaves)),
230 ]
231 )
233 get_linkage(root)
234 links = np.array(links_list)
236 return Dendrogram(root=root, linkage=links, method=method, distance=dist, assets=cor.columns)