Coverage for src / jquantstats / _plots / _portfolio.py: 100%

173 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-26 18:44 +0000

1"""Plotting utilities for portfolio analytics using Plotly. 

2 

3This module defines the PortfolioPlots facade which renders common portfolio visuals 

4such as snapshots, lagged performance curves, smoothed-holdings curves, and 

5lead/lag information ratio bar charts. Designed for notebook use. 

6 

7Examples: 

8 >>> import dataclasses 

9 >>> from jquantstats._plots import PortfolioPlots 

10 >>> dataclasses.is_dataclass(PortfolioPlots) 

11 True 

12""" 

13 

14from __future__ import annotations 

15 

16import dataclasses 

17from typing import TYPE_CHECKING 

18 

19import plotly.express as px 

20import plotly.graph_objects as go 

21import plotly.io as pio 

22import polars as pl 

23from plotly.subplots import make_subplots 

24 

25if TYPE_CHECKING: 

26 from ._protocol import PortfolioLike 

27 

28# Ensure Plotly works with Marimo (set after imports to satisfy linters) 

29pio.renderers.default = "plotly_mimetype" 

30 

31 

32@dataclasses.dataclass(frozen=True) 

33class PortfolioPlots: 

34 """Facade for portfolio plots built with Plotly. 

35 

36 Provides convenience methods to visualize portfolio performance and 

37 diagnostics directly from a Portfolio instance (e.g., snapshot charts, 

38 lagged performance, smoothed holdings, and lead/lag IR). 

39 """ 

40 

41 portfolio: PortfolioLike 

42 

43 def lead_lag_ir_plot(self, start: int = -10, end: int = 19) -> go.Figure: 

44 """Plot Sharpe ratio (IR) across lead/lag variants of the portfolio. 

45 

46 Builds portfolios with cash positions lagged from ``start`` to ``end`` 

47 (inclusive) and plots a bar chart of the Sharpe ratio for each lag. 

48 Positive lags delay weights; negative lags lead them. 

49 

50 Args: 

51 start: First lag to include (default: -10). 

52 end: Last lag to include (default: +19). 

53 

54 Returns: 

55 A Plotly Figure with one bar per lag labeled by the lag value. 

56 """ 

57 if not isinstance(start, int) or not isinstance(end, int): 

58 raise TypeError 

59 if start > end: 

60 start, end = end, start 

61 

62 lags = list(range(start, end + 1)) 

63 

64 x_vals: list[int] = [] 

65 y_vals: list[float] = [] 

66 

67 for n in lags: 

68 pf = self.portfolio if n == 0 else self.portfolio.lag(n) 

69 # Compute Sharpe on the portfolio's returns series 

70 sharpe_val = pf.stats.sharpe().get("returns", float("nan")) # type: ignore[union-attr] 

71 # Ensure a float (Stats returns mapping asset->value) 

72 y_vals.append(float(sharpe_val) if sharpe_val is not None else float("nan")) 

73 x_vals.append(n) 

74 

75 colors = ["red" if x == 0 else "#1f77b4" for x in x_vals] 

76 fig = go.Figure( 

77 data=[ 

78 go.Bar(x=x_vals, y=y_vals, name="Sharpe by lag", marker_color=colors), 

79 ] 

80 ) 

81 fig.update_layout( 

82 title="Lead/Lag Information Ratio (Sharpe) by Lag", 

83 xaxis_title="Lag (steps)", 

84 yaxis_title="Sharpe ratio", 

85 plot_bgcolor="white", 

86 hovermode="x", 

87 ) 

88 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

89 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

90 return fig 

91 

92 def snapshot(self, log_scale: bool = False) -> go.Figure: 

93 """Return a snapshot dashboard of NAV and drawdown. 

94 

95 When the portfolio has a non-zero ``cost_model.cost_per_unit``, an additional 

96 ``"Net-of-Cost NAV"`` trace is overlaid on the NAV panel showing the 

97 realised NAV path after deducting position-delta trading costs. 

98 

99 Args: 

100 log_scale (bool, optional): If True, display NAV on a log scale. Defaults to False. 

101 

102 Returns: 

103 plotly.graph_objects.Figure: A Figure with accumulated NAV (including tilt/timing) 

104 and drawdown shaded area, equipped with a range selector. 

105 """ 

106 # Create subplot grid with domain for stats table 

107 fig = make_subplots( 

108 rows=2, 

109 cols=1, 

110 shared_xaxes=True, 

111 row_heights=[0.66, 0.33], 

112 subplot_titles=["Accumulated Profit", "Drawdown"], 

113 vertical_spacing=0.05, 

114 ) 

115 

116 # --- Row 1: Cumulative Returns 

117 fig.add_trace( 

118 go.Scatter( 

119 x=self.portfolio.nav_accumulated["date"], 

120 y=self.portfolio.nav_accumulated["NAV_accumulated"], 

121 mode="lines", 

122 name="NAV", 

123 showlegend=False, 

124 ), 

125 row=1, 

126 col=1, 

127 ) 

128 

129 fig.add_trace( 

130 go.Scatter( 

131 x=self.portfolio.tilt.nav_accumulated["date"], 

132 y=self.portfolio.tilt.nav_accumulated["NAV_accumulated"], 

133 mode="lines", 

134 name="Tilt", 

135 showlegend=False, 

136 ), 

137 row=1, 

138 col=1, 

139 ) 

140 

141 fig.add_trace( 

142 go.Scatter( 

143 x=self.portfolio.timing.nav_accumulated["date"], 

144 y=self.portfolio.timing.nav_accumulated["NAV_accumulated"], 

145 mode="lines", 

146 name="Timing", 

147 showlegend=False, 

148 ), 

149 row=1, 

150 col=1, 

151 ) 

152 

153 # Net-of-cost NAV overlay (only when a cost model is active) 

154 if self.portfolio.cost_model.cost_per_unit > 0: 

155 net_nav_df = self.portfolio.net_cost_nav 

156 x_dates = net_nav_df["date"] if "date" in net_nav_df.columns else None 

157 fig.add_trace( 

158 go.Scatter( 

159 x=x_dates, 

160 y=net_nav_df["NAV_accumulated_net"], 

161 mode="lines", 

162 name="Net-of-Cost NAV", 

163 line={"dash": "dash"}, 

164 showlegend=True, 

165 ), 

166 row=1, 

167 col=1, 

168 ) 

169 

170 fig.add_trace( 

171 go.Scatter( 

172 x=self.portfolio.drawdown["date"], 

173 y=self.portfolio.drawdown["drawdown_pct"], 

174 mode="lines", 

175 fill="tozeroy", 

176 name="Drawdown", 

177 showlegend=False, 

178 ), 

179 row=2, 

180 col=1, 

181 ) 

182 

183 fig.add_hline(y=0, line_width=1, line_color="gray", row=2, col=1) 

184 

185 # Layout 

186 fig.update_layout( 

187 title="Performance Dashboard", 

188 height=1200, 

189 hovermode="x unified", 

190 plot_bgcolor="white", 

191 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}, 

192 xaxis={ 

193 "rangeselector": { 

194 "buttons": [ 

195 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"}, 

196 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"}, 

197 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"}, 

198 {"step": "year", "stepmode": "todate", "label": "YTD"}, 

199 {"step": "all", "label": "All"}, 

200 ] 

201 }, 

202 "rangeslider": {"visible": False}, 

203 "type": "date", 

204 }, 

205 ) 

206 

207 fig.update_yaxes(title_text="NAV (accumulated)", row=1, col=1, tickformat=".2s") 

208 fig.update_yaxes(title_text="Drawdown", row=2, col=1, tickformat=".0%") 

209 

210 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

211 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

212 

213 if log_scale: 

214 fig.update_yaxes(type="log", row=1, col=1) 

215 # Ensure the first y-axis is explicitly set for environments 

216 # where subplot updates may not propagate to layout alias. 

217 if hasattr(fig.layout, "yaxis"): 

218 fig.layout.yaxis.type = "log" 

219 

220 return fig 

221 

222 @staticmethod 

223 def _apply_nav_layout(fig: go.Figure, title: str, log_scale: bool = False) -> None: 

224 """Apply common NAV-accumulated layout to *fig* in-place. 

225 

226 Configures the plot background, legend, hover mode, x-axis date range 

227 selector, y-axis label, grid lines, and optional logarithmic y-scale. 

228 Shared by :meth:`lagged_performance_plot` and 

229 :meth:`smoothed_holdings_performance_plot`. 

230 

231 Args: 

232 fig: The Plotly Figure to configure. 

233 title: Chart title text. 

234 log_scale: If True, set the primary y-axis to logarithmic scale. 

235 """ 

236 fig.update_layout( 

237 title=title, 

238 hovermode="x unified", 

239 plot_bgcolor="white", 

240 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}, 

241 xaxis={ 

242 "rangeselector": { 

243 "buttons": [ 

244 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"}, 

245 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"}, 

246 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"}, 

247 {"step": "year", "stepmode": "todate", "label": "YTD"}, 

248 {"step": "all", "label": "All"}, 

249 ] 

250 }, 

251 "rangeslider": {"visible": False}, 

252 "type": "date", 

253 }, 

254 ) 

255 fig.update_yaxes(title_text="NAV (accumulated)") 

256 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

257 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

258 

259 if log_scale: 

260 fig.update_yaxes(type="log") 

261 if hasattr(fig.layout, "yaxis"): 

262 fig.layout.yaxis.type = "log" 

263 

264 def lagged_performance_plot(self, lags: list[int] | None = None, log_scale: bool = False) -> go.Figure: 

265 """Plot NAV_accumulated for multiple lagged portfolios. 

266 

267 Creates a Plotly figure with one line per lag value showing the 

268 accumulated NAV series for the portfolio with cash positions 

269 shifted by that lag. By default, lags [0, 1, 2, 3, 4] are used. 

270 

271 Args: 

272 lags: A list of integer lags to apply; defaults to [0, 1, 2, 3, 4]. 

273 log_scale: If True, set the primary y-axis to logarithmic scale. 

274 

275 Returns: 

276 A Plotly Figure containing one trace per requested lag. 

277 """ 

278 if lags is None: 

279 lags = [0, 1, 2, 3, 4] 

280 if not isinstance(lags, list) or not all(isinstance(x, int) for x in lags): 

281 raise TypeError 

282 

283 fig = go.Figure() 

284 for lag in lags: 

285 pf = self.portfolio if lag == 0 else self.portfolio.lag(lag) 

286 nav = pf.nav_accumulated 

287 fig.add_trace( 

288 go.Scatter( 

289 x=nav["date"], 

290 y=nav["NAV_accumulated"], 

291 mode="lines", 

292 name=f"lag {lag}", 

293 line={"width": 1}, 

294 ) 

295 ) 

296 

297 self._apply_nav_layout(fig, title="NAV accumulated by lag", log_scale=log_scale) 

298 return fig 

299 

300 def rolling_sharpe_plot(self, window: int = 63) -> go.Figure: 

301 """Plot rolling annualised Sharpe ratio over time. 

302 

303 Computes the rolling Sharpe for each asset column using the given 

304 window and renders one line per asset. 

305 

306 Args: 

307 window: Rolling-window size in periods. Defaults to 63. 

308 

309 Returns: 

310 A Plotly Figure with one trace per asset. 

311 

312 Raises: 

313 ValueError: If ``window`` is not a positive integer. 

314 """ 

315 if not isinstance(window, int) or window <= 0: 

316 raise ValueError 

317 

318 rolling = self.portfolio.stats.rolling_sharpe(window=window) # type: ignore[union-attr] 

319 

320 fig = go.Figure() 

321 date_col = rolling["date"] if "date" in rolling.columns else None 

322 for col in rolling.columns: 

323 if col == "date": 

324 continue 

325 fig.add_trace( 

326 go.Scatter( 

327 x=date_col, 

328 y=rolling[col], 

329 mode="lines", 

330 name=col, 

331 line={"width": 1}, 

332 ) 

333 ) 

334 

335 fig.add_hline(y=0, line_width=1, line_dash="dash", line_color="gray") 

336 

337 fig.update_layout( 

338 title=f"Rolling Sharpe Ratio ({window}-period window)", 

339 hovermode="x unified", 

340 plot_bgcolor="white", 

341 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}, 

342 xaxis={ 

343 "rangeselector": { 

344 "buttons": [ 

345 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"}, 

346 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"}, 

347 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"}, 

348 {"step": "year", "stepmode": "todate", "label": "YTD"}, 

349 {"step": "all", "label": "All"}, 

350 ] 

351 }, 

352 "rangeslider": {"visible": False}, 

353 "type": "date", 

354 }, 

355 ) 

356 fig.update_yaxes(title_text="Sharpe ratio") 

357 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

358 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

359 return fig 

360 

361 def rolling_volatility_plot(self, window: int = 63) -> go.Figure: 

362 """Plot rolling annualised volatility over time. 

363 

364 Computes the rolling volatility for each asset column using the given 

365 window and renders one line per asset. 

366 

367 Args: 

368 window: Rolling-window size in periods. Defaults to 63. 

369 

370 Returns: 

371 A Plotly Figure with one trace per asset. 

372 

373 Raises: 

374 ValueError: If ``window`` is not a positive integer. 

375 """ 

376 if not isinstance(window, int) or window <= 0: 

377 raise ValueError 

378 

379 rolling = self.portfolio.stats.rolling_volatility(window=window) # type: ignore[union-attr] 

380 

381 fig = go.Figure() 

382 date_col = rolling["date"] if "date" in rolling.columns else None 

383 for col in rolling.columns: 

384 if col == "date": 

385 continue 

386 fig.add_trace( 

387 go.Scatter( 

388 x=date_col, 

389 y=rolling[col], 

390 mode="lines", 

391 name=col, 

392 line={"width": 1}, 

393 ) 

394 ) 

395 

396 fig.update_layout( 

397 title=f"Rolling Volatility ({window}-period window)", 

398 hovermode="x unified", 

399 plot_bgcolor="white", 

400 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}, 

401 xaxis={ 

402 "rangeselector": { 

403 "buttons": [ 

404 {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"}, 

405 {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"}, 

406 {"count": 3, "label": "3y", "step": "year", "stepmode": "backward"}, 

407 {"step": "year", "stepmode": "todate", "label": "YTD"}, 

408 {"step": "all", "label": "All"}, 

409 ] 

410 }, 

411 "rangeslider": {"visible": False}, 

412 "type": "date", 

413 }, 

414 ) 

415 fig.update_yaxes(title_text="Annualised volatility") 

416 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

417 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

418 return fig 

419 

420 def annual_sharpe_plot(self) -> go.Figure: 

421 """Plot annualised Sharpe ratio broken down by calendar year. 

422 

423 Computes the Sharpe ratio for each calendar year from the portfolio 

424 returns and renders a grouped bar chart with one bar per year per 

425 asset. 

426 

427 Returns: 

428 A Plotly Figure with one bar group per asset. 

429 """ 

430 breakdown = self.portfolio.stats.annual_breakdown() # type: ignore[union-attr] 

431 

432 # Extract the sharpe row for each year 

433 sharpe_rows = breakdown.filter(pl.col("metric") == "sharpe") 

434 asset_cols = [c for c in sharpe_rows.columns if c not in ("year", "metric")] 

435 

436 fig = go.Figure() 

437 for asset in asset_cols: 

438 fig.add_trace( 

439 go.Bar( 

440 x=sharpe_rows["year"], 

441 y=sharpe_rows[asset], 

442 name=asset, 

443 ) 

444 ) 

445 

446 fig.add_hline(y=0, line_width=1, line_color="gray") 

447 

448 fig.update_layout( 

449 title="Annual Sharpe Ratio by Year", 

450 barmode="group", 

451 hovermode="x unified", 

452 plot_bgcolor="white", 

453 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}, 

454 ) 

455 fig.update_yaxes(title_text="Sharpe ratio") 

456 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey", title_text="Year") 

457 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

458 return fig 

459 

460 def correlation_heatmap( 

461 self, 

462 frame: pl.DataFrame | None = None, 

463 name: str = "portfolio", 

464 title: str = "Correlation heatmap", 

465 ) -> go.Figure: 

466 """Plot a correlation heatmap for assets and the portfolio. 

467 

468 If ``frame`` is None, uses the portfolio's prices. The portfolio's 

469 profit series is appended under ``name`` before computing the 

470 correlation matrix. 

471 

472 Args: 

473 frame: Optional Polars DataFrame with at least the asset price 

474 columns. If omitted, uses ``self.portfolio.prices``. 

475 name: Column name under which to include the portfolio profit. 

476 title: Plot title. 

477 

478 Returns: 

479 A Plotly Figure rendering the correlation matrix as a heatmap. 

480 """ 

481 if frame is None: 

482 frame = self.portfolio.prices 

483 

484 corr = self.portfolio.correlation(frame, name=name) 

485 

486 # Create an interactive heatmap 

487 fig = px.imshow( 

488 corr, 

489 x=corr.columns, 

490 y=corr.columns, 

491 text_auto=".2f", # show correlation values 

492 color_continuous_scale="RdBu_r", # red-blue diverging colormap 

493 zmin=-1, 

494 zmax=1, # correlation range 

495 title=title, 

496 ) 

497 

498 # Adjust layout 

499 fig.update_layout( 

500 xaxis_title="", yaxis_title="", width=700, height=600, coloraxis_colorbar={"title": "Correlation"} 

501 ) 

502 

503 return fig 

504 

505 def monthly_returns_heatmap(self) -> go.Figure: 

506 """Plot a monthly returns calendar heatmap. 

507 

508 Groups portfolio returns by calendar year and month, then renders a 

509 Plotly heatmap with months on the x-axis and years on the y-axis. 

510 Green cells indicate positive months; red cells indicate negative 

511 months. Cell text shows the percentage return for that month. 

512 

513 Returns: 

514 A Plotly Figure with a calendar heatmap of monthly returns. 

515 

516 Raises: 

517 ValueError: If the portfolio has no ``date`` column. 

518 """ 

519 monthly = self.portfolio.monthly 

520 

521 years = monthly["year"].unique().sort().to_list() 

522 month_names = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] 

523 

524 z: list[list[float | None]] = [] 

525 text: list[list[str]] = [] 

526 for year in years: 

527 year_data = monthly.filter(pl.col("year") == year) 

528 year_row: list[float | None] = [] 

529 year_text: list[str] = [] 

530 for m in range(1, 13): 

531 month_data = year_data.filter(pl.col("month") == m) 

532 if month_data.is_empty(): 

533 year_row.append(None) 

534 year_text.append("") 

535 else: 

536 ret = float(month_data["returns"][0]) 

537 year_row.append(ret * 100.0) 

538 year_text.append(f"{ret * 100.0:.1f}%") 

539 z.append(year_row) 

540 text.append(year_text) 

541 

542 fig = go.Figure( 

543 data=go.Heatmap( 

544 z=z, 

545 x=month_names, 

546 y=[str(y) for y in years], 

547 text=text, 

548 texttemplate="%{text}", 

549 colorscale="RdYlGn", 

550 zmid=0, 

551 colorbar={"title": "Return (%)"}, 

552 hovertemplate="<b>%{y} %{x}</b><br>Return: %{text}<extra></extra>", 

553 ) 

554 ) 

555 

556 fig.update_layout( 

557 title="Monthly Returns Heatmap", 

558 xaxis_title="Month", 

559 yaxis_title="Year", 

560 plot_bgcolor="white", 

561 yaxis={"type": "category"}, 

562 ) 

563 

564 return fig 

565 

566 def smoothed_holdings_performance_plot( 

567 self, 

568 windows: list[int] | None = None, 

569 log_scale: bool = False, 

570 ) -> go.Figure: 

571 """Plot NAV_accumulated for smoothed-holding portfolios. 

572 

573 Builds portfolios with cash positions smoothed by a trailing rolling 

574 mean over the previous ``n`` steps (window size n+1) for n in 

575 ``windows`` (defaults to [0, 1, 2, 3, 4]) and plots their 

576 accumulated NAV curves. 

577 

578 Args: 

579 windows: List of non-negative integers specifying smoothing steps 

580 to include; defaults to [0, 1, 2, 3, 4]. 

581 log_scale: If True, set the primary y-axis to logarithmic scale. 

582 

583 Returns: 

584 A Plotly Figure containing one line per requested smoothing level. 

585 """ 

586 if windows is None: 

587 windows = [0, 1, 2, 3, 4] 

588 if not isinstance(windows, list) or not all(isinstance(x, int) and x >= 0 for x in windows): 

589 raise TypeError 

590 

591 fig = go.Figure() 

592 for n in windows: 

593 pf = self.portfolio if n == 0 else self.portfolio.smoothed_holding(n) 

594 nav = pf.nav_accumulated 

595 fig.add_trace( 

596 go.Scatter( 

597 x=nav["date"], 

598 y=nav["NAV_accumulated"], 

599 mode="lines", 

600 name=f"smooth {n}", 

601 line={"width": 1}, 

602 ) 

603 ) 

604 

605 self._apply_nav_layout(fig, title="NAV accumulated by smoothed holdings", log_scale=log_scale) 

606 return fig 

607 

608 def trading_cost_impact_plot(self, max_bps: int = 20) -> go.Figure: 

609 """Plot the Sharpe ratio as a function of one-way trading costs. 

610 

611 Evaluates the portfolio's annualised Sharpe ratio at each integer 

612 cost level from 0 up to ``max_bps`` basis points and renders the 

613 result as a line chart. The zero-cost Sharpe is shown as a 

614 reference horizontal line so that the reader can quickly gauge 

615 at what cost level the strategy's edge is eroded. 

616 

617 Args: 

618 max_bps: Maximum one-way trading cost to evaluate, in basis 

619 points. Defaults to 20. 

620 

621 Returns: 

622 A Plotly Figure with one line trace showing Sharpe vs. cost. 

623 

624 Raises: 

625 ValueError: If ``max_bps`` is not a positive integer. 

626 """ 

627 impact = self.portfolio.trading_cost_impact(max_bps=max_bps) 

628 

629 cost_vals = impact["cost_bps"].to_list() 

630 sharpe_vals = impact["sharpe"].to_list() 

631 

632 # Baseline Sharpe at zero cost 

633 baseline = float(sharpe_vals[0]) if sharpe_vals and sharpe_vals[0] is not None else float("nan") 

634 

635 fig = go.Figure() 

636 fig.add_trace( 

637 go.Scatter( 

638 x=cost_vals, 

639 y=sharpe_vals, 

640 mode="lines+markers", 

641 name="Sharpe (cost-adjusted)", 

642 marker={"size": 6}, 

643 line={"width": 2, "color": "#1f77b4"}, 

644 ) 

645 ) 

646 if baseline == baseline: # only add when baseline is finite (NaN != NaN) 

647 fig.add_hline( 

648 y=baseline, 

649 line_width=1, 

650 line_dash="dash", 

651 line_color="gray", 

652 annotation_text="0 bps baseline", 

653 annotation_position="top right", 

654 ) 

655 

656 fig.update_layout( 

657 title=f"Trading Cost Impact on Sharpe Ratio (0\u2013{max_bps} bps)", 

658 hovermode="x unified", 

659 plot_bgcolor="white", 

660 ) 

661 fig.update_xaxes( 

662 title_text="One-way cost (basis points)", 

663 showgrid=True, 

664 gridwidth=0.5, 

665 gridcolor="lightgrey", 

666 dtick=1, 

667 ) 

668 fig.update_yaxes( 

669 title_text="Annualised Sharpe ratio", 

670 showgrid=True, 

671 gridwidth=0.5, 

672 gridcolor="lightgrey", 

673 ) 

674 return fig