Coverage for src/tinycta/engine.py: 100%
77 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-06 05:36 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-06 05:36 +0000
1"""Engine for correlation-aware risk position optimization."""
3from __future__ import annotations
5import dataclasses
7import numpy as np
8import polars as pl
10from .config import Config
11from .ewm_cov import ewm_covariance as _ewm_covariance
12from .linalg import inv_a_norm as _inv_a_norm
13from .linalg import solve as _solve
14from .signal import shrink2id as _shrink2id
15from .util import vol_adj as _vol_adj
18@dataclasses.dataclass(frozen=True)
19class Engine:
20 """Correlation-aware risk position optimizer (Basanos engine)."""
22 prices: pl.DataFrame
23 mu: pl.DataFrame
24 cfg: Config
26 def __post_init__(self) -> None:
27 """Validate that prices and mu are aligned and both contain a date column."""
28 if "date" not in self.prices.columns:
29 raise ValueError
30 if "date" not in self.mu.columns:
31 raise ValueError
32 if self.prices.shape != self.mu.shape:
33 raise ValueError
34 if not set(self.prices.columns) == set(self.mu.columns):
35 raise ValueError
37 @property
38 def assets(self) -> list[str]:
39 """List numeric asset column names, excluding the date column."""
40 return [c for c in self.prices.columns if c != "date" and self.prices[c].dtype.is_numeric()]
42 @property
43 def ret_adj(self) -> pl.DataFrame:
44 """Per-asset EWMA-volatility-adjusted log returns clipped by cfg.clip."""
45 return self.prices.with_columns(
46 [_vol_adj(pl.col(asset), vola=self.cfg.vola, clip=self.cfg.clip) for asset in self.assets]
47 )
49 @property
50 def vola(self) -> pl.DataFrame:
51 """Per-asset EWMA volatility of percentage returns."""
52 return self.prices.with_columns(
53 pl.col(asset)
54 .pct_change()
55 .ewm_std(com=self.cfg.vola - 1, adjust=True, min_samples=self.cfg.vola)
56 .alias(asset)
57 for asset in self.assets
58 )
60 @property
61 def cor(self) -> dict[object, np.ndarray]:
62 """Per-timestamp EWMA correlation matrices, returned as a date-keyed dict."""
63 cov = _ewm_covariance(
64 self.ret_adj,
65 assets=self.assets,
66 index_col="date",
67 window=2 * self.cfg.corr + 1,
68 warmup=self.cfg.corr,
69 )
70 result = {}
71 for k, mat in cov.items():
72 std = np.sqrt(np.abs(np.diag(mat)))
73 outer = np.outer(std, std)
74 result[k] = np.where(outer > 0, mat / outer, np.nan)
75 return result
77 @property
78 def cash_position(self) -> pl.DataFrame:
79 """Correlation-shrinkage-optimized cash positions for each timestamp."""
80 cor = self.cor
81 assets = self.assets
83 prices_num = self.prices.select(assets).to_numpy()
84 returns_num = np.zeros_like(prices_num, dtype=float)
85 returns_num[1:] = prices_num[1:] / prices_num[:-1] - 1.0
87 mu = self.mu.select(assets).to_numpy()
88 risk_pos_np = np.full_like(mu, fill_value=np.nan, dtype=float)
89 cash_pos_np = np.full_like(mu, fill_value=np.nan, dtype=float)
90 vola_np = self.vola.select(assets).to_numpy()
92 profit_variance = 1.0
93 lamb = 0.99
95 for i, t in enumerate(cor.keys()):
96 mask = np.isfinite(prices_num[i])
98 if i > 0:
99 ret_mask = np.isfinite(returns_num[i]) & mask
100 if ret_mask.any():
101 cash_pos_np[i - 1] = risk_pos_np[i - 1] / vola_np[i - 1]
102 lhs = np.nan_to_num(cash_pos_np[i - 1, ret_mask], nan=0.0)
103 rhs = np.nan_to_num(returns_num[i, ret_mask], nan=0.0)
104 profit = lhs @ rhs
105 profit_variance = lamb * profit_variance + (1 - lamb) * profit**2
107 if not mask.any():
108 continue
110 corr_n = cor[t]
111 matrix = _shrink2id(corr_n, lamb=self.cfg.shrink)[np.ix_(mask, mask)]
112 expected_mu = np.nan_to_num(mu[i][mask])
113 denom = _inv_a_norm(expected_mu, matrix)
115 if denom is None or not np.isfinite(denom) or denom <= 1e-12 or np.allclose(expected_mu, 0.0):
116 pos = np.zeros_like(expected_mu)
117 else:
118 pos = _solve(matrix, expected_mu) / denom
120 risk_pos_np[i, mask] = pos / profit_variance
121 cash_pos_np[i, mask] = risk_pos_np[i, mask] / vola_np[i, mask]
123 return self.prices.with_columns([(pl.lit(cash_pos_np[:, i]).alias(asset)) for i, asset in enumerate(assets)])