Coverage for src / flight / numpy_client.py: 95%

39 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-27 04:42 +0000

1"""Client module for handling NumPy array operations over Apache Arrow Flight. 

2 

3This module provides a client interface for sending NumPy arrays to a Flight server, 

4retrieving data, and performing computations with automatic conversion between 

5NumPy arrays and Arrow Tables. 

6""" 

7 

8import numpy as np 

9import pyarrow as pa 

10import pyarrow.flight as fl 

11 

12from .utils.alter import np_2_pa, pa_2_np 

13 

14 

15class Client: 

16 """A client for handling NumPy array operations over Apache Arrow Flight. 

17 

18 This class provides an interface for sending NumPy arrays to a Flight server, 

19 retrieving data, and performing computations. It handles the conversion between 

20 NumPy arrays and Arrow Tables automatically. 

21 

22 Attributes: 

23 _client (fl.FlightClient): The underlying Flight client for network communication. 

24 """ 

25 

26 def __init__(self, location: str, **kwargs: object) -> None: 

27 """Initialize the NumpyClient with a Flight server location. 

28 

29 Args: 

30 location: The URI location of the Flight server to connect to. 

31 **kwargs: Additional keyword arguments to pass to the Flight client. 

32 """ 

33 self._location = location 

34 self._kwargs = kwargs 

35 

36 def __enter__(self) -> "Client": 

37 """Open the database connection. 

38 

39 Returns: 

40 Client: The client instance with an active connection. 

41 """ 

42 self._client = fl.connect(self._location, **self._kwargs) 

43 return self 

44 

45 def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: 

46 """Close the connection. 

47 

48 Args: 

49 exc_type: The exception type if an exception was raised in the context. 

50 exc_val: The exception value if an exception was raised in the context. 

51 exc_tb: The traceback if an exception was raised in the context. 

52 

53 Raises: 

54 Exception: Re-raises any exception that occurred in the context. 

55 """ 

56 self._client.close() 

57 if exc_val: # pragma: no cover 

58 raise 

59 

60 @property 

61 def flight(self) -> fl.FlightClient: 

62 """Get the underlying Flight client. 

63 

64 Returns: 

65 fl.FlightClient: The Flight client instance used for communication. 

66 """ 

67 return self._client 

68 

69 @staticmethod 

70 def descriptor(command: str) -> fl.FlightDescriptor: 

71 """Create a FlightDescriptor for an opaque command. 

72 

73 A FlightDescriptor is used to identify and describe the data being transferred 

74 over the Flight protocol. 

75 

76 Args: 

77 command: The command string that identifies the operation to perform. 

78 

79 Returns: 

80 A FlightDescriptor containing the command. 

81 """ 

82 return fl.FlightDescriptor.for_command(command.encode()) 

83 

84 def write(self, command: str, data: dict[str, np.ndarray]) -> None: 

85 """Write NumPy array data to the Flight server. 

86 

87 This method converts the input NumPy arrays to an Arrow Table and sends it 

88 to the server using the specified command. 

89 

90 Args: 

91 command: The command string identifying the operation. 

92 data: A dictionary mapping column names to NumPy arrays. 

93 

94 Raises: 

95 FlightError: If there's an error in the Flight protocol communication. 

96 ValueError: If the data cannot be converted to an Arrow Table. 

97 """ 

98 # Create a descriptor for the data transfer 

99 descriptor = self.descriptor(command) 

100 

101 # check if the dictionary is empty 

102 if not data: 

103 msg = "Empty data" 

104 raise ValueError(msg) 

105 

106 # Convert NumPy arrays to Arrow Table 

107 table = np_2_pa(data) 

108 

109 if table.num_rows == 0: 

110 msg = "Empty table" 

111 raise TypeError(msg) 

112 

113 # Initialize the write operation with the server 

114 writer, _ = self.flight.do_put(descriptor, table.schema) 

115 

116 try: 

117 # Send the data to the server 

118 writer.write_table(table) 

119 finally: 

120 # Ensure the writer is closed even if an error occurs 

121 writer.close() 

122 

123 def get(self, command: str) -> pa.Table: 

124 """Retrieve data from the Flight server. 

125 

126 Issues a GET request to the server and returns the results as an Arrow Table. 

127 

128 Args: 

129 command: The command string identifying the data to retrieve. 

130 

131 Returns: 

132 An Arrow Table containing the retrieved data. 

133 

134 Raises: 

135 FlightError: If there's an error in the Flight protocol communication. 

136 """ 

137 # Create a ticket for the data request 

138 ticket = fl.Ticket(command) 

139 

140 # Get a reader for the requested data 

141 reader = self.flight.do_get(ticket) 

142 

143 # Read and return all data as an Arrow Table 

144 return reader.read_all() 

145 

146 def compute(self, command: str, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]: 

147 """Send data to the server, perform computation, and retrieve results. 

148 

149 This is a convenience method that combines write and get operations into 

150 a single call. It handles the conversion between NumPy arrays and Arrow 

151 Tables in both directions. 

152 

153 Args: 

154 command: The command string identifying the computation to perform. 

155 data: A dictionary mapping column names to NumPy arrays for input. 

156 

157 Returns: 

158 A dictionary mapping column names to NumPy arrays containing the results. 

159 

160 Raises: 

161 FlightError: If there's an error in the Flight protocol communication. 

162 ValueError: If the data cannot be converted between formats. 

163 """ 

164 # Send input data to the server 

165 self.write(command, data) 

166 

167 # Retrieve and convert results back to NumPy arrays 

168 return pa_2_np(self.get(command))