Coverage for src / jquantstats / data.py: 100%
159 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-26 18:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-26 18:44 +0000
1"""Financial returns data container and manipulation utilities."""
3from __future__ import annotations
5import dataclasses
6from collections.abc import Iterator
7from datetime import timedelta
8from typing import TYPE_CHECKING, cast
10import narwhals as nw
11import polars as pl
13from ._types import NativeFrame, NativeFrameOrScalar
15if TYPE_CHECKING:
16 from ._plots import DataPlots
17 from ._reports import Reports
18 from ._stats import Stats
21def _to_polars(df: NativeFrame) -> pl.DataFrame:
22 """Convert any narwhals-compatible DataFrame to a polars DataFrame."""
23 if isinstance(df, pl.DataFrame):
24 return df
25 return nw.from_native(df, eager_only=True).to_polars()
28def _subtract_risk_free(dframe: pl.DataFrame, rf: float | pl.DataFrame, date_col: str) -> pl.DataFrame:
29 """Subtract the risk-free rate from all numeric columns in the DataFrame.
31 Parameters
32 ----------
33 dframe : pl.DataFrame
34 DataFrame containing returns data with a date column
35 and one or more numeric columns representing asset returns.
37 rf : float | pl.DataFrame
38 Risk-free rate to subtract from returns.
40 - If float: A constant risk-free rate applied to all dates.
41 - If pl.DataFrame: A DataFrame with a date column and a second column
42 containing time-varying risk-free rates.
44 date_col : str
45 Name of the date column in both DataFrames for joining
46 when rf is a DataFrame.
48 Returns:
49 -------
50 pl.DataFrame
51 DataFrame with the risk-free rate subtracted from all numeric columns,
52 preserving the original column names.
54 """
55 if isinstance(rf, float):
56 rf_dframe = dframe.select([pl.col(date_col), pl.lit(rf).alias("rf")])
57 else:
58 if not isinstance(rf, pl.DataFrame):
59 raise TypeError("rf must be a float or DataFrame") # noqa: TRY003
60 rf_dframe = rf.rename({rf.columns[1]: "rf"}) if rf.columns[1] != "rf" else rf
62 dframe = dframe.join(rf_dframe, on=date_col, how="inner")
63 return dframe.select(
64 [pl.col(date_col)]
65 + [(pl.col(col) - pl.col("rf")).alias(col) for col in dframe.columns if col not in {date_col, "rf"}]
66 )
69@dataclasses.dataclass(frozen=True, slots=True)
70class Data:
71 """A container for financial returns data and an optional benchmark.
73 This class provides methods for analyzing and manipulating financial returns data,
74 including converting returns to prices, calculating drawdowns, and resampling data
75 to different time periods. It also provides access to statistical metrics through
76 the stats property and visualization through the plots property.
78 Attributes:
79 returns (pl.DataFrame): DataFrame containing returns data with assets as columns.
80 benchmark (pl.DataFrame, optional): DataFrame containing benchmark returns data.
81 Defaults to None.
82 index (pl.DataFrame): DataFrame containing the date index for the returns data.
84 """
86 returns: pl.DataFrame
87 index: pl.DataFrame
88 benchmark: pl.DataFrame | None = None
90 def __post_init__(self) -> None:
91 """Validate the Data object after initialization."""
92 # You need at least two points
93 if self.index.shape[0] < 2:
94 raise ValueError("Index must contain at least two timestamps.") # noqa: TRY003
96 # Check index is monotonically increasing
97 datetime_col = self.index[self.index.columns[0]]
98 if not datetime_col.is_sorted():
99 raise ValueError("Index must be monotonically increasing.") # noqa: TRY003
101 # Check row count matches returns
102 if self.returns.shape[0] != self.index.shape[0]:
103 raise ValueError("Returns and index must have the same number of rows.") # noqa: TRY003
105 # Check row count matches benchmark (if provided)
106 if self.benchmark is not None and self.benchmark.shape[0] != self.index.shape[0]:
107 raise ValueError("Benchmark and index must have the same number of rows.") # noqa: TRY003
109 @classmethod
110 def from_returns(
111 cls,
112 returns: NativeFrame,
113 rf: NativeFrameOrScalar = 0.0,
114 benchmark: NativeFrame | None = None,
115 date_col: str = "Date",
116 ) -> Data:
117 """Create a Data object from returns and optional benchmark.
119 Parameters
120 ----------
121 returns : NativeFrame
122 Financial returns data. First column should be the date column,
123 remaining columns are asset returns.
125 rf : float | NativeFrame, optional
126 Risk-free rate. Default is 0.0 (no risk-free rate adjustment).
128 - If float: Constant risk-free rate applied to all dates.
129 - If NativeFrame: Time-varying risk-free rate with dates matching returns.
131 benchmark : NativeFrame | None, optional
132 Benchmark returns. Default is None (no benchmark).
133 First column should be the date column, remaining columns are benchmark returns.
135 date_col : str, optional
136 Name of the date column in the DataFrames. Default is "Date".
138 Returns:
139 -------
140 Data
141 Object containing excess returns and benchmark (if any), with methods for
142 analysis and visualization through the ``stats`` and ``plots`` properties.
144 Raises:
145 ------
146 ValueError
147 If there are no overlapping dates between returns and benchmark.
149 Examples:
150 --------
151 Basic usage:
153 ```python
154 from jquantstats import Data
155 import polars as pl
157 returns = pl.DataFrame({
158 "Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
159 "Asset1": [0.01, -0.02, 0.03]
160 }).with_columns(pl.col("Date").str.to_date())
162 data = Data.from_returns(returns=returns)
163 ```
165 With benchmark and risk-free rate:
167 ```python
168 benchmark = pl.DataFrame({
169 "Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
170 "Market": [0.005, -0.01, 0.02]
171 }).with_columns(pl.col("Date").str.to_date())
173 data = Data.from_returns(returns=returns, benchmark=benchmark, rf=0.0002)
174 ```
176 """
177 returns_pl = _to_polars(returns)
178 benchmark_pl = _to_polars(benchmark) if benchmark is not None else None
179 rf_converted: float | pl.DataFrame
180 if isinstance(rf, pl.DataFrame) or (not isinstance(rf, float) and not isinstance(rf, int)):
181 rf_converted = _to_polars(rf)
182 else:
183 rf_converted = rf # int is not float/DataFrame: _subtract_risk_free raises TypeError
185 if benchmark_pl is not None:
186 joined_dates = returns_pl.join(benchmark_pl, on=date_col, how="inner").select(date_col)
187 if joined_dates.is_empty():
188 raise ValueError("No overlapping dates between returns and benchmark.") # noqa: TRY003
189 returns_pl = returns_pl.join(joined_dates, on=date_col, how="inner")
190 benchmark_pl = benchmark_pl.join(joined_dates, on=date_col, how="inner")
192 index = returns_pl.select(date_col)
193 excess_returns = _subtract_risk_free(returns_pl, rf_converted, date_col).drop(date_col)
194 excess_benchmark = (
195 _subtract_risk_free(benchmark_pl, rf_converted, date_col).drop(date_col)
196 if benchmark_pl is not None
197 else None
198 )
200 return cls(returns=excess_returns, benchmark=excess_benchmark, index=index)
202 @classmethod
203 def from_prices(
204 cls,
205 prices: NativeFrame,
206 rf: NativeFrameOrScalar = 0.0,
207 benchmark: NativeFrame | None = None,
208 date_col: str = "Date",
209 ) -> Data:
210 """Create a Data object from prices and optional benchmark.
212 Converts price levels to returns via percentage change and delegates
213 to :meth:`from_returns`. The first row of each asset is dropped
214 because no prior price is available to compute a return.
216 Parameters
217 ----------
218 prices : NativeFrame
219 Price-level data. First column should be the date column;
220 remaining columns are asset prices.
222 rf : float | NativeFrame, optional
223 Risk-free rate. Forwarded unchanged to :meth:`from_returns`.
224 Default is 0.0 (no risk-free rate adjustment).
226 benchmark : NativeFrame | None, optional
227 Benchmark prices. Converted to returns in the same way as
228 ``prices`` before being forwarded to :meth:`from_returns`.
229 Default is None (no benchmark).
231 date_col : str, optional
232 Name of the date column in the DataFrames. Default is ``"Date"``.
234 Returns:
235 -------
236 Data
237 Object containing excess returns derived from the supplied prices,
238 with methods for analysis and visualization through the ``stats``
239 and ``plots`` properties.
241 Examples:
242 --------
243 ```python
244 from jquantstats import Data
245 import polars as pl
247 prices = pl.DataFrame({
248 "Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
249 "Asset1": [100.0, 101.0, 99.0]
250 }).with_columns(pl.col("Date").str.to_date())
252 data = Data.from_prices(prices=prices)
253 ```
255 """
256 prices_pl = _to_polars(prices)
257 asset_cols = [c for c in prices_pl.columns if c != date_col]
258 returns_pl = prices_pl.with_columns([pl.col(c).pct_change().alias(c) for c in asset_cols]).slice(1)
260 benchmark_returns: NativeFrame | None = None
261 if benchmark is not None:
262 benchmark_pl = _to_polars(benchmark)
263 bench_cols = [c for c in benchmark_pl.columns if c != date_col]
264 benchmark_returns = benchmark_pl.with_columns([pl.col(c).pct_change().alias(c) for c in bench_cols]).slice(
265 1
266 )
268 return cls.from_returns(returns=returns_pl, rf=rf, benchmark=benchmark_returns, date_col=date_col)
270 def __repr__(self) -> str:
271 """Return a string representation of the Data object."""
272 rows = len(self.index)
273 date_cols = self.date_col
274 if date_cols:
275 date_column = date_cols[0]
276 start = self.index[date_column].min()
277 end = self.index[date_column].max()
278 return f"Data(assets={self.assets}, rows={rows}, start={start}, end={end})"
279 return f"Data(assets={self.assets}, rows={rows})" # pragma: no cover # __post_init__ requires ≥1 index column
281 @property
282 def plots(self) -> DataPlots:
283 """Provides access to visualization methods for the financial data.
285 Returns:
286 DataPlots: An instance of the DataPlots class initialized with this data.
288 """
289 from ._plots import DataPlots
291 return DataPlots(self)
293 @property
294 def stats(self) -> Stats:
295 """Provides access to statistical analysis methods for the financial data.
297 Returns:
298 Stats: An instance of the Stats class initialized with this data.
300 """
301 from ._stats import Stats
303 return Stats(self)
305 @property
306 def reports(self) -> Reports:
307 """Provides access to reporting methods for the financial data.
309 Returns:
310 Reports: An instance of the Reports class initialized with this data.
312 """
313 from ._reports import Reports
315 return Reports(self)
317 @property
318 def date_col(self) -> list[str]:
319 """Return the column names of the index DataFrame.
321 Returns:
322 list[str]: List of column names in the index DataFrame, typically containing
323 the date column name.
325 """
326 return list(self.index.columns)
328 @property
329 def assets(self) -> list[str]:
330 """Return the combined list of asset column names from returns and benchmark.
332 Returns:
333 list[str]: List of all asset column names from both returns and benchmark
334 (if available).
336 """
337 if self.benchmark is not None:
338 return list(self.returns.columns) + list(self.benchmark.columns)
339 return list(self.returns.columns)
341 @property
342 def all(self) -> pl.DataFrame:
343 """Combine index, returns, and benchmark data into a single DataFrame.
345 This property provides a convenient way to access all data in a single DataFrame,
346 which is useful for analysis and visualization.
348 Returns:
349 pl.DataFrame: A DataFrame containing the index, all returns data, and benchmark data
350 (if available) combined horizontally.
352 """
353 if self.benchmark is None:
354 return pl.concat([self.index, self.returns], how="horizontal")
355 else:
356 return pl.concat([self.index, self.returns, self.benchmark], how="horizontal")
358 def resample(self, every: str = "1mo") -> Data:
359 """Resamples returns and benchmark to a different frequency using Polars.
361 Args:
362 every (str, optional): Resampling frequency (e.g., '1mo', '1y'). Defaults to '1mo'.
364 Returns:
365 Data: Resampled data.
367 """
369 def resample_frame(dframe: pl.DataFrame) -> pl.DataFrame:
370 """Resample a single DataFrame to the target frequency using compound returns."""
371 dframe = self.index.hstack(dframe) # Add the date column for resampling
373 return dframe.group_by_dynamic(
374 index_column=self.index.columns[0], every=every, period=every, closed="right", label="right"
375 ).agg(
376 [
377 ((pl.col(col) + 1.0).product() - 1.0).alias(col)
378 for col in dframe.columns
379 if col != self.index.columns[0]
380 ]
381 )
383 resampled_returns = resample_frame(self.returns)
384 resampled_benchmark = resample_frame(self.benchmark) if self.benchmark is not None else None
385 resampled_index = resampled_returns.select(self.index.columns[0])
387 return Data(
388 returns=resampled_returns.drop(self.index.columns[0]),
389 benchmark=resampled_benchmark.drop(self.index.columns[0]) if resampled_benchmark is not None else None,
390 index=resampled_index,
391 )
393 def describe(self) -> pl.DataFrame:
394 """Return a tidy summary of shape, date range and asset names.
396 Returns:
397 -------
398 pl.DataFrame
399 One row per asset with columns: asset, start, end, rows, has_benchmark.
401 """
402 date_column = self.date_col[0]
403 start = self.index[date_column].min()
404 end = self.index[date_column].max()
405 rows = len(self.index)
406 return pl.DataFrame(
407 {
408 "asset": self.returns.columns,
409 "start": [start] * len(self.returns.columns),
410 "end": [end] * len(self.returns.columns),
411 "rows": [rows] * len(self.returns.columns),
412 "has_benchmark": [self.benchmark is not None] * len(self.returns.columns),
413 }
414 )
416 def copy(self) -> Data:
417 """Create a deep copy of the Data object.
419 Returns:
420 Data: A new Data object with copies of the returns and benchmark.
422 """
423 if self.benchmark is not None:
424 return Data(returns=self.returns.clone(), benchmark=self.benchmark.clone(), index=self.index.clone())
425 return Data(returns=self.returns.clone(), index=self.index.clone())
427 def head(self, n: int = 5) -> Data:
428 """Return the first n rows of the combined returns and benchmark data.
430 Args:
431 n (int, optional): Number of rows to return. Defaults to 5.
433 Returns:
434 Data: A new Data object containing the first n rows of the combined data.
436 """
437 benchmark_head = self.benchmark.head(n) if self.benchmark is not None else None
438 return Data(returns=self.returns.head(n), benchmark=benchmark_head, index=self.index.head(n))
440 def tail(self, n: int = 5) -> Data:
441 """Return the last n rows of the combined returns and benchmark data.
443 Args:
444 n (int, optional): Number of rows to return. Defaults to 5.
446 Returns:
447 Data: A new Data object containing the last n rows of the combined data.
449 """
450 benchmark_tail = self.benchmark.tail(n) if self.benchmark is not None else None
451 return Data(returns=self.returns.tail(n), benchmark=benchmark_tail, index=self.index.tail(n))
453 def truncate(self, start: object = None, end: object = None) -> Data:
454 """Return a new Data object truncated to the inclusive [start, end] range.
456 When the index is temporal (Date/Datetime), truncation is performed by
457 comparing the date column against ``start`` and ``end`` values.
459 When the index is integer-based, row slicing is used instead, and
460 ``start`` and ``end`` must be non-negative integers. Passing
461 non-integer bounds to an integer-indexed Data raises :exc:`TypeError`.
463 Args:
464 start: Optional lower bound (inclusive). A date/datetime value
465 when the index is temporal; a non-negative :class:`int` row
466 index when the data has no temporal index.
467 end: Optional upper bound (inclusive). Same type rules as
468 ``start``.
470 Returns:
471 Data: A new Data object filtered to the specified range.
473 Raises:
474 TypeError: When the index is not temporal and a non-integer bound
475 is supplied.
477 """
478 date_column = self.index.columns[0]
479 is_temporal = self.index[date_column].dtype.is_temporal()
481 if is_temporal:
482 cond = pl.lit(True)
483 if start is not None:
484 cond = cond & (pl.col(date_column) >= pl.lit(start))
485 if end is not None:
486 cond = cond & (pl.col(date_column) <= pl.lit(end))
487 mask = self.index.select(cond.alias("mask"))["mask"]
488 new_index = self.index.filter(mask)
489 new_returns = self.returns.filter(mask)
490 new_benchmark = self.benchmark.filter(mask) if self.benchmark is not None else None
491 else:
492 if start is not None and not isinstance(start, int):
493 raise TypeError(f"start must be an integer, got {type(start).__name__}.") # noqa: TRY003
494 if end is not None and not isinstance(end, int):
495 raise TypeError(f"end must be an integer, got {type(end).__name__}.") # noqa: TRY003
496 row_start = start if start is not None else 0
497 row_end = end + 1 if end is not None else self.index.height
498 length = max(0, row_end - row_start)
499 new_index = self.index.slice(row_start, length)
500 new_returns = self.returns.slice(row_start, length)
501 new_benchmark = self.benchmark.slice(row_start, length) if self.benchmark is not None else None
503 return Data(returns=new_returns, benchmark=new_benchmark, index=new_index)
505 @property
506 def _periods_per_year(self) -> float:
507 """Estimate the number of periods per year based on average frequency in the index.
509 For temporal (Date/Datetime) indices, computes the mean gap between observations
510 and converts to an annualised period count (e.g. ~252 for daily, ~52 for weekly).
512 For integer indices (date-free portfolios), falls back to 252 trading days per year
513 because integer diffs have no time meaning.
514 """
515 datetime_col = self.index[self.index.columns[0]]
517 if not datetime_col.dtype.is_temporal():
518 return 252.0
520 sorted_dt = datetime_col.sort()
521 diffs = sorted_dt.diff().drop_nulls()
522 mean_diff = diffs.mean()
524 if isinstance(mean_diff, timedelta):
525 seconds = mean_diff.total_seconds()
526 else: # pragma: no cover # Polars always returns timedelta for temporal diff
527 seconds = cast(float, mean_diff) if mean_diff is not None else 1.0
529 return (365 * 24 * 60 * 60) / seconds
531 def items(self) -> Iterator[tuple[str, pl.Series]]:
532 """Iterate over all assets and their corresponding data series.
534 This method provides a convenient way to iterate over all assets in the data,
535 yielding each asset name and its corresponding data series.
537 Yields:
538 tuple[str, pl.Series]: A tuple containing the asset name and its data series.
540 """
541 matrix = self.all
543 for col in self.assets:
544 yield col, matrix.get_column(col)