Coverage for src/tinycta/engine.py: 100%

77 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-06 05:36 +0000

1"""Engine for correlation-aware risk position optimization.""" 

2 

3from __future__ import annotations 

4 

5import dataclasses 

6 

7import numpy as np 

8import polars as pl 

9 

10from .config import Config 

11from .ewm_cov import ewm_covariance as _ewm_covariance 

12from .linalg import inv_a_norm as _inv_a_norm 

13from .linalg import solve as _solve 

14from .signal import shrink2id as _shrink2id 

15from .util import vol_adj as _vol_adj 

16 

17 

18@dataclasses.dataclass(frozen=True) 

19class Engine: 

20 """Correlation-aware risk position optimizer (Basanos engine).""" 

21 

22 prices: pl.DataFrame 

23 mu: pl.DataFrame 

24 cfg: Config 

25 

26 def __post_init__(self) -> None: 

27 """Validate that prices and mu are aligned and both contain a date column.""" 

28 if "date" not in self.prices.columns: 

29 raise ValueError 

30 if "date" not in self.mu.columns: 

31 raise ValueError 

32 if self.prices.shape != self.mu.shape: 

33 raise ValueError 

34 if not set(self.prices.columns) == set(self.mu.columns): 

35 raise ValueError 

36 

37 @property 

38 def assets(self) -> list[str]: 

39 """List numeric asset column names, excluding the date column.""" 

40 return [c for c in self.prices.columns if c != "date" and self.prices[c].dtype.is_numeric()] 

41 

42 @property 

43 def ret_adj(self) -> pl.DataFrame: 

44 """Per-asset EWMA-volatility-adjusted log returns clipped by cfg.clip.""" 

45 return self.prices.with_columns( 

46 [_vol_adj(pl.col(asset), vola=self.cfg.vola, clip=self.cfg.clip) for asset in self.assets] 

47 ) 

48 

49 @property 

50 def vola(self) -> pl.DataFrame: 

51 """Per-asset EWMA volatility of percentage returns.""" 

52 return self.prices.with_columns( 

53 pl.col(asset) 

54 .pct_change() 

55 .ewm_std(com=self.cfg.vola - 1, adjust=True, min_samples=self.cfg.vola) 

56 .alias(asset) 

57 for asset in self.assets 

58 ) 

59 

60 @property 

61 def cor(self) -> dict[object, np.ndarray]: 

62 """Per-timestamp EWMA correlation matrices, returned as a date-keyed dict.""" 

63 cov = _ewm_covariance( 

64 self.ret_adj, 

65 assets=self.assets, 

66 index_col="date", 

67 window=2 * self.cfg.corr + 1, 

68 warmup=self.cfg.corr, 

69 ) 

70 result = {} 

71 for k, mat in cov.items(): 

72 std = np.sqrt(np.abs(np.diag(mat))) 

73 outer = np.outer(std, std) 

74 result[k] = np.where(outer > 0, mat / outer, np.nan) 

75 return result 

76 

77 @property 

78 def cash_position(self) -> pl.DataFrame: 

79 """Correlation-shrinkage-optimized cash positions for each timestamp.""" 

80 cor = self.cor 

81 assets = self.assets 

82 

83 prices_num = self.prices.select(assets).to_numpy() 

84 returns_num = np.zeros_like(prices_num, dtype=float) 

85 returns_num[1:] = prices_num[1:] / prices_num[:-1] - 1.0 

86 

87 mu = self.mu.select(assets).to_numpy() 

88 risk_pos_np = np.full_like(mu, fill_value=np.nan, dtype=float) 

89 cash_pos_np = np.full_like(mu, fill_value=np.nan, dtype=float) 

90 vola_np = self.vola.select(assets).to_numpy() 

91 

92 profit_variance = 1.0 

93 lamb = 0.99 

94 

95 for i, t in enumerate(cor.keys()): 

96 mask = np.isfinite(prices_num[i]) 

97 

98 if i > 0: 

99 ret_mask = np.isfinite(returns_num[i]) & mask 

100 if ret_mask.any(): 

101 cash_pos_np[i - 1] = risk_pos_np[i - 1] / vola_np[i - 1] 

102 lhs = np.nan_to_num(cash_pos_np[i - 1, ret_mask], nan=0.0) 

103 rhs = np.nan_to_num(returns_num[i, ret_mask], nan=0.0) 

104 profit = lhs @ rhs 

105 profit_variance = lamb * profit_variance + (1 - lamb) * profit**2 

106 

107 if not mask.any(): 

108 continue 

109 

110 corr_n = cor[t] 

111 matrix = _shrink2id(corr_n, lamb=self.cfg.shrink)[np.ix_(mask, mask)] 

112 expected_mu = np.nan_to_num(mu[i][mask]) 

113 denom = _inv_a_norm(expected_mu, matrix) 

114 

115 if denom is None or not np.isfinite(denom) or denom <= 1e-12 or np.allclose(expected_mu, 0.0): 

116 pos = np.zeros_like(expected_mu) 

117 else: 

118 pos = _solve(matrix, expected_mu) / denom 

119 

120 risk_pos_np[i, mask] = pos / profit_variance 

121 cash_pos_np[i, mask] = risk_pos_np[i, mask] / vola_np[i, mask] 

122 

123 return self.prices.with_columns([(pl.lit(cash_pos_np[:, i]).alias(asset)) for i, asset in enumerate(assets)])