Coverage for src / pyhrp / cluster.py: 96%
47 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-10 05:36 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-10 05:36 +0000
1"""Data structures for hierarchical risk parity portfolio optimization.
3This module defines the core data structures used in the hierarchical risk parity algorithm:
4- Portfolio: Manages a collection of asset weights (strings identify assets)
5- Cluster: Represents a node in the hierarchical clustering tree
6"""
8from __future__ import annotations
10from dataclasses import dataclass, field
11from typing import Any
13import numpy as np
14import pandas as pd
15from matplotlib.axes import Axes
17from .treelib import Node
20@dataclass
21class Portfolio:
22 """Container for portfolio asset weights.
24 This lightweight class stores and manipulates a mapping from asset names to
25 their portfolio weights, and provides convenience helpers for analysis and
26 visualization.
28 Attributes:
29 _weights (dict[str, float]): Internal mapping from asset symbol to weight.
30 """
32 _weights: dict[str, float] = field(default_factory=dict)
34 @property
35 def assets(self) -> list[str]:
36 """List of asset names present in the portfolio.
38 Returns:
39 list[str]: Asset identifiers in insertion order (Python 3.7+ dict order).
40 """
41 return list(self._weights.keys())
43 def variance(self, cov: pd.DataFrame) -> float:
44 """Calculate the variance of the portfolio.
46 Args:
47 cov (pd.DataFrame): Covariance matrix
49 Returns:
50 float: Portfolio variance
51 """
52 c = cov[self.assets].loc[self.assets].values
53 w = self.weights[self.assets].values
54 return float(np.linalg.multi_dot((w, c, w)))
56 def __getitem__(self, item: str) -> float:
57 """Return the weight for a given asset.
59 Args:
60 item (str): Asset name/symbol.
62 Returns:
63 float: The weight associated with the asset.
65 Raises:
66 KeyError: If the asset is not present in the portfolio.
67 """
68 return self._weights[item]
70 def __setitem__(self, key: str, value: float) -> None:
71 """Set or update the weight for an asset.
73 Args:
74 key (str): Asset name/symbol.
75 value (float): Portfolio weight for the asset.
76 """
77 self._weights[key] = value
79 @property
80 def weights(self) -> pd.Series:
81 """Get all weights as a pandas Series.
83 Returns:
84 pd.Series: Series of weights indexed by assets
85 """
86 return pd.Series(self._weights, name="Weights").sort_index()
88 def plot(self, names: list[str]) -> Axes:
89 """Plot the portfolio weights.
91 Args:
92 names (list[str]): List of asset names to include in the plot
94 Returns:
95 matplotlib.axes.Axes: The plot axes
96 """
97 a = self.weights.loc[names]
99 ax = a.plot(kind="bar", color="skyblue")
101 # Set x-axis labels and rotations
102 ax.set_xticklabels(names, rotation=90, fontsize=8)
103 return ax
106class Cluster(Node):
107 """Represents a cluster in the hierarchical clustering tree.
109 Clusters are the nodes of the graphs we build.
110 Each cluster is aware of the left and the right cluster
111 it is connecting to. Each cluster also has an associated portfolio.
113 Attributes:
114 portfolio (Portfolio): The portfolio associated with this cluster
115 """
117 def __init__(self, value: int, left: Cluster | None = None, right: Cluster | None = None, **kwargs: Any) -> None:
118 """Initialize a new Cluster.
120 Args:
121 value (int): The identifier for this cluster
122 left (Cluster, optional): The left child cluster
123 right (Cluster, optional): The right child cluster
124 **kwargs: Additional arguments to pass to the parent class
125 """
126 super().__init__(value=value, left=left, right=right, **kwargs)
127 self.portfolio = Portfolio()
129 @property
130 def is_leaf(self) -> bool:
131 """Check if this cluster is a leaf node (has no children).
133 Returns:
134 bool: True if this is a leaf node, False otherwise
135 """
136 return self.left is None and self.right is None
138 @property
139 def leaves(self) -> list[Cluster]:
140 """Get all reachable leaf nodes in the correct order.
142 Note that the leaves method of the Node class implemented in BinaryTree
143 is not respecting the 'correct' order of the nodes.
145 Returns:
146 list[Cluster]: List of all leaf nodes reachable from this cluster
147 """
148 if self.is_leaf:
149 return [self]
150 else:
151 if self.left is None:
152 raise ValueError("Expected left child to exist for non-leaf cluster") # noqa: TRY003
153 if self.right is None:
154 raise ValueError("Expected right child to exist for non-leaf cluster") # noqa: TRY003
155 left_leaves: list[Cluster] = self.left.leaves # type: ignore[assignment]
156 right_leaves: list[Cluster] = self.right.leaves # type: ignore[assignment]
157 return left_leaves + right_leaves