Coverage for src/pyhrp/cluster.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-26 16:50 +0000

1"""Data structures for hierarchical risk parity portfolio optimization. 

2 

3This module defines the core data structures used in the hierarchical risk parity algorithm: 

4- Portfolio: Manages a collection of asset weights (strings identify assets) 

5- Cluster: Represents a node in the hierarchical clustering tree 

6""" 

7 

8from __future__ import annotations 

9 

10from dataclasses import dataclass, field 

11 

12import numpy as np 

13import pandas as pd 

14 

15from .treelib import Node 

16 

17 

18@dataclass 

19class Portfolio: 

20 """Container for portfolio asset weights. 

21 

22 This lightweight class stores and manipulates a mapping from asset names to 

23 their portfolio weights, and provides convenience helpers for analysis and 

24 visualization. 

25 

26 Attributes: 

27 _weights (dict[str, float]): Internal mapping from asset symbol to weight. 

28 """ 

29 

30 _weights: dict[str, float] = field(default_factory=dict) 

31 

32 @property 

33 def assets(self) -> list[str]: 

34 """List of asset names present in the portfolio. 

35 

36 Returns: 

37 list[str]: Asset identifiers in insertion order (Python 3.7+ dict order). 

38 """ 

39 return list(self._weights.keys()) 

40 

41 def variance(self, cov: pd.DataFrame) -> float: 

42 """Calculate the variance of the portfolio. 

43 

44 Args: 

45 cov (pd.DataFrame): Covariance matrix 

46 

47 Returns: 

48 float: Portfolio variance 

49 """ 

50 c = cov[self.assets].loc[self.assets].values 

51 w = self.weights[self.assets].values 

52 return np.linalg.multi_dot((w, c, w)) 

53 

54 def __getitem__(self, item: str) -> float: 

55 """Return the weight for a given asset. 

56 

57 Args: 

58 item (str): Asset name/symbol. 

59 

60 Returns: 

61 float: The weight associated with the asset. 

62 

63 Raises: 

64 KeyError: If the asset is not present in the portfolio. 

65 """ 

66 return self._weights[item] 

67 

68 def __setitem__(self, key: str, value: float) -> None: 

69 """Set or update the weight for an asset. 

70 

71 Args: 

72 key (str): Asset name/symbol. 

73 value (float): Portfolio weight for the asset. 

74 """ 

75 self._weights[key] = value 

76 

77 @property 

78 def weights(self) -> pd.Series: 

79 """Get all weights as a pandas Series. 

80 

81 Returns: 

82 pd.Series: Series of weights indexed by assets 

83 """ 

84 return pd.Series(self._weights, name="Weights").sort_index() 

85 

86 def plot(self, names: list[str]): 

87 """Plot the portfolio weights. 

88 

89 Args: 

90 names (list[str]): List of asset names to include in the plot 

91 

92 Returns: 

93 matplotlib.axes.Axes: The plot axes 

94 """ 

95 a = self.weights.loc[names] 

96 

97 ax = a.plot(kind="bar", color="skyblue") 

98 

99 # Set x-axis labels and rotations 

100 ax.set_xticklabels(names, rotation=90, fontsize=8) 

101 return ax 

102 

103 

104class Cluster(Node): 

105 """Represents a cluster in the hierarchical clustering tree. 

106 

107 Clusters are the nodes of the graphs we build. 

108 Each cluster is aware of the left and the right cluster 

109 it is connecting to. Each cluster also has an associated portfolio. 

110 

111 Attributes: 

112 portfolio (Portfolio): The portfolio associated with this cluster 

113 """ 

114 

115 def __init__(self, value: int, left: Cluster | None = None, right: Cluster | None = None, **kwargs): 

116 """Initialize a new Cluster. 

117 

118 Args: 

119 value (int): The identifier for this cluster 

120 left (Cluster, optional): The left child cluster 

121 right (Cluster, optional): The right child cluster 

122 **kwargs: Additional arguments to pass to the parent class 

123 """ 

124 super().__init__(value=value, left=left, right=right, **kwargs) 

125 self.portfolio = Portfolio() 

126 

127 @property 

128 def is_leaf(self) -> bool: 

129 """Check if this cluster is a leaf node (has no children). 

130 

131 Returns: 

132 bool: True if this is a leaf node, False otherwise 

133 """ 

134 return self.left is None and self.right is None 

135 

136 @property 

137 def leaves(self) -> list[Cluster]: 

138 """Get all reachable leaf nodes in the correct order. 

139 

140 Note that the leaves method of the Node class implemented in BinaryTree 

141 is not respecting the 'correct' order of the nodes. 

142 

143 Returns: 

144 list[Cluster]: List of all leaf nodes reachable from this cluster 

145 """ 

146 if self.is_leaf: 

147 return [self] 

148 else: 

149 return self.left.leaves + self.right.leaves