Coverage for src / antarctic / polars_field.py: 100%
32 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-27 10:29 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-27 10:29 +0000
1"""Provide a custom Polars field type for MongoEngine.
3It allows storing polars DataFrames in MongoDB by converting them to and from
4parquet-formatted byte streams. This enables efficient storage and retrieval
5of polars data structures within MongoDB documents.
6"""
8from __future__ import annotations
10from io import BytesIO
11from typing import Any, Literal
13import polars as pl
14from mongoengine.base import BaseField
16CompressionType = Literal["lz4", "uncompressed", "snappy", "gzip", "brotli", "zstd"]
19def _read(value: bytes, columns: list[str] | None = None) -> pl.DataFrame:
20 """Read a DataFrame from its binary parquet representation.
22 Args:
23 value: Binary representation of a DataFrame stored as parquet
24 columns: Optional list of column names to load (loads all columns if None)
26 Returns:
27 pl.DataFrame: The reconstructed polars DataFrame
29 Examples:
30 >>> import polars as pl
31 >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
32 >>> data = _write(df)
33 >>> result = _read(data)
34 >>> result["a"].to_list()
35 [1, 2, 3]
37 Select specific columns:
39 >>> result = _read(data, columns=["b"])
40 >>> result.columns
41 ['b']
43 """
44 return pl.read_parquet(BytesIO(value), columns=columns)
47def _write(value: pl.DataFrame, compression: CompressionType = "zstd") -> bytes:
48 """Convert a Polars DataFrame into a compressed parquet byte-stream.
50 The byte-stream encodes in its metadata the structure of the Polars object,
51 including column names and data types.
53 Args:
54 value: The DataFrame to convert
55 compression: Compression algorithm to use (default: "zstd")
57 Returns:
58 bytes: Binary representation of the DataFrame in parquet format
60 Examples:
61 >>> import polars as pl
62 >>> df = pl.DataFrame({"x": [1.0, 2.0], "y": [3.0, 4.0]})
63 >>> data = _write(df)
64 >>> isinstance(data, bytes)
65 True
66 >>> data[:4] # Parquet magic bytes
67 b'PAR1'
69 """
70 buffer = BytesIO()
71 value.write_parquet(buffer, compression=compression)
72 return buffer.getvalue()
75class PolarsField(BaseField): # type: ignore[misc]
76 """Custom MongoEngine field type for storing polars DataFrames.
78 This field handles the conversion between polars DataFrames and binary data
79 that can be stored in MongoDB. It uses parquet format for efficient storage
80 and retrieval of tabular data.
81 """
83 def __init__(self, compression: CompressionType = "zstd", **kwargs: Any) -> None:
84 """Initialize a PolarsField.
86 Args:
87 compression: Compression algorithm to use for parquet storage (default: "zstd")
88 **kwargs: Additional arguments passed to the parent BaseField
90 """
91 super().__init__(**kwargs)
92 self.compression = compression
94 def __set__(self, instance: Any, value: pl.DataFrame | bytes | None) -> None:
95 """Convert and set the value for this field.
97 If the value is a DataFrame, it's converted to a parquet byte stream.
98 If it's already bytes, it's stored as-is.
100 Args:
101 instance: The document instance
102 value: The value to set (DataFrame, bytes, or None)
104 Raises:
105 TypeError: If the value is neither a DataFrame, bytes, nor None
107 """
108 if value is not None:
109 if isinstance(value, pl.DataFrame):
110 # Convert DataFrame to binary format for storage
111 value = _write(value, compression=self.compression)
112 elif isinstance(value, bytes):
113 # Already in binary format, store as-is
114 pass
115 else:
116 msg = f"Type of value {type(value)} not supported. Expected DataFrame or bytes."
117 raise TypeError(msg)
118 super().__set__(instance, value)
120 def __get__(self, instance: Any, owner: type) -> pl.DataFrame | PolarsField | None:
121 """Retrieve and convert the stored value back to a DataFrame.
123 Args:
124 instance: The document instance (None for class-level access)
125 owner: The document class
127 Returns:
128 PolarsField: The descriptor itself when accessed at class level
129 pl.DataFrame: The retrieved DataFrame when accessed on an instance
130 None: If no data is stored
132 """
133 if instance is None:
134 return self
136 data = super().__get__(instance, owner)
138 if data is not None:
139 return _read(data)
141 return None