Coverage for src/jquantstats/_plots.py: 100%
49 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-05 07:23 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-05 07:23 +0000
1import dataclasses
3import plotly.express as px
4import plotly.graph_objects as go
5import polars as pl
6from plotly.subplots import make_subplots
9def _plot_performance_dashboard(returns: pl.DataFrame, log_scale=False) -> go.Figure:
10 def hex_to_rgba(hex_color: str, alpha: float = 0.5) -> str:
11 hex_color = hex_color.lstrip("#")
12 r, g, b = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
13 return f"rgba({r}, {g}, {b}, {alpha})"
15 # Get the date column name from the first column of the DataFrame
16 date_col = returns.columns[0]
18 # Get the tickers (all columns except the date column)
19 tickers = [col for col in returns.columns if col != date_col]
21 # Calculate cumulative returns (prices)
22 prices = returns.with_columns([((1 + pl.col(ticker)).cum_prod()).alias(f"{ticker}_price") for ticker in tickers])
24 palette = px.colors.qualitative.Plotly
25 colors = {ticker: palette[i % len(palette)] for i, ticker in enumerate(tickers)}
26 colors.update({f"{ticker}_light": hex_to_rgba(colors[ticker]) for ticker in tickers})
28 # Resample to monthly returns
29 monthly_returns = returns.group_by_dynamic(
30 index_column=date_col, every="1mo", period="1mo", closed="right", label="right"
31 ).agg([((pl.col(ticker) + 1.0).product() - 1.0).alias(ticker) for ticker in tickers])
33 # Create subplot grid with domain for stats table
34 fig = make_subplots(
35 rows=3,
36 cols=1,
37 shared_xaxes=True,
38 row_heights=[0.5, 0.25, 0.25],
39 subplot_titles=["Cumulative Returns", "Drawdowns", "Monthly Returns"],
40 vertical_spacing=0.05,
41 )
43 # --- Row 1: Cumulative Returns
44 for ticker in tickers:
45 price_col = f"{ticker}_price"
46 fig.add_trace(
47 go.Scatter(
48 x=prices[date_col],
49 y=prices[price_col],
50 mode="lines",
51 name=ticker,
52 legendgroup=ticker,
53 line={"color": colors[ticker], "width": 2},
54 hovertemplate=f"<b>%{{x|%b %Y}}</b><br>{ticker}: %{{y:.2f}}x",
55 showlegend=True,
56 ),
57 row=1,
58 col=1,
59 )
61 # --- Row 2: Drawdowns
62 for ticker in tickers:
63 price_col = f"{ticker}_price"
64 # Calculate drawdowns using polars
65 price_series = prices[price_col]
66 cummax = prices.select(pl.col(price_col).cum_max().alias("cummax"))
67 dd_values = ((price_series - cummax["cummax"]) / cummax["cummax"]).to_list()
69 fig.add_trace(
70 go.Scatter(
71 x=prices[date_col],
72 y=dd_values,
73 mode="lines",
74 fill="tozeroy",
75 fillcolor=colors[f"{ticker}_light"],
76 line={"color": colors[ticker], "width": 1},
77 name=ticker,
78 legendgroup=ticker,
79 hovertemplate=f"{ticker} Drawdown: %{{y:.2%}}",
80 showlegend=False,
81 ),
82 row=2,
83 col=1,
84 )
86 fig.add_hline(y=0, line_width=1, line_color="gray", row=2, col=1)
88 # --- Row 3: Monthly Returns
89 for ticker in tickers:
90 # Get monthly returns values as a list for coloring
91 monthly_values = monthly_returns[ticker].to_list()
93 # If there's only one ticker, use green for positive returns and red for negative returns
94 if len(tickers) == 1:
95 bar_colors = ["green" if val > 0 else "red" for val in monthly_values]
96 else:
97 bar_colors = [colors[ticker] if val > 0 else colors[f"{ticker}_light"] for val in monthly_values]
99 fig.add_trace(
100 go.Bar(
101 x=monthly_returns[date_col],
102 y=monthly_returns[ticker],
103 name=ticker,
104 legendgroup=ticker,
105 marker={
106 "color": bar_colors,
107 "line": {"width": 0},
108 },
109 opacity=0.8,
110 hovertemplate=f"{ticker} Monthly Return: %{{y:.2%}}",
111 showlegend=False,
112 ),
113 row=3,
114 col=1,
115 )
117 # Layout
118 fig.update_layout(
119 title=f"{' vs '.join(tickers)} Performance Dashboard",
120 height=1200,
121 hovermode="x unified",
122 plot_bgcolor="white",
123 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
124 xaxis={
125 "rangeselector": {
126 "buttons": [
127 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"},
128 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"},
129 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"},
130 {"step": "year", "stepmode": "todate", "label": "YTD"},
131 {"step": "all", "label": "All"},
132 ]
133 },
134 "rangeslider": {"visible": False},
135 "type": "date",
136 },
137 )
139 fig.update_yaxes(title_text="Cumulative Return", row=1, col=1, tickformat=".2f")
140 fig.update_yaxes(title_text="Drawdown", row=2, col=1, tickformat=".0%")
141 fig.update_yaxes(title_text="Monthly Return", row=3, col=1, tickformat=".0%")
143 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
144 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
146 if log_scale:
147 fig.update_yaxes(type="log", row=1, col=1)
149 return fig
152@dataclasses.dataclass(frozen=True)
153class Plots:
154 """Visualization tools for financial returns data.
156 This class provides methods for creating various plots and visualizations
157 of financial returns data, including:
159 - Returns bar charts
160 - Portfolio performance snapshots
161 - Monthly returns heatmaps
163 The class is designed to work with the _Data class and uses Plotly
164 for creating interactive visualizations.
166 Attributes:
167 data: The _Data object containing returns and benchmark data to visualize.
169 """
171 data: "Data" # type: ignore
173 def plot_snapshot(self, title: str = "Portfolio Summary", log_scale: bool = False) -> go.Figure:
174 """Create a comprehensive dashboard with multiple plots for portfolio analysis.
176 This function generates a three-panel plot showing:
177 1. Cumulative returns over time
178 2. Drawdowns over time
179 3. Daily returns over time
181 This provides a complete visual summary of portfolio performance.
183 Args:
184 title (str, optional): Title of the plot. Defaults to "Portfolio Summary".
185 compounded (bool, optional): Whether to use compounded returns. Defaults to True.
186 log_scale (bool, optional): Whether to use logarithmic scale for cumulative returns.
187 Defaults to False.
189 Returns:
190 go.Figure: A Plotly figure object containing the dashboard.
192 Example:
193 >>> from jquantstats._data import Data
194 >>> import polars as pl
195 >>> returns = pl.DataFrame(...)
196 >>> data = Data(returns=returns)
197 >>> fig = data.plots.plot_snapshot(title="My Portfolio Performance")
198 >>> fig.show()
200 """
201 fig = _plot_performance_dashboard(returns=self.data.all, log_scale=log_scale)
202 return fig