Coverage for src/pyhrp/cluster.py: 100%
39 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-26 16:50 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-26 16:50 +0000
1"""Data structures for hierarchical risk parity portfolio optimization.
3This module defines the core data structures used in the hierarchical risk parity algorithm:
4- Portfolio: Manages a collection of asset weights (strings identify assets)
5- Cluster: Represents a node in the hierarchical clustering tree
6"""
8from __future__ import annotations
10from dataclasses import dataclass, field
12import numpy as np
13import pandas as pd
15from .treelib import Node
18@dataclass
19class Portfolio:
20 """Container for portfolio asset weights.
22 This lightweight class stores and manipulates a mapping from asset names to
23 their portfolio weights, and provides convenience helpers for analysis and
24 visualization.
26 Attributes:
27 _weights (dict[str, float]): Internal mapping from asset symbol to weight.
28 """
30 _weights: dict[str, float] = field(default_factory=dict)
32 @property
33 def assets(self) -> list[str]:
34 """List of asset names present in the portfolio.
36 Returns:
37 list[str]: Asset identifiers in insertion order (Python 3.7+ dict order).
38 """
39 return list(self._weights.keys())
41 def variance(self, cov: pd.DataFrame) -> float:
42 """Calculate the variance of the portfolio.
44 Args:
45 cov (pd.DataFrame): Covariance matrix
47 Returns:
48 float: Portfolio variance
49 """
50 c = cov[self.assets].loc[self.assets].values
51 w = self.weights[self.assets].values
52 return np.linalg.multi_dot((w, c, w))
54 def __getitem__(self, item: str) -> float:
55 """Return the weight for a given asset.
57 Args:
58 item (str): Asset name/symbol.
60 Returns:
61 float: The weight associated with the asset.
63 Raises:
64 KeyError: If the asset is not present in the portfolio.
65 """
66 return self._weights[item]
68 def __setitem__(self, key: str, value: float) -> None:
69 """Set or update the weight for an asset.
71 Args:
72 key (str): Asset name/symbol.
73 value (float): Portfolio weight for the asset.
74 """
75 self._weights[key] = value
77 @property
78 def weights(self) -> pd.Series:
79 """Get all weights as a pandas Series.
81 Returns:
82 pd.Series: Series of weights indexed by assets
83 """
84 return pd.Series(self._weights, name="Weights").sort_index()
86 def plot(self, names: list[str]):
87 """Plot the portfolio weights.
89 Args:
90 names (list[str]): List of asset names to include in the plot
92 Returns:
93 matplotlib.axes.Axes: The plot axes
94 """
95 a = self.weights.loc[names]
97 ax = a.plot(kind="bar", color="skyblue")
99 # Set x-axis labels and rotations
100 ax.set_xticklabels(names, rotation=90, fontsize=8)
101 return ax
104class Cluster(Node):
105 """Represents a cluster in the hierarchical clustering tree.
107 Clusters are the nodes of the graphs we build.
108 Each cluster is aware of the left and the right cluster
109 it is connecting to. Each cluster also has an associated portfolio.
111 Attributes:
112 portfolio (Portfolio): The portfolio associated with this cluster
113 """
115 def __init__(self, value: int, left: Cluster | None = None, right: Cluster | None = None, **kwargs):
116 """Initialize a new Cluster.
118 Args:
119 value (int): The identifier for this cluster
120 left (Cluster, optional): The left child cluster
121 right (Cluster, optional): The right child cluster
122 **kwargs: Additional arguments to pass to the parent class
123 """
124 super().__init__(value=value, left=left, right=right, **kwargs)
125 self.portfolio = Portfolio()
127 @property
128 def is_leaf(self) -> bool:
129 """Check if this cluster is a leaf node (has no children).
131 Returns:
132 bool: True if this is a leaf node, False otherwise
133 """
134 return self.left is None and self.right is None
136 @property
137 def leaves(self) -> list[Cluster]:
138 """Get all reachable leaf nodes in the correct order.
140 Note that the leaves method of the Node class implemented in BinaryTree
141 is not respecting the 'correct' order of the nodes.
143 Returns:
144 list[Cluster]: List of all leaf nodes reachable from this cluster
145 """
146 if self.is_leaf:
147 return [self]
148 else:
149 return self.left.leaves + self.right.leaves