Coverage for src / antarctic / polars_field.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-01-27 10:29 +0000

1"""Provide a custom Polars field type for MongoEngine. 

2 

3It allows storing polars DataFrames in MongoDB by converting them to and from 

4parquet-formatted byte streams. This enables efficient storage and retrieval 

5of polars data structures within MongoDB documents. 

6""" 

7 

8from __future__ import annotations 

9 

10from io import BytesIO 

11from typing import Any, Literal 

12 

13import polars as pl 

14from mongoengine.base import BaseField 

15 

16CompressionType = Literal["lz4", "uncompressed", "snappy", "gzip", "brotli", "zstd"] 

17 

18 

19def _read(value: bytes, columns: list[str] | None = None) -> pl.DataFrame: 

20 """Read a DataFrame from its binary parquet representation. 

21 

22 Args: 

23 value: Binary representation of a DataFrame stored as parquet 

24 columns: Optional list of column names to load (loads all columns if None) 

25 

26 Returns: 

27 pl.DataFrame: The reconstructed polars DataFrame 

28 

29 Examples: 

30 >>> import polars as pl 

31 >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) 

32 >>> data = _write(df) 

33 >>> result = _read(data) 

34 >>> result["a"].to_list() 

35 [1, 2, 3] 

36 

37 Select specific columns: 

38 

39 >>> result = _read(data, columns=["b"]) 

40 >>> result.columns 

41 ['b'] 

42 

43 """ 

44 return pl.read_parquet(BytesIO(value), columns=columns) 

45 

46 

47def _write(value: pl.DataFrame, compression: CompressionType = "zstd") -> bytes: 

48 """Convert a Polars DataFrame into a compressed parquet byte-stream. 

49 

50 The byte-stream encodes in its metadata the structure of the Polars object, 

51 including column names and data types. 

52 

53 Args: 

54 value: The DataFrame to convert 

55 compression: Compression algorithm to use (default: "zstd") 

56 

57 Returns: 

58 bytes: Binary representation of the DataFrame in parquet format 

59 

60 Examples: 

61 >>> import polars as pl 

62 >>> df = pl.DataFrame({"x": [1.0, 2.0], "y": [3.0, 4.0]}) 

63 >>> data = _write(df) 

64 >>> isinstance(data, bytes) 

65 True 

66 >>> data[:4] # Parquet magic bytes 

67 b'PAR1' 

68 

69 """ 

70 buffer = BytesIO() 

71 value.write_parquet(buffer, compression=compression) 

72 return buffer.getvalue() 

73 

74 

75class PolarsField(BaseField): # type: ignore[misc] 

76 """Custom MongoEngine field type for storing polars DataFrames. 

77 

78 This field handles the conversion between polars DataFrames and binary data 

79 that can be stored in MongoDB. It uses parquet format for efficient storage 

80 and retrieval of tabular data. 

81 """ 

82 

83 def __init__(self, compression: CompressionType = "zstd", **kwargs: Any) -> None: 

84 """Initialize a PolarsField. 

85 

86 Args: 

87 compression: Compression algorithm to use for parquet storage (default: "zstd") 

88 **kwargs: Additional arguments passed to the parent BaseField 

89 

90 """ 

91 super().__init__(**kwargs) 

92 self.compression = compression 

93 

94 def __set__(self, instance: Any, value: pl.DataFrame | bytes | None) -> None: 

95 """Convert and set the value for this field. 

96 

97 If the value is a DataFrame, it's converted to a parquet byte stream. 

98 If it's already bytes, it's stored as-is. 

99 

100 Args: 

101 instance: The document instance 

102 value: The value to set (DataFrame, bytes, or None) 

103 

104 Raises: 

105 TypeError: If the value is neither a DataFrame, bytes, nor None 

106 

107 """ 

108 if value is not None: 

109 if isinstance(value, pl.DataFrame): 

110 # Convert DataFrame to binary format for storage 

111 value = _write(value, compression=self.compression) 

112 elif isinstance(value, bytes): 

113 # Already in binary format, store as-is 

114 pass 

115 else: 

116 msg = f"Type of value {type(value)} not supported. Expected DataFrame or bytes." 

117 raise TypeError(msg) 

118 super().__set__(instance, value) 

119 

120 def __get__(self, instance: Any, owner: type) -> pl.DataFrame | PolarsField | None: 

121 """Retrieve and convert the stored value back to a DataFrame. 

122 

123 Args: 

124 instance: The document instance (None for class-level access) 

125 owner: The document class 

126 

127 Returns: 

128 PolarsField: The descriptor itself when accessed at class level 

129 pl.DataFrame: The retrieved DataFrame when accessed on an instance 

130 None: If no data is stored 

131 

132 """ 

133 if instance is None: 

134 return self 

135 

136 data = super().__get__(instance, owner) 

137 

138 if data is not None: 

139 return _read(data) 

140 

141 return None