Coverage for src/pyhrp/hrp.py: 100%
75 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-26 16:50 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-26 16:50 +0000
1"""Hierarchical Risk Parity (HRP) algorithm implementation.
3This module implements the core HRP algorithm and related functions:
4- hrp: Main function to compute HRP portfolio weights
5- build_tree: Function to build hierarchical cluster tree from correlation matrix
6- Dendrogram: Class to store and visualize hierarchical clustering results
7"""
9from __future__ import annotations
11from dataclasses import dataclass
12from typing import Literal
14import numpy as np
15import pandas as pd
16import scipy.cluster.hierarchy as sch
17import scipy.spatial.distance as ssd
19from .algos import risk_parity
20from .cluster import Cluster
23def hrp(
24 prices: pd.DataFrame,
25 node: Cluster = None,
26 method: Literal["single", "complete", "average", "ward"] = "ward",
27 bisection: bool = False,
28) -> Cluster:
29 """Compute the hierarchical risk parity portfolio weights.
31 This is the main entry point for the HRP algorithm. It calculates returns from prices,
32 builds a hierarchical clustering tree if not provided, and applies risk parity weights.
34 Args:
35 prices (pd.DataFrame): Asset price time series
36 node (Cluster, optional): Root node of the hierarchical clustering tree.
37 If None, a tree will be built from the correlation matrix.
38 method (Literal["single", "complete", "average", "ward"]): Linkage method to use for distance calculation
39 - "single": minimum distance between points (nearest neighbor)
40 - "complete": maximum distance between points (furthest neighbor)
41 - "average": average distance between all points
42 - "ward": Ward variance minimization
43 bisection (bool): Whether to use bisection method for tree construction
45 Returns:
46 Cluster: The root cluster with portfolio weights assigned according to HRP
47 """
48 returns = prices.pct_change().dropna(axis=0, how="all")
49 cov, cor = returns.cov(), returns.corr()
50 node = node or build_tree(cor, method=method, bisection=bisection).root
52 return risk_parity(root=node, cov=cov)
55@dataclass(frozen=True)
56class Dendrogram:
57 """Container for hierarchical clustering dendrogram data and visualization.
59 This class stores the results of hierarchical clustering and provides methods
60 for accessing and visualizing the dendrogram structure.
62 Attributes:
63 root (Cluster): The root node of the hierarchical clustering tree
64 assets (list[Asset]): List of assets included in the clustering
65 linkage (np.ndarray | None): Linkage matrix in scipy format for plotting
66 distance (np.ndarray | None): Distance matrix used for clustering
67 method (str | None): Linkage method used for clustering
68 """
70 root: Cluster
71 assets: list[str]
72 distance: pd.DataFrame | None = None
73 linkage: np.ndarray | None = None
74 method: str | None = None
76 def __post_init__(self):
77 """Validate dataclass fields after initialization.
79 Ensures that the optional distance matrix, when provided, is a pandas
80 DataFrame aligned with the asset order, and verifies that the number of
81 leaves in the cluster tree matches the number of assets.
82 """
83 # ---- Optional: validate distance index/columns ----
84 if self.distance is not None:
85 if not isinstance(self.distance, pd.DataFrame):
86 raise TypeError("distance must be a pandas DataFrame.")
88 # Optionally check if distance matches assets
89 if not self.distance.index.equals(pd.Index(self.assets)) or not self.distance.columns.equals(
90 pd.Index(self.assets)
91 ):
92 raise ValueError("Distance matrix index/columns must align with assets.")
94 # Check the number of leaves and assets
95 if len(self.root.leaves) != len(self.assets):
96 raise ValueError("Number of leaves does not match number of assets.")
98 def plot(self, **kwargs):
99 """Plot the dendrogram."""
100 sch.dendrogram(self.linkage, leaf_rotation=90, leaf_font_size=8, labels=self.assets, **kwargs)
102 @property
103 def ids(self):
104 """Node values in the order left -> right as they appear in the dendrogram."""
105 return [node.value for node in self.root.leaves]
107 @property
108 def names(self):
109 """The asset names as induced by the order of ids."""
110 return [self.assets[i] for i in self.ids]
113def _compute_distance_matrix(corr: pd.DataFrame) -> pd.DataFrame:
114 """Convert correlation matrix to distance matrix."""
115 c = corr.values
116 dist = np.sqrt(np.clip((1.0 - c) / 2.0, a_min=0.0, a_max=1.0))
117 np.fill_diagonal(dist, 0.0)
118 return pd.DataFrame(data=dist, index=corr.index, columns=corr.index)
121def build_tree(
122 cor: pd.DataFrame, method: Literal["single", "complete", "average", "ward"] = "ward", bisection: bool = False
123) -> Dendrogram:
124 """Build hierarchical cluster tree from correlation matrix.
126 This function converts a correlation matrix to a distance matrix, performs
127 hierarchical clustering, and returns a Dendrogram object containing the
128 resulting tree structure.
130 Args:
131 cor (pd.DataFrame): Correlation matrix of asset returns
132 method (Literal["single", "complete", "average", "ward"]): Linkage method for hierarchical clustering
133 - "single": minimum distance between points (nearest neighbor)
134 - "complete": maximum distance between points (furthest neighbor)
135 - "average": average distance between all points
136 - "ward": Ward variance minimization
137 bisection (bool): Whether to use bisection method for tree construction
139 Returns:
140 Dendrogram: Object containing the hierarchical clustering tree, with:
141 - root: Root cluster node
142 - linkage: Linkage matrix for plotting
143 - assets: List of assets
144 - method: Clustering method used
145 - distance: Distance matrix
146 """
147 # Create distance matrix and linkage
148 assert isinstance(cor, pd.DataFrame), "Correlation matrix must be a pandas DataFrame."
149 dist = _compute_distance_matrix(cor)
150 links = sch.linkage(ssd.squareform(dist), method=method)
152 # Convert scipy tree to our Cluster format
153 def to_cluster(node: sch.ClusterNode) -> Cluster:
154 """Convert a scipy ClusterNode to our Cluster format.
156 Args:
157 node (sch.ClusterNode): A node from scipy's hierarchical clustering
159 Returns:
160 Cluster: Equivalent node in our Cluster format
161 """
162 if node.left is not None:
163 left = to_cluster(node.left)
164 right = to_cluster(node.right)
165 return Cluster(value=node.id, left=left, right=right)
166 return Cluster(value=node.id)
168 root = to_cluster(sch.to_tree(links, rd=False))
170 # Apply bisection if requested
171 if bisection:
173 def bisect_tree(ids: list[int]) -> Cluster:
174 """Build tree by recursive bisection.
176 This function recursively splits the list of IDs in half and creates
177 a binary tree where each node represents a split.
179 Args:
180 ids (list[int]): List of leaf node IDs to organize into a tree
182 Returns:
183 Cluster: Root node of the constructed tree
184 """
185 nonlocal nnn
187 if len(ids) == 1:
188 return Cluster(value=ids[0])
190 mid = len(ids) // 2
191 left_ids, right_ids = ids[:mid], ids[mid:]
193 left = bisect_tree(left_ids)
194 right = bisect_tree(right_ids)
196 nnn += 1
197 return Cluster(value=nnn, left=left, right=right)
199 # Rebuild tree using bisection
200 leaf_ids = [node.value for node in root.leaves]
201 nnn = max(leaf_ids)
202 root = bisect_tree(leaf_ids)
204 # Convert back to linkage format for plotting
205 links = []
207 def get_linkage(node: Cluster) -> None:
208 """Convert tree structure back to linkage matrix format.
210 This function traverses the tree and builds a linkage matrix compatible
211 with scipy's hierarchical clustering format for visualization.
213 Args:
214 node (Cluster): Current node being processed
215 """
216 if node.left is not None:
217 get_linkage(node.left)
218 get_linkage(node.right)
219 links.append(
220 [
221 node.left.value,
222 node.right.value,
223 float(node.size),
224 len(node.left.leaves) + len(node.right.leaves),
225 ]
226 )
228 get_linkage(root)
229 links = np.array(links)
231 return Dendrogram(root=root, linkage=links, method=method, distance=dist, assets=cor.columns)