Coverage for src/tinycta/hyper/_setup.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-06 05:36 +0000

1"""Experiment setup helpers: logger configuration.""" 

2 

3import os 

4from pathlib import Path 

5from typing import Any, NamedTuple 

6 

7import yaml 

8from loguru import logger 

9 

10_FILE_SINKS: dict[str, int] = {} 

11 

12 

13class ExperimentConfig(NamedTuple): 

14 """Resources bundled for a notebook experiment run.""" 

15 

16 name: str 

17 logger: Any 

18 params: dict | None = None 

19 optuna: dict | None = None 

20 data: dict | None = None 

21 

22 

23def _load_yaml(path: Path) -> dict: 

24 """Load a YAML file, returning an empty dict if the file does not exist.""" 

25 if not path.exists(): 

26 return {} 

27 with open(path) as f: 

28 return yaml.safe_load(f) or {} 

29 

30 

31def get_config(name: str, config_path: Path | str | None = None) -> ExperimentConfig: 

32 """Return logger and config sections for an experiment. 

33 

34 Accepts either a shared ``config.yml`` or an experiment-specific 

35 ``config/{name}.yml``. Paths in the config are resolved relative to the 

36 notebooks directory (one level above any ``config/`` subdirectory). 

37 ``NOTEBOOK_OUTPUT_FOLDER`` env var overrides the output directory used for 

38 the log file sink. 

39 """ 

40 config_path = Path(config_path) if config_path else Path.cwd() / "config.yml" 

41 cfg = _load_yaml(config_path) 

42 # Resolve paths relative to the notebooks dir, not the config subdir. 

43 base = config_path.parent.parent if config_path.parent.name == "config" else config_path.parent 

44 sibling = _load_yaml(base / "config" / f"{name}.yml") 

45 

46 data = cfg.get("data") or sibling.get("data") or {} 

47 params = cfg.get("params") or sibling.get("params") or {} 

48 optuna_cfg = cfg.get("optuna") or sibling.get("optuna") or {} 

49 

50 env_folder = os.environ.get("NOTEBOOK_OUTPUT_FOLDER") 

51 if env_folder: 

52 output_dir = Path(env_folder) 

53 else: 

54 folder = data.get("output_path", "output") 

55 output_dir = (base / folder / name).resolve() 

56 output_dir.mkdir(parents=True, exist_ok=True) 

57 

58 log_path = output_dir / "output.log" 

59 key = str(log_path.resolve()) 

60 if key not in _FILE_SINKS: 

61 _FILE_SINKS[key] = logger.add(log_path) 

62 logger.info(f"Writing output to: {output_dir}\nCurrent working directory: {os.getcwd()}") 

63 

64 return ExperimentConfig( 

65 name=name, 

66 logger=logger, 

67 params=params, 

68 optuna=optuna_cfg, 

69 data=data, 

70 )