Coverage for src/tinycta/hyper/_study.py: 98%
56 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-06 05:36 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-06 05:36 +0000
1"""Frozen Study result and Optuna-based hyperparameter optimisation."""
3from __future__ import annotations
5import contextlib
6from dataclasses import dataclass, field
7from pathlib import Path
9import optuna
10from jquantstats import Portfolio
13@dataclass(frozen=True)
14class Study:
15 """Frozen wrapper around a completed Optuna study."""
17 best_params: dict
18 best_value: float
19 n_completed: int
20 n_trials: int
21 optuna_study: optuna.Study = field(repr=False)
23 def __str__(self) -> str:
24 """Return a human-readable summary of the best trial."""
25 if self.n_completed == 0:
26 return "No completed trials — all returned NaN Sharpe."
27 lines = ["=== Best parameters ==="]
28 for k, v in self.best_params.items():
29 lines.append(f" {k:<12} = {v}")
30 lines.append(f" {'Sharpe':<12} = {self.best_value:.4f}")
31 lines.append(f" {'Completed':<12} = {self.n_completed} / {self.n_trials} trials")
32 return "\n".join(lines)
34 @classmethod
35 def from_optuna(cls, s: optuna.Study) -> Study:
36 """Wrap a completed optuna.Study in a frozen Study."""
37 n_completed = sum(1 for t in s.trials if t.state == optuna.trial.TrialState.COMPLETE)
38 if n_completed == 0:
39 best_params, best_value = {}, float("nan")
40 else:
41 best_params, best_value = s.best_params, s.best_value
42 return cls(
43 best_params=best_params,
44 best_value=best_value,
45 n_completed=n_completed,
46 n_trials=len(s.trials),
47 optuna_study=s,
48 )
50 def plot(self, output_dir: Path) -> None:
51 """Write Optuna visualisation plots to output_dir (HTML, PNG if kaleido available)."""
52 output_dir.mkdir(parents=True, exist_ok=True)
53 figures = {
54 "optuna_history": optuna.visualization.plot_optimization_history(self.optuna_study),
55 "optuna_importance": optuna.visualization.plot_param_importances(self.optuna_study),
56 "optuna_parallel": optuna.visualization.plot_parallel_coordinate(self.optuna_study),
57 "optuna_contour": optuna.visualization.plot_contour(self.optuna_study),
58 }
59 for name, fig in figures.items():
60 fig.write_html(str(output_dir / f"{name}.html"))
61 with contextlib.suppress(Exception):
62 fig.write_image(str(output_dir / f"{name}.png"), scale=2)
65def _sharpe(portfolio: Portfolio) -> float:
66 """Compute Sharpe ratio, raising TrialPruned if the result is NaN or None."""
67 result = portfolio.stats.sharpe()
68 sharpe = result["returns"] if isinstance(result, dict) else float(result)
69 if sharpe is None or sharpe != sharpe:
70 raise optuna.exceptions.TrialPruned()
71 return sharpe
74def _run_study(
75 objective,
76 *,
77 n_trials: int = 100,
78 seed: int = 42,
79 name: str | None = None,
80) -> optuna.Study:
81 """Create and run an Optuna study, returning the optuna.Study."""
82 optuna.logging.set_verbosity(optuna.logging.WARNING)
83 s = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=seed), study_name=name)
84 s.optimize(objective, n_trials=n_trials, show_progress_bar=False)
85 return s
88def _build_objective(suggest_portfolio_fn):
89 """Objective factory: wraps a portfolio-returning function with Sharpe scoring."""
91 def objective(trial: optuna.Trial) -> float:
92 """Call suggest_portfolio_fn and return the Sharpe ratio."""
93 return _sharpe(suggest_portfolio_fn(trial))
95 return objective
98def optimize(suggest_portfolio_fn, n_trials: int = 100, seed: int = 42) -> Study:
99 """Build objective, run study, print and return a frozen Study."""
100 s = _run_study(_build_objective(suggest_portfolio_fn), n_trials=n_trials, seed=seed)
101 study = Study.from_optuna(s)
102 print(study)
103 return study