Coverage for src / jquantstats / _plots / _portfolio.py: 100%
173 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-26 18:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-26 18:44 +0000
1"""Plotting utilities for portfolio analytics using Plotly.
3This module defines the PortfolioPlots facade which renders common portfolio visuals
4such as snapshots, lagged performance curves, smoothed-holdings curves, and
5lead/lag information ratio bar charts. Designed for notebook use.
7Examples:
8 >>> import dataclasses
9 >>> from jquantstats._plots import PortfolioPlots
10 >>> dataclasses.is_dataclass(PortfolioPlots)
11 True
12"""
14from __future__ import annotations
16import dataclasses
17from typing import TYPE_CHECKING
19import plotly.express as px
20import plotly.graph_objects as go
21import plotly.io as pio
22import polars as pl
23from plotly.subplots import make_subplots
25if TYPE_CHECKING:
26 from ._protocol import PortfolioLike
28# Ensure Plotly works with Marimo (set after imports to satisfy linters)
29pio.renderers.default = "plotly_mimetype"
32@dataclasses.dataclass(frozen=True)
33class PortfolioPlots:
34 """Facade for portfolio plots built with Plotly.
36 Provides convenience methods to visualize portfolio performance and
37 diagnostics directly from a Portfolio instance (e.g., snapshot charts,
38 lagged performance, smoothed holdings, and lead/lag IR).
39 """
41 portfolio: PortfolioLike
43 def lead_lag_ir_plot(self, start: int = -10, end: int = 19) -> go.Figure:
44 """Plot Sharpe ratio (IR) across lead/lag variants of the portfolio.
46 Builds portfolios with cash positions lagged from ``start`` to ``end``
47 (inclusive) and plots a bar chart of the Sharpe ratio for each lag.
48 Positive lags delay weights; negative lags lead them.
50 Args:
51 start: First lag to include (default: -10).
52 end: Last lag to include (default: +19).
54 Returns:
55 A Plotly Figure with one bar per lag labeled by the lag value.
56 """
57 if not isinstance(start, int) or not isinstance(end, int):
58 raise TypeError
59 if start > end:
60 start, end = end, start
62 lags = list(range(start, end + 1))
64 x_vals: list[int] = []
65 y_vals: list[float] = []
67 for n in lags:
68 pf = self.portfolio if n == 0 else self.portfolio.lag(n)
69 # Compute Sharpe on the portfolio's returns series
70 sharpe_val = pf.stats.sharpe().get("returns", float("nan")) # type: ignore[union-attr]
71 # Ensure a float (Stats returns mapping asset->value)
72 y_vals.append(float(sharpe_val) if sharpe_val is not None else float("nan"))
73 x_vals.append(n)
75 colors = ["red" if x == 0 else "#1f77b4" for x in x_vals]
76 fig = go.Figure(
77 data=[
78 go.Bar(x=x_vals, y=y_vals, name="Sharpe by lag", marker_color=colors),
79 ]
80 )
81 fig.update_layout(
82 title="Lead/Lag Information Ratio (Sharpe) by Lag",
83 xaxis_title="Lag (steps)",
84 yaxis_title="Sharpe ratio",
85 plot_bgcolor="white",
86 hovermode="x",
87 )
88 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
89 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
90 return fig
92 def snapshot(self, log_scale: bool = False) -> go.Figure:
93 """Return a snapshot dashboard of NAV and drawdown.
95 When the portfolio has a non-zero ``cost_model.cost_per_unit``, an additional
96 ``"Net-of-Cost NAV"`` trace is overlaid on the NAV panel showing the
97 realised NAV path after deducting position-delta trading costs.
99 Args:
100 log_scale (bool, optional): If True, display NAV on a log scale. Defaults to False.
102 Returns:
103 plotly.graph_objects.Figure: A Figure with accumulated NAV (including tilt/timing)
104 and drawdown shaded area, equipped with a range selector.
105 """
106 # Create subplot grid with domain for stats table
107 fig = make_subplots(
108 rows=2,
109 cols=1,
110 shared_xaxes=True,
111 row_heights=[0.66, 0.33],
112 subplot_titles=["Accumulated Profit", "Drawdown"],
113 vertical_spacing=0.05,
114 )
116 # --- Row 1: Cumulative Returns
117 fig.add_trace(
118 go.Scatter(
119 x=self.portfolio.nav_accumulated["date"],
120 y=self.portfolio.nav_accumulated["NAV_accumulated"],
121 mode="lines",
122 name="NAV",
123 showlegend=False,
124 ),
125 row=1,
126 col=1,
127 )
129 fig.add_trace(
130 go.Scatter(
131 x=self.portfolio.tilt.nav_accumulated["date"],
132 y=self.portfolio.tilt.nav_accumulated["NAV_accumulated"],
133 mode="lines",
134 name="Tilt",
135 showlegend=False,
136 ),
137 row=1,
138 col=1,
139 )
141 fig.add_trace(
142 go.Scatter(
143 x=self.portfolio.timing.nav_accumulated["date"],
144 y=self.portfolio.timing.nav_accumulated["NAV_accumulated"],
145 mode="lines",
146 name="Timing",
147 showlegend=False,
148 ),
149 row=1,
150 col=1,
151 )
153 # Net-of-cost NAV overlay (only when a cost model is active)
154 if self.portfolio.cost_model.cost_per_unit > 0:
155 net_nav_df = self.portfolio.net_cost_nav
156 x_dates = net_nav_df["date"] if "date" in net_nav_df.columns else None
157 fig.add_trace(
158 go.Scatter(
159 x=x_dates,
160 y=net_nav_df["NAV_accumulated_net"],
161 mode="lines",
162 name="Net-of-Cost NAV",
163 line={"dash": "dash"},
164 showlegend=True,
165 ),
166 row=1,
167 col=1,
168 )
170 fig.add_trace(
171 go.Scatter(
172 x=self.portfolio.drawdown["date"],
173 y=self.portfolio.drawdown["drawdown_pct"],
174 mode="lines",
175 fill="tozeroy",
176 name="Drawdown",
177 showlegend=False,
178 ),
179 row=2,
180 col=1,
181 )
183 fig.add_hline(y=0, line_width=1, line_color="gray", row=2, col=1)
185 # Layout
186 fig.update_layout(
187 title="Performance Dashboard",
188 height=1200,
189 hovermode="x unified",
190 plot_bgcolor="white",
191 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
192 xaxis={
193 "rangeselector": {
194 "buttons": [
195 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"},
196 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"},
197 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"},
198 {"step": "year", "stepmode": "todate", "label": "YTD"},
199 {"step": "all", "label": "All"},
200 ]
201 },
202 "rangeslider": {"visible": False},
203 "type": "date",
204 },
205 )
207 fig.update_yaxes(title_text="NAV (accumulated)", row=1, col=1, tickformat=".2s")
208 fig.update_yaxes(title_text="Drawdown", row=2, col=1, tickformat=".0%")
210 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
211 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
213 if log_scale:
214 fig.update_yaxes(type="log", row=1, col=1)
215 # Ensure the first y-axis is explicitly set for environments
216 # where subplot updates may not propagate to layout alias.
217 if hasattr(fig.layout, "yaxis"):
218 fig.layout.yaxis.type = "log"
220 return fig
222 @staticmethod
223 def _apply_nav_layout(fig: go.Figure, title: str, log_scale: bool = False) -> None:
224 """Apply common NAV-accumulated layout to *fig* in-place.
226 Configures the plot background, legend, hover mode, x-axis date range
227 selector, y-axis label, grid lines, and optional logarithmic y-scale.
228 Shared by :meth:`lagged_performance_plot` and
229 :meth:`smoothed_holdings_performance_plot`.
231 Args:
232 fig: The Plotly Figure to configure.
233 title: Chart title text.
234 log_scale: If True, set the primary y-axis to logarithmic scale.
235 """
236 fig.update_layout(
237 title=title,
238 hovermode="x unified",
239 plot_bgcolor="white",
240 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
241 xaxis={
242 "rangeselector": {
243 "buttons": [
244 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"},
245 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"},
246 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"},
247 {"step": "year", "stepmode": "todate", "label": "YTD"},
248 {"step": "all", "label": "All"},
249 ]
250 },
251 "rangeslider": {"visible": False},
252 "type": "date",
253 },
254 )
255 fig.update_yaxes(title_text="NAV (accumulated)")
256 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
257 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
259 if log_scale:
260 fig.update_yaxes(type="log")
261 if hasattr(fig.layout, "yaxis"):
262 fig.layout.yaxis.type = "log"
264 def lagged_performance_plot(self, lags: list[int] | None = None, log_scale: bool = False) -> go.Figure:
265 """Plot NAV_accumulated for multiple lagged portfolios.
267 Creates a Plotly figure with one line per lag value showing the
268 accumulated NAV series for the portfolio with cash positions
269 shifted by that lag. By default, lags [0, 1, 2, 3, 4] are used.
271 Args:
272 lags: A list of integer lags to apply; defaults to [0, 1, 2, 3, 4].
273 log_scale: If True, set the primary y-axis to logarithmic scale.
275 Returns:
276 A Plotly Figure containing one trace per requested lag.
277 """
278 if lags is None:
279 lags = [0, 1, 2, 3, 4]
280 if not isinstance(lags, list) or not all(isinstance(x, int) for x in lags):
281 raise TypeError
283 fig = go.Figure()
284 for lag in lags:
285 pf = self.portfolio if lag == 0 else self.portfolio.lag(lag)
286 nav = pf.nav_accumulated
287 fig.add_trace(
288 go.Scatter(
289 x=nav["date"],
290 y=nav["NAV_accumulated"],
291 mode="lines",
292 name=f"lag {lag}",
293 line={"width": 1},
294 )
295 )
297 self._apply_nav_layout(fig, title="NAV accumulated by lag", log_scale=log_scale)
298 return fig
300 def rolling_sharpe_plot(self, window: int = 63) -> go.Figure:
301 """Plot rolling annualised Sharpe ratio over time.
303 Computes the rolling Sharpe for each asset column using the given
304 window and renders one line per asset.
306 Args:
307 window: Rolling-window size in periods. Defaults to 63.
309 Returns:
310 A Plotly Figure with one trace per asset.
312 Raises:
313 ValueError: If ``window`` is not a positive integer.
314 """
315 if not isinstance(window, int) or window <= 0:
316 raise ValueError
318 rolling = self.portfolio.stats.rolling_sharpe(window=window) # type: ignore[union-attr]
320 fig = go.Figure()
321 date_col = rolling["date"] if "date" in rolling.columns else None
322 for col in rolling.columns:
323 if col == "date":
324 continue
325 fig.add_trace(
326 go.Scatter(
327 x=date_col,
328 y=rolling[col],
329 mode="lines",
330 name=col,
331 line={"width": 1},
332 )
333 )
335 fig.add_hline(y=0, line_width=1, line_dash="dash", line_color="gray")
337 fig.update_layout(
338 title=f"Rolling Sharpe Ratio ({window}-period window)",
339 hovermode="x unified",
340 plot_bgcolor="white",
341 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
342 xaxis={
343 "rangeselector": {
344 "buttons": [
345 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"},
346 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"},
347 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"},
348 {"step": "year", "stepmode": "todate", "label": "YTD"},
349 {"step": "all", "label": "All"},
350 ]
351 },
352 "rangeslider": {"visible": False},
353 "type": "date",
354 },
355 )
356 fig.update_yaxes(title_text="Sharpe ratio")
357 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
358 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
359 return fig
361 def rolling_volatility_plot(self, window: int = 63) -> go.Figure:
362 """Plot rolling annualised volatility over time.
364 Computes the rolling volatility for each asset column using the given
365 window and renders one line per asset.
367 Args:
368 window: Rolling-window size in periods. Defaults to 63.
370 Returns:
371 A Plotly Figure with one trace per asset.
373 Raises:
374 ValueError: If ``window`` is not a positive integer.
375 """
376 if not isinstance(window, int) or window <= 0:
377 raise ValueError
379 rolling = self.portfolio.stats.rolling_volatility(window=window) # type: ignore[union-attr]
381 fig = go.Figure()
382 date_col = rolling["date"] if "date" in rolling.columns else None
383 for col in rolling.columns:
384 if col == "date":
385 continue
386 fig.add_trace(
387 go.Scatter(
388 x=date_col,
389 y=rolling[col],
390 mode="lines",
391 name=col,
392 line={"width": 1},
393 )
394 )
396 fig.update_layout(
397 title=f"Rolling Volatility ({window}-period window)",
398 hovermode="x unified",
399 plot_bgcolor="white",
400 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
401 xaxis={
402 "rangeselector": {
403 "buttons": [
404 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"},
405 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"},
406 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"},
407 {"step": "year", "stepmode": "todate", "label": "YTD"},
408 {"step": "all", "label": "All"},
409 ]
410 },
411 "rangeslider": {"visible": False},
412 "type": "date",
413 },
414 )
415 fig.update_yaxes(title_text="Annualised volatility")
416 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
417 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
418 return fig
420 def annual_sharpe_plot(self) -> go.Figure:
421 """Plot annualised Sharpe ratio broken down by calendar year.
423 Computes the Sharpe ratio for each calendar year from the portfolio
424 returns and renders a grouped bar chart with one bar per year per
425 asset.
427 Returns:
428 A Plotly Figure with one bar group per asset.
429 """
430 breakdown = self.portfolio.stats.annual_breakdown() # type: ignore[union-attr]
432 # Extract the sharpe row for each year
433 sharpe_rows = breakdown.filter(pl.col("metric") == "sharpe")
434 asset_cols = [c for c in sharpe_rows.columns if c not in ("year", "metric")]
436 fig = go.Figure()
437 for asset in asset_cols:
438 fig.add_trace(
439 go.Bar(
440 x=sharpe_rows["year"],
441 y=sharpe_rows[asset],
442 name=asset,
443 )
444 )
446 fig.add_hline(y=0, line_width=1, line_color="gray")
448 fig.update_layout(
449 title="Annual Sharpe Ratio by Year",
450 barmode="group",
451 hovermode="x unified",
452 plot_bgcolor="white",
453 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
454 )
455 fig.update_yaxes(title_text="Sharpe ratio")
456 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey", title_text="Year")
457 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
458 return fig
460 def correlation_heatmap(
461 self,
462 frame: pl.DataFrame | None = None,
463 name: str = "portfolio",
464 title: str = "Correlation heatmap",
465 ) -> go.Figure:
466 """Plot a correlation heatmap for assets and the portfolio.
468 If ``frame`` is None, uses the portfolio's prices. The portfolio's
469 profit series is appended under ``name`` before computing the
470 correlation matrix.
472 Args:
473 frame: Optional Polars DataFrame with at least the asset price
474 columns. If omitted, uses ``self.portfolio.prices``.
475 name: Column name under which to include the portfolio profit.
476 title: Plot title.
478 Returns:
479 A Plotly Figure rendering the correlation matrix as a heatmap.
480 """
481 if frame is None:
482 frame = self.portfolio.prices
484 corr = self.portfolio.correlation(frame, name=name)
486 # Create an interactive heatmap
487 fig = px.imshow(
488 corr,
489 x=corr.columns,
490 y=corr.columns,
491 text_auto=".2f", # show correlation values
492 color_continuous_scale="RdBu_r", # red-blue diverging colormap
493 zmin=-1,
494 zmax=1, # correlation range
495 title=title,
496 )
498 # Adjust layout
499 fig.update_layout(
500 xaxis_title="", yaxis_title="", width=700, height=600, coloraxis_colorbar={"title": "Correlation"}
501 )
503 return fig
505 def monthly_returns_heatmap(self) -> go.Figure:
506 """Plot a monthly returns calendar heatmap.
508 Groups portfolio returns by calendar year and month, then renders a
509 Plotly heatmap with months on the x-axis and years on the y-axis.
510 Green cells indicate positive months; red cells indicate negative
511 months. Cell text shows the percentage return for that month.
513 Returns:
514 A Plotly Figure with a calendar heatmap of monthly returns.
516 Raises:
517 ValueError: If the portfolio has no ``date`` column.
518 """
519 monthly = self.portfolio.monthly
521 years = monthly["year"].unique().sort().to_list()
522 month_names = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
524 z: list[list[float | None]] = []
525 text: list[list[str]] = []
526 for year in years:
527 year_data = monthly.filter(pl.col("year") == year)
528 year_row: list[float | None] = []
529 year_text: list[str] = []
530 for m in range(1, 13):
531 month_data = year_data.filter(pl.col("month") == m)
532 if month_data.is_empty():
533 year_row.append(None)
534 year_text.append("")
535 else:
536 ret = float(month_data["returns"][0])
537 year_row.append(ret * 100.0)
538 year_text.append(f"{ret * 100.0:.1f}%")
539 z.append(year_row)
540 text.append(year_text)
542 fig = go.Figure(
543 data=go.Heatmap(
544 z=z,
545 x=month_names,
546 y=[str(y) for y in years],
547 text=text,
548 texttemplate="%{text}",
549 colorscale="RdYlGn",
550 zmid=0,
551 colorbar={"title": "Return (%)"},
552 hovertemplate="<b>%{y} %{x}</b><br>Return: %{text}<extra></extra>",
553 )
554 )
556 fig.update_layout(
557 title="Monthly Returns Heatmap",
558 xaxis_title="Month",
559 yaxis_title="Year",
560 plot_bgcolor="white",
561 yaxis={"type": "category"},
562 )
564 return fig
566 def smoothed_holdings_performance_plot(
567 self,
568 windows: list[int] | None = None,
569 log_scale: bool = False,
570 ) -> go.Figure:
571 """Plot NAV_accumulated for smoothed-holding portfolios.
573 Builds portfolios with cash positions smoothed by a trailing rolling
574 mean over the previous ``n`` steps (window size n+1) for n in
575 ``windows`` (defaults to [0, 1, 2, 3, 4]) and plots their
576 accumulated NAV curves.
578 Args:
579 windows: List of non-negative integers specifying smoothing steps
580 to include; defaults to [0, 1, 2, 3, 4].
581 log_scale: If True, set the primary y-axis to logarithmic scale.
583 Returns:
584 A Plotly Figure containing one line per requested smoothing level.
585 """
586 if windows is None:
587 windows = [0, 1, 2, 3, 4]
588 if not isinstance(windows, list) or not all(isinstance(x, int) and x >= 0 for x in windows):
589 raise TypeError
591 fig = go.Figure()
592 for n in windows:
593 pf = self.portfolio if n == 0 else self.portfolio.smoothed_holding(n)
594 nav = pf.nav_accumulated
595 fig.add_trace(
596 go.Scatter(
597 x=nav["date"],
598 y=nav["NAV_accumulated"],
599 mode="lines",
600 name=f"smooth {n}",
601 line={"width": 1},
602 )
603 )
605 self._apply_nav_layout(fig, title="NAV accumulated by smoothed holdings", log_scale=log_scale)
606 return fig
608 def trading_cost_impact_plot(self, max_bps: int = 20) -> go.Figure:
609 """Plot the Sharpe ratio as a function of one-way trading costs.
611 Evaluates the portfolio's annualised Sharpe ratio at each integer
612 cost level from 0 up to ``max_bps`` basis points and renders the
613 result as a line chart. The zero-cost Sharpe is shown as a
614 reference horizontal line so that the reader can quickly gauge
615 at what cost level the strategy's edge is eroded.
617 Args:
618 max_bps: Maximum one-way trading cost to evaluate, in basis
619 points. Defaults to 20.
621 Returns:
622 A Plotly Figure with one line trace showing Sharpe vs. cost.
624 Raises:
625 ValueError: If ``max_bps`` is not a positive integer.
626 """
627 impact = self.portfolio.trading_cost_impact(max_bps=max_bps)
629 cost_vals = impact["cost_bps"].to_list()
630 sharpe_vals = impact["sharpe"].to_list()
632 # Baseline Sharpe at zero cost
633 baseline = float(sharpe_vals[0]) if sharpe_vals and sharpe_vals[0] is not None else float("nan")
635 fig = go.Figure()
636 fig.add_trace(
637 go.Scatter(
638 x=cost_vals,
639 y=sharpe_vals,
640 mode="lines+markers",
641 name="Sharpe (cost-adjusted)",
642 marker={"size": 6},
643 line={"width": 2, "color": "#1f77b4"},
644 )
645 )
646 if baseline == baseline: # only add when baseline is finite (NaN != NaN)
647 fig.add_hline(
648 y=baseline,
649 line_width=1,
650 line_dash="dash",
651 line_color="gray",
652 annotation_text="0 bps baseline",
653 annotation_position="top right",
654 )
656 fig.update_layout(
657 title=f"Trading Cost Impact on Sharpe Ratio (0\u2013{max_bps} bps)",
658 hovermode="x unified",
659 plot_bgcolor="white",
660 )
661 fig.update_xaxes(
662 title_text="One-way cost (basis points)",
663 showgrid=True,
664 gridwidth=0.5,
665 gridcolor="lightgrey",
666 dtick=1,
667 )
668 fig.update_yaxes(
669 title_text="Annualised Sharpe ratio",
670 showgrid=True,
671 gridwidth=0.5,
672 gridcolor="lightgrey",
673 )
674 return fig