Coverage for src/pyhrp/cluster.py: 100%
57 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:19 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:19 +0000
1"""Data structures for hierarchical risk parity portfolio optimization.
3This module defines the core data structures used in the hierarchical risk parity algorithm:
4- Asset: Represents a financial asset in a portfolio
5- Portfolio: Manages a collection of assets and their weights
6- Cluster: Represents a node in the hierarchical clustering tree
7"""
9from __future__ import annotations
11from dataclasses import dataclass, field
12from typing import Any
14import numpy as np
15import pandas as pd
17from .treelib import Node
20@dataclass(frozen=True)
21class Asset:
22 """Represents a financial asset in a portfolio.
24 Attributes:
25 mu (float, optional): Expected return of the asset
26 name (str, optional): Name of the asset
27 """
29 mu: float = None
30 name: str = None
32 def __hash__(self) -> int:
33 """Hash function for Asset objects.
35 Returns:
36 int: Hash value based on the asset name
37 """
38 return hash(self.name)
40 def __eq__(self, other: Any) -> bool:
41 """Equality comparison for Asset objects.
43 Args:
44 other (Any): Object to compare with
46 Returns:
47 bool: True if other is an Asset with the same name
48 """
49 if not isinstance(other, Asset):
50 return False
51 return self.name == other.name
53 def __lt__(self, other: Any) -> bool:
54 """Less than comparison for Asset objects.
56 Args:
57 other (Any): Object to compare with
59 Returns:
60 bool: True if this asset's name is lexicographically less than other's name
61 """
62 if not isinstance(other, Asset):
63 return False
64 return self.name < other.name
67@dataclass
68class Portfolio:
69 """Manages a collection of assets and their weights in a portfolio.
71 This class provides methods to calculate portfolio statistics, retrieve and
72 set asset weights, and visualize the portfolio composition.
74 Attributes:
75 _weights (dict[Asset, float]): Dictionary mapping assets to their weights in the portfolio
76 """
78 _weights: dict[Asset, float] = field(default_factory=dict)
80 @property
81 def assets(self) -> list[Asset]:
82 """Get all assets in the portfolio.
84 Returns:
85 list[Asset]: List of assets in the portfolio
86 """
87 return list(self._weights.keys())
89 def variance(self, cov: pd.DataFrame) -> float:
90 """Calculate the variance of the portfolio.
92 Args:
93 cov (pd.DataFrame): Covariance matrix
95 Returns:
96 float: Portfolio variance
97 """
98 c = cov[self.assets].loc[self.assets].values
99 w = self.weights[self.assets].values
100 return np.linalg.multi_dot((w, c, w))
102 def __getitem__(self, item: Asset) -> float:
103 """Get the weight of an asset.
105 Args:
106 item (Asset): The asset to get the weight for
108 Returns:
109 float: The weight of the asset
110 """
111 return self._weights[item]
113 def __setitem__(self, key: Asset, value: float) -> None:
114 """Set the weight of an asset.
116 Args:
117 key (Asset): The asset to set the weight for
118 value (float): The weight to set
119 """
120 self._weights[key] = value
122 @property
123 def weights(self) -> pd.Series:
124 """Get all weights as a pandas Series.
126 Returns:
127 pd.Series: Series of weights indexed by assets
128 """
129 return pd.Series(self._weights, name="Weights").sort_index()
131 def weights_by_name(self, names: list[str]) -> pd.Series:
132 """Get weights as a pandas Series indexed by asset names.
134 Args:
135 names (list[str]): List of asset names to include
137 Returns:
138 pd.Series: Series of weights indexed by asset names
139 """
140 w = {asset.name: weight for asset, weight in self.weights.items()}
141 return pd.Series({name: w[name] for name in names})
143 def plot(self, names: list[str]):
144 """Plot the portfolio weights.
146 Args:
147 names (list[str]): List of asset names to include in the plot
149 Returns:
150 matplotlib.axes.Axes: The plot axes
151 """
152 a = self.weights_by_name(names)
154 ax = a.plot(kind="bar", color="skyblue")
156 # Set x-axis labels and rotations
157 ax.set_xticklabels(names, rotation=90, fontsize=8)
158 return ax
161class Cluster(Node):
162 """Represents a cluster in the hierarchical clustering tree.
164 Clusters are the nodes of the graphs we build.
165 Each cluster is aware of the left and the right cluster
166 it is connecting to. Each cluster also has an associated portfolio.
168 Attributes:
169 portfolio (Portfolio): The portfolio associated with this cluster
170 """
172 def __init__(self, value: int, left: Cluster | None = None, right: Cluster | None = None, **kwargs):
173 """Initialize a new Cluster.
175 Args:
176 value (int): The identifier for this cluster
177 left (Cluster, optional): The left child cluster
178 right (Cluster, optional): The right child cluster
179 **kwargs: Additional arguments to pass to the parent class
180 """
181 super().__init__(value=value, left=left, right=right, **kwargs)
182 self.portfolio = Portfolio()
184 @property
185 def is_leaf(self) -> bool:
186 """Check if this cluster is a leaf node (has no children).
188 Returns:
189 bool: True if this is a leaf node, False otherwise
190 """
191 return self.left is None and self.right is None
193 @property
194 def leaves(self) -> list[Cluster]:
195 """Get all reachable leaf nodes in the correct order.
197 Note that the leaves method of the Node class implemented in BinaryTree
198 is not respecting the 'correct' order of the nodes.
200 Returns:
201 list[Cluster]: List of all leaf nodes reachable from this cluster
202 """
203 if self.is_leaf:
204 return [self]
205 else:
206 return self.left.leaves + self.right.leaves