Coverage for src / pyhrp / cluster.py: 96%

47 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-10 05:36 +0000

1"""Data structures for hierarchical risk parity portfolio optimization. 

2 

3This module defines the core data structures used in the hierarchical risk parity algorithm: 

4- Portfolio: Manages a collection of asset weights (strings identify assets) 

5- Cluster: Represents a node in the hierarchical clustering tree 

6""" 

7 

8from __future__ import annotations 

9 

10from dataclasses import dataclass, field 

11from typing import Any 

12 

13import numpy as np 

14import pandas as pd 

15from matplotlib.axes import Axes 

16 

17from .treelib import Node 

18 

19 

20@dataclass 

21class Portfolio: 

22 """Container for portfolio asset weights. 

23 

24 This lightweight class stores and manipulates a mapping from asset names to 

25 their portfolio weights, and provides convenience helpers for analysis and 

26 visualization. 

27 

28 Attributes: 

29 _weights (dict[str, float]): Internal mapping from asset symbol to weight. 

30 """ 

31 

32 _weights: dict[str, float] = field(default_factory=dict) 

33 

34 @property 

35 def assets(self) -> list[str]: 

36 """List of asset names present in the portfolio. 

37 

38 Returns: 

39 list[str]: Asset identifiers in insertion order (Python 3.7+ dict order). 

40 """ 

41 return list(self._weights.keys()) 

42 

43 def variance(self, cov: pd.DataFrame) -> float: 

44 """Calculate the variance of the portfolio. 

45 

46 Args: 

47 cov (pd.DataFrame): Covariance matrix 

48 

49 Returns: 

50 float: Portfolio variance 

51 """ 

52 c = cov[self.assets].loc[self.assets].values 

53 w = self.weights[self.assets].values 

54 return float(np.linalg.multi_dot((w, c, w))) 

55 

56 def __getitem__(self, item: str) -> float: 

57 """Return the weight for a given asset. 

58 

59 Args: 

60 item (str): Asset name/symbol. 

61 

62 Returns: 

63 float: The weight associated with the asset. 

64 

65 Raises: 

66 KeyError: If the asset is not present in the portfolio. 

67 """ 

68 return self._weights[item] 

69 

70 def __setitem__(self, key: str, value: float) -> None: 

71 """Set or update the weight for an asset. 

72 

73 Args: 

74 key (str): Asset name/symbol. 

75 value (float): Portfolio weight for the asset. 

76 """ 

77 self._weights[key] = value 

78 

79 @property 

80 def weights(self) -> pd.Series: 

81 """Get all weights as a pandas Series. 

82 

83 Returns: 

84 pd.Series: Series of weights indexed by assets 

85 """ 

86 return pd.Series(self._weights, name="Weights").sort_index() 

87 

88 def plot(self, names: list[str]) -> Axes: 

89 """Plot the portfolio weights. 

90 

91 Args: 

92 names (list[str]): List of asset names to include in the plot 

93 

94 Returns: 

95 matplotlib.axes.Axes: The plot axes 

96 """ 

97 a = self.weights.loc[names] 

98 

99 ax = a.plot(kind="bar", color="skyblue") 

100 

101 # Set x-axis labels and rotations 

102 ax.set_xticklabels(names, rotation=90, fontsize=8) 

103 return ax 

104 

105 

106class Cluster(Node): 

107 """Represents a cluster in the hierarchical clustering tree. 

108 

109 Clusters are the nodes of the graphs we build. 

110 Each cluster is aware of the left and the right cluster 

111 it is connecting to. Each cluster also has an associated portfolio. 

112 

113 Attributes: 

114 portfolio (Portfolio): The portfolio associated with this cluster 

115 """ 

116 

117 def __init__(self, value: int, left: Cluster | None = None, right: Cluster | None = None, **kwargs: Any) -> None: 

118 """Initialize a new Cluster. 

119 

120 Args: 

121 value (int): The identifier for this cluster 

122 left (Cluster, optional): The left child cluster 

123 right (Cluster, optional): The right child cluster 

124 **kwargs: Additional arguments to pass to the parent class 

125 """ 

126 super().__init__(value=value, left=left, right=right, **kwargs) 

127 self.portfolio = Portfolio() 

128 

129 @property 

130 def is_leaf(self) -> bool: 

131 """Check if this cluster is a leaf node (has no children). 

132 

133 Returns: 

134 bool: True if this is a leaf node, False otherwise 

135 """ 

136 return self.left is None and self.right is None 

137 

138 @property 

139 def leaves(self) -> list[Cluster]: 

140 """Get all reachable leaf nodes in the correct order. 

141 

142 Note that the leaves method of the Node class implemented in BinaryTree 

143 is not respecting the 'correct' order of the nodes. 

144 

145 Returns: 

146 list[Cluster]: List of all leaf nodes reachable from this cluster 

147 """ 

148 if self.is_leaf: 

149 return [self] 

150 else: 

151 if self.left is None: 

152 raise ValueError("Expected left child to exist for non-leaf cluster") # noqa: TRY003 

153 if self.right is None: 

154 raise ValueError("Expected right child to exist for non-leaf cluster") # noqa: TRY003 

155 left_leaves: list[Cluster] = self.left.leaves # type: ignore[assignment] 

156 right_leaves: list[Cluster] = self.right.leaves # type: ignore[assignment] 

157 return left_leaves + right_leaves