Coverage for src/pyhrp/cluster.py: 100%

57 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:19 +0000

1"""Data structures for hierarchical risk parity portfolio optimization. 

2 

3This module defines the core data structures used in the hierarchical risk parity algorithm: 

4- Asset: Represents a financial asset in a portfolio 

5- Portfolio: Manages a collection of assets and their weights 

6- Cluster: Represents a node in the hierarchical clustering tree 

7""" 

8 

9from __future__ import annotations 

10 

11from dataclasses import dataclass, field 

12from typing import Any 

13 

14import numpy as np 

15import pandas as pd 

16 

17from .treelib import Node 

18 

19 

20@dataclass(frozen=True) 

21class Asset: 

22 """Represents a financial asset in a portfolio. 

23 

24 Attributes: 

25 mu (float, optional): Expected return of the asset 

26 name (str, optional): Name of the asset 

27 """ 

28 

29 mu: float = None 

30 name: str = None 

31 

32 def __hash__(self) -> int: 

33 """Hash function for Asset objects. 

34 

35 Returns: 

36 int: Hash value based on the asset name 

37 """ 

38 return hash(self.name) 

39 

40 def __eq__(self, other: Any) -> bool: 

41 """Equality comparison for Asset objects. 

42 

43 Args: 

44 other (Any): Object to compare with 

45 

46 Returns: 

47 bool: True if other is an Asset with the same name 

48 """ 

49 if not isinstance(other, Asset): 

50 return False 

51 return self.name == other.name 

52 

53 def __lt__(self, other: Any) -> bool: 

54 """Less than comparison for Asset objects. 

55 

56 Args: 

57 other (Any): Object to compare with 

58 

59 Returns: 

60 bool: True if this asset's name is lexicographically less than other's name 

61 """ 

62 if not isinstance(other, Asset): 

63 return False 

64 return self.name < other.name 

65 

66 

67@dataclass 

68class Portfolio: 

69 """Manages a collection of assets and their weights in a portfolio. 

70 

71 This class provides methods to calculate portfolio statistics, retrieve and 

72 set asset weights, and visualize the portfolio composition. 

73 

74 Attributes: 

75 _weights (dict[Asset, float]): Dictionary mapping assets to their weights in the portfolio 

76 """ 

77 

78 _weights: dict[Asset, float] = field(default_factory=dict) 

79 

80 @property 

81 def assets(self) -> list[Asset]: 

82 """Get all assets in the portfolio. 

83 

84 Returns: 

85 list[Asset]: List of assets in the portfolio 

86 """ 

87 return list(self._weights.keys()) 

88 

89 def variance(self, cov: pd.DataFrame) -> float: 

90 """Calculate the variance of the portfolio. 

91 

92 Args: 

93 cov (pd.DataFrame): Covariance matrix 

94 

95 Returns: 

96 float: Portfolio variance 

97 """ 

98 c = cov[self.assets].loc[self.assets].values 

99 w = self.weights[self.assets].values 

100 return np.linalg.multi_dot((w, c, w)) 

101 

102 def __getitem__(self, item: Asset) -> float: 

103 """Get the weight of an asset. 

104 

105 Args: 

106 item (Asset): The asset to get the weight for 

107 

108 Returns: 

109 float: The weight of the asset 

110 """ 

111 return self._weights[item] 

112 

113 def __setitem__(self, key: Asset, value: float) -> None: 

114 """Set the weight of an asset. 

115 

116 Args: 

117 key (Asset): The asset to set the weight for 

118 value (float): The weight to set 

119 """ 

120 self._weights[key] = value 

121 

122 @property 

123 def weights(self) -> pd.Series: 

124 """Get all weights as a pandas Series. 

125 

126 Returns: 

127 pd.Series: Series of weights indexed by assets 

128 """ 

129 return pd.Series(self._weights, name="Weights").sort_index() 

130 

131 def weights_by_name(self, names: list[str]) -> pd.Series: 

132 """Get weights as a pandas Series indexed by asset names. 

133 

134 Args: 

135 names (list[str]): List of asset names to include 

136 

137 Returns: 

138 pd.Series: Series of weights indexed by asset names 

139 """ 

140 w = {asset.name: weight for asset, weight in self.weights.items()} 

141 return pd.Series({name: w[name] for name in names}) 

142 

143 def plot(self, names: list[str]): 

144 """Plot the portfolio weights. 

145 

146 Args: 

147 names (list[str]): List of asset names to include in the plot 

148 

149 Returns: 

150 matplotlib.axes.Axes: The plot axes 

151 """ 

152 a = self.weights_by_name(names) 

153 

154 ax = a.plot(kind="bar", color="skyblue") 

155 

156 # Set x-axis labels and rotations 

157 ax.set_xticklabels(names, rotation=90, fontsize=8) 

158 return ax 

159 

160 

161class Cluster(Node): 

162 """Represents a cluster in the hierarchical clustering tree. 

163 

164 Clusters are the nodes of the graphs we build. 

165 Each cluster is aware of the left and the right cluster 

166 it is connecting to. Each cluster also has an associated portfolio. 

167 

168 Attributes: 

169 portfolio (Portfolio): The portfolio associated with this cluster 

170 """ 

171 

172 def __init__(self, value: int, left: Cluster | None = None, right: Cluster | None = None, **kwargs): 

173 """Initialize a new Cluster. 

174 

175 Args: 

176 value (int): The identifier for this cluster 

177 left (Cluster, optional): The left child cluster 

178 right (Cluster, optional): The right child cluster 

179 **kwargs: Additional arguments to pass to the parent class 

180 """ 

181 super().__init__(value=value, left=left, right=right, **kwargs) 

182 self.portfolio = Portfolio() 

183 

184 @property 

185 def is_leaf(self) -> bool: 

186 """Check if this cluster is a leaf node (has no children). 

187 

188 Returns: 

189 bool: True if this is a leaf node, False otherwise 

190 """ 

191 return self.left is None and self.right is None 

192 

193 @property 

194 def leaves(self) -> list[Cluster]: 

195 """Get all reachable leaf nodes in the correct order. 

196 

197 Note that the leaves method of the Node class implemented in BinaryTree 

198 is not respecting the 'correct' order of the nodes. 

199 

200 Returns: 

201 list[Cluster]: List of all leaf nodes reachable from this cluster 

202 """ 

203 if self.is_leaf: 

204 return [self] 

205 else: 

206 return self.left.leaves + self.right.leaves