Coverage for src/tinycta/hyper/_setup.py: 100%
37 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-06 05:36 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-06 05:36 +0000
1"""Experiment setup helpers: logger configuration."""
3import os
4from pathlib import Path
5from typing import Any, NamedTuple
7import yaml
8from loguru import logger
10_FILE_SINKS: dict[str, int] = {}
13class ExperimentConfig(NamedTuple):
14 """Resources bundled for a notebook experiment run."""
16 name: str
17 logger: Any
18 params: dict | None = None
19 optuna: dict | None = None
20 data: dict | None = None
23def _load_yaml(path: Path) -> dict:
24 """Load a YAML file, returning an empty dict if the file does not exist."""
25 if not path.exists():
26 return {}
27 with open(path) as f:
28 return yaml.safe_load(f) or {}
31def get_config(name: str, config_path: Path | str | None = None) -> ExperimentConfig:
32 """Return logger and config sections for an experiment.
34 Accepts either a shared ``config.yml`` or an experiment-specific
35 ``config/{name}.yml``. Paths in the config are resolved relative to the
36 notebooks directory (one level above any ``config/`` subdirectory).
37 ``NOTEBOOK_OUTPUT_FOLDER`` env var overrides the output directory used for
38 the log file sink.
39 """
40 config_path = Path(config_path) if config_path else Path.cwd() / "config.yml"
41 cfg = _load_yaml(config_path)
42 # Resolve paths relative to the notebooks dir, not the config subdir.
43 base = config_path.parent.parent if config_path.parent.name == "config" else config_path.parent
44 sibling = _load_yaml(base / "config" / f"{name}.yml")
46 data = cfg.get("data") or sibling.get("data") or {}
47 params = cfg.get("params") or sibling.get("params") or {}
48 optuna_cfg = cfg.get("optuna") or sibling.get("optuna") or {}
50 env_folder = os.environ.get("NOTEBOOK_OUTPUT_FOLDER")
51 if env_folder:
52 output_dir = Path(env_folder)
53 else:
54 folder = data.get("output_path", "output")
55 output_dir = (base / folder / name).resolve()
56 output_dir.mkdir(parents=True, exist_ok=True)
58 log_path = output_dir / "output.log"
59 key = str(log_path.resolve())
60 if key not in _FILE_SINKS:
61 _FILE_SINKS[key] = logger.add(log_path)
62 logger.info(f"Writing output to: {output_dir}\nCurrent working directory: {os.getcwd()}")
64 return ExperimentConfig(
65 name=name,
66 logger=logger,
67 params=params,
68 optuna=optuna_cfg,
69 data=data,
70 )