Coverage for src/pyhrp/cluster.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-16 16:29 +0000

1"""Data structures for hierarchical risk parity portfolio optimization. 

2 

3This module defines the core data structures used in the hierarchical risk parity algorithm: 

4- Portfolio: Manages a collection of asset weights (strings identify assets) 

5- Cluster: Represents a node in the hierarchical clustering tree 

6""" 

7 

8from __future__ import annotations 

9 

10from dataclasses import dataclass, field 

11 

12import numpy as np 

13import plotly.graph_objects as go 

14import polars as pl 

15 

16from .treelib import Node 

17 

18__all__ = ["Cluster", "Portfolio"] 

19 

20 

21@dataclass 

22class Portfolio: 

23 """Container for portfolio asset weights. 

24 

25 This lightweight class stores and manipulates a mapping from asset names to 

26 their portfolio weights, and provides convenience helpers for analysis and 

27 visualization. 

28 

29 Attributes: 

30 _weights (dict[str, float]): Internal mapping from asset symbol to weight. 

31 """ 

32 

33 _weights: dict[str, float] = field(default_factory=dict) 

34 

35 @property 

36 def assets(self) -> list[str]: 

37 """List of asset names present in the portfolio. 

38 

39 Returns: 

40 list[str]: Asset identifiers in insertion order (Python 3.7+ dict order). 

41 """ 

42 return list(self._weights.keys()) 

43 

44 def variance(self, cov: pl.DataFrame) -> float: 

45 """Calculate the variance of the portfolio. 

46 

47 Args: 

48 cov (pl.DataFrame): Covariance matrix where columns and rows correspond 

49 to assets in the same order as columns list. 

50 

51 Returns: 

52 float: Portfolio variance 

53 """ 

54 assets = self.assets 

55 index = {name: i for i, name in enumerate(cov.columns)} 

56 row_indices = [index[a] for a in assets] 

57 cov_matrix = cov.to_numpy() 

58 c = cov_matrix[np.ix_(row_indices, row_indices)] 

59 w = np.array([self._weights[a] for a in assets]) 

60 return float(w @ c @ w) 

61 

62 def __getitem__(self, item: str) -> float: 

63 """Return the weight for a given asset. 

64 

65 Args: 

66 item (str): Asset name/symbol. 

67 

68 Returns: 

69 float: The weight associated with the asset. 

70 

71 Raises: 

72 KeyError: If the asset is not present in the portfolio. 

73 """ 

74 return self._weights[item] 

75 

76 def __setitem__(self, key: str, value: float) -> None: 

77 """Set or update the weight for an asset. 

78 

79 Args: 

80 key (str): Asset name/symbol. 

81 value (float): Portfolio weight for the asset. 

82 """ 

83 self._weights[key] = value 

84 

85 @property 

86 def weights(self) -> dict[str, float]: 

87 """Get all weights as a dict sorted alphabetically by asset name. 

88 

89 Returns: 

90 dict[str, float]: Mapping from asset name to weight, sorted by name. 

91 """ 

92 return dict(sorted(self._weights.items())) 

93 

94 def plot(self, names: list[str]) -> go.Figure: 

95 """Plot the portfolio weights as a bar chart. 

96 

97 Args: 

98 names (list[str]): List of asset names to include in the plot 

99 

100 Returns: 

101 go.Figure: The plotly figure 

102 """ 

103 w = self.weights 

104 values = [w[n] for n in names] 

105 fig = go.Figure(go.Bar(x=names, y=values, marker_color="steelblue")) 

106 fig.update_layout(xaxis={"tickangle": -90}) 

107 return fig 

108 

109 

110class Cluster(Node[int]): 

111 """Represents a cluster in the hierarchical clustering tree. 

112 

113 Clusters are the nodes of the graphs we build. 

114 Each cluster is aware of the left and the right cluster 

115 it is connecting to. Each cluster also has an associated portfolio. 

116 

117 Attributes: 

118 portfolio (Portfolio): The portfolio associated with this cluster 

119 """ 

120 

121 def __init__(self, value: int, left: Cluster | None = None, right: Cluster | None = None) -> None: 

122 """Initialize a new Cluster. 

123 

124 Args: 

125 value (int): The identifier for this cluster 

126 left (Cluster, optional): The left child cluster 

127 right (Cluster, optional): The right child cluster 

128 """ 

129 super().__init__(value=value, left=left, right=right) 

130 self.portfolio = Portfolio() 

131 

132 # Override narrows the return type to list[Cluster] and validates tree integrity; 

133 # the traversal order (left to right) matches Node.leaves. 

134 @property 

135 def leaves(self) -> list[Cluster]: 

136 """Get all reachable leaf nodes in left-to-right dendrogram order. 

137 

138 Returns: 

139 list[Cluster]: List of all leaf nodes reachable from this cluster 

140 """ 

141 if self.is_leaf: 

142 return [self] 

143 else: 

144 if self.left is None: 

145 raise ValueError("Expected left child to exist for non-leaf cluster") # noqa: TRY003 

146 if self.right is None: 

147 raise ValueError("Expected right child to exist for non-leaf cluster") # noqa: TRY003 

148 if not isinstance(self.left, Cluster): 

149 raise TypeError(f"Expected left child to be a Cluster for node {self.value}") # noqa: TRY003 

150 if not isinstance(self.right, Cluster): 

151 raise TypeError(f"Expected right child to be a Cluster for node {self.value}") # noqa: TRY003 

152 return self.left.leaves + self.right.leaves