Coverage for src/pyhrp/cluster.py: 100%
51 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-16 16:29 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-16 16:29 +0000
1"""Data structures for hierarchical risk parity portfolio optimization.
3This module defines the core data structures used in the hierarchical risk parity algorithm:
4- Portfolio: Manages a collection of asset weights (strings identify assets)
5- Cluster: Represents a node in the hierarchical clustering tree
6"""
8from __future__ import annotations
10from dataclasses import dataclass, field
12import numpy as np
13import plotly.graph_objects as go
14import polars as pl
16from .treelib import Node
18__all__ = ["Cluster", "Portfolio"]
21@dataclass
22class Portfolio:
23 """Container for portfolio asset weights.
25 This lightweight class stores and manipulates a mapping from asset names to
26 their portfolio weights, and provides convenience helpers for analysis and
27 visualization.
29 Attributes:
30 _weights (dict[str, float]): Internal mapping from asset symbol to weight.
31 """
33 _weights: dict[str, float] = field(default_factory=dict)
35 @property
36 def assets(self) -> list[str]:
37 """List of asset names present in the portfolio.
39 Returns:
40 list[str]: Asset identifiers in insertion order (Python 3.7+ dict order).
41 """
42 return list(self._weights.keys())
44 def variance(self, cov: pl.DataFrame) -> float:
45 """Calculate the variance of the portfolio.
47 Args:
48 cov (pl.DataFrame): Covariance matrix where columns and rows correspond
49 to assets in the same order as columns list.
51 Returns:
52 float: Portfolio variance
53 """
54 assets = self.assets
55 index = {name: i for i, name in enumerate(cov.columns)}
56 row_indices = [index[a] for a in assets]
57 cov_matrix = cov.to_numpy()
58 c = cov_matrix[np.ix_(row_indices, row_indices)]
59 w = np.array([self._weights[a] for a in assets])
60 return float(w @ c @ w)
62 def __getitem__(self, item: str) -> float:
63 """Return the weight for a given asset.
65 Args:
66 item (str): Asset name/symbol.
68 Returns:
69 float: The weight associated with the asset.
71 Raises:
72 KeyError: If the asset is not present in the portfolio.
73 """
74 return self._weights[item]
76 def __setitem__(self, key: str, value: float) -> None:
77 """Set or update the weight for an asset.
79 Args:
80 key (str): Asset name/symbol.
81 value (float): Portfolio weight for the asset.
82 """
83 self._weights[key] = value
85 @property
86 def weights(self) -> dict[str, float]:
87 """Get all weights as a dict sorted alphabetically by asset name.
89 Returns:
90 dict[str, float]: Mapping from asset name to weight, sorted by name.
91 """
92 return dict(sorted(self._weights.items()))
94 def plot(self, names: list[str]) -> go.Figure:
95 """Plot the portfolio weights as a bar chart.
97 Args:
98 names (list[str]): List of asset names to include in the plot
100 Returns:
101 go.Figure: The plotly figure
102 """
103 w = self.weights
104 values = [w[n] for n in names]
105 fig = go.Figure(go.Bar(x=names, y=values, marker_color="steelblue"))
106 fig.update_layout(xaxis={"tickangle": -90})
107 return fig
110class Cluster(Node[int]):
111 """Represents a cluster in the hierarchical clustering tree.
113 Clusters are the nodes of the graphs we build.
114 Each cluster is aware of the left and the right cluster
115 it is connecting to. Each cluster also has an associated portfolio.
117 Attributes:
118 portfolio (Portfolio): The portfolio associated with this cluster
119 """
121 def __init__(self, value: int, left: Cluster | None = None, right: Cluster | None = None) -> None:
122 """Initialize a new Cluster.
124 Args:
125 value (int): The identifier for this cluster
126 left (Cluster, optional): The left child cluster
127 right (Cluster, optional): The right child cluster
128 """
129 super().__init__(value=value, left=left, right=right)
130 self.portfolio = Portfolio()
132 # Override narrows the return type to list[Cluster] and validates tree integrity;
133 # the traversal order (left to right) matches Node.leaves.
134 @property
135 def leaves(self) -> list[Cluster]:
136 """Get all reachable leaf nodes in left-to-right dendrogram order.
138 Returns:
139 list[Cluster]: List of all leaf nodes reachable from this cluster
140 """
141 if self.is_leaf:
142 return [self]
143 else:
144 if self.left is None:
145 raise ValueError("Expected left child to exist for non-leaf cluster") # noqa: TRY003
146 if self.right is None:
147 raise ValueError("Expected right child to exist for non-leaf cluster") # noqa: TRY003
148 if not isinstance(self.left, Cluster):
149 raise TypeError(f"Expected left child to be a Cluster for node {self.value}") # noqa: TRY003
150 if not isinstance(self.right, Cluster):
151 raise TypeError(f"Expected right child to be a Cluster for node {self.value}") # noqa: TRY003
152 return self.left.leaves + self.right.leaves