Coverage for src/tinycta/hyper/_study.py: 98%

56 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-06 05:36 +0000

1"""Frozen Study result and Optuna-based hyperparameter optimisation.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6from dataclasses import dataclass, field 

7from pathlib import Path 

8 

9import optuna 

10from jquantstats import Portfolio 

11 

12 

13@dataclass(frozen=True) 

14class Study: 

15 """Frozen wrapper around a completed Optuna study.""" 

16 

17 best_params: dict 

18 best_value: float 

19 n_completed: int 

20 n_trials: int 

21 optuna_study: optuna.Study = field(repr=False) 

22 

23 def __str__(self) -> str: 

24 """Return a human-readable summary of the best trial.""" 

25 if self.n_completed == 0: 

26 return "No completed trials — all returned NaN Sharpe." 

27 lines = ["=== Best parameters ==="] 

28 for k, v in self.best_params.items(): 

29 lines.append(f" {k:<12} = {v}") 

30 lines.append(f" {'Sharpe':<12} = {self.best_value:.4f}") 

31 lines.append(f" {'Completed':<12} = {self.n_completed} / {self.n_trials} trials") 

32 return "\n".join(lines) 

33 

34 @classmethod 

35 def from_optuna(cls, s: optuna.Study) -> Study: 

36 """Wrap a completed optuna.Study in a frozen Study.""" 

37 n_completed = sum(1 for t in s.trials if t.state == optuna.trial.TrialState.COMPLETE) 

38 if n_completed == 0: 

39 best_params, best_value = {}, float("nan") 

40 else: 

41 best_params, best_value = s.best_params, s.best_value 

42 return cls( 

43 best_params=best_params, 

44 best_value=best_value, 

45 n_completed=n_completed, 

46 n_trials=len(s.trials), 

47 optuna_study=s, 

48 ) 

49 

50 def plot(self, output_dir: Path) -> None: 

51 """Write Optuna visualisation plots to output_dir (HTML, PNG if kaleido available).""" 

52 output_dir.mkdir(parents=True, exist_ok=True) 

53 figures = { 

54 "optuna_history": optuna.visualization.plot_optimization_history(self.optuna_study), 

55 "optuna_importance": optuna.visualization.plot_param_importances(self.optuna_study), 

56 "optuna_parallel": optuna.visualization.plot_parallel_coordinate(self.optuna_study), 

57 "optuna_contour": optuna.visualization.plot_contour(self.optuna_study), 

58 } 

59 for name, fig in figures.items(): 

60 fig.write_html(str(output_dir / f"{name}.html")) 

61 with contextlib.suppress(Exception): 

62 fig.write_image(str(output_dir / f"{name}.png"), scale=2) 

63 

64 

65def _sharpe(portfolio: Portfolio) -> float: 

66 """Compute Sharpe ratio, raising TrialPruned if the result is NaN or None.""" 

67 result = portfolio.stats.sharpe() 

68 sharpe = result["returns"] if isinstance(result, dict) else float(result) 

69 if sharpe is None or sharpe != sharpe: 

70 raise optuna.exceptions.TrialPruned() 

71 return sharpe 

72 

73 

74def _run_study( 

75 objective, 

76 *, 

77 n_trials: int = 100, 

78 seed: int = 42, 

79 name: str | None = None, 

80) -> optuna.Study: 

81 """Create and run an Optuna study, returning the optuna.Study.""" 

82 optuna.logging.set_verbosity(optuna.logging.WARNING) 

83 s = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=seed), study_name=name) 

84 s.optimize(objective, n_trials=n_trials, show_progress_bar=False) 

85 return s 

86 

87 

88def _build_objective(suggest_portfolio_fn): 

89 """Objective factory: wraps a portfolio-returning function with Sharpe scoring.""" 

90 

91 def objective(trial: optuna.Trial) -> float: 

92 """Call suggest_portfolio_fn and return the Sharpe ratio.""" 

93 return _sharpe(suggest_portfolio_fn(trial)) 

94 

95 return objective 

96 

97 

98def optimize(suggest_portfolio_fn, n_trials: int = 100, seed: int = 42) -> Study: 

99 """Build objective, run study, print and return a frozen Study.""" 

100 s = _run_study(_build_objective(suggest_portfolio_fn), n_trials=n_trials, seed=seed) 

101 study = Study.from_optuna(s) 

102 print(study) 

103 return study