Coverage for src/flight/numpy_client.py: 95%

37 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 01:30 +0000

1"""Client module for handling NumPy array operations over Apache Arrow Flight. 

2 

3This module provides a client interface for sending NumPy arrays to a Flight server, 

4retrieving data, and performing computations with automatic conversion between 

5NumPy arrays and Arrow Tables. 

6""" 

7 

8import numpy as np 

9import pyarrow as pa 

10import pyarrow.flight as fl 

11 

12from .utils.alter import np_2_pa, pa_2_np 

13 

14 

15class Client: 

16 """A client for handling NumPy array operations over Apache Arrow Flight. 

17 

18 This class provides an interface for sending NumPy arrays to a Flight server, 

19 retrieving data, and performing computations. It handles the conversion between 

20 NumPy arrays and Arrow Tables automatically. 

21 

22 Attributes: 

23 _client (fl.FlightClient): The underlying Flight client for network communication. 

24 """ 

25 

26 def __init__(self, location: str, **kwargs) -> None: 

27 """Initialize the NumpyClient with a Flight server location. 

28 

29 Args: 

30 location: The URI location of the Flight server to connect to. 

31 **kwargs: Additional keyword arguments to pass to the Flight client. 

32 """ 

33 self._location = location 

34 self._kwargs = kwargs 

35 

36 def __enter__(self) -> "Client": 

37 """Open the database connection. 

38 

39 Returns: 

40 Client: The client instance with an active connection. 

41 """ 

42 self._client = fl.connect(self._location, **self._kwargs) 

43 return self 

44 

45 def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: 

46 """Close the connection. 

47 

48 Args: 

49 exc_type: The exception type if an exception was raised in the context. 

50 exc_val: The exception value if an exception was raised in the context. 

51 exc_tb: The traceback if an exception was raised in the context. 

52 

53 Raises: 

54 Exception: Re-raises any exception that occurred in the context. 

55 """ 

56 self._client.close() 

57 if exc_val: # pragma: no cover 

58 raise 

59 

60 @property 

61 def flight(self) -> fl.FlightClient: 

62 """Get the underlying Flight client. 

63 

64 Returns: 

65 fl.FlightClient: The Flight client instance used for communication. 

66 """ 

67 return self._client 

68 

69 @staticmethod 

70 def descriptor(command: str) -> fl.FlightDescriptor: 

71 """Create a FlightDescriptor for an opaque command. 

72 

73 A FlightDescriptor is used to identify and describe the data being transferred 

74 over the Flight protocol. 

75 

76 Args: 

77 command: The command string that identifies the operation to perform. 

78 

79 Returns: 

80 A FlightDescriptor containing the command. 

81 """ 

82 return fl.FlightDescriptor.for_command(command.encode()) 

83 

84 def write(self, command: str, data: dict[str, np.ndarray]) -> None: 

85 """Write NumPy array data to the Flight server. 

86 

87 This method converts the input NumPy arrays to an Arrow Table and sends it 

88 to the server using the specified command. 

89 

90 Args: 

91 command: The command string identifying the operation. 

92 data: A dictionary mapping column names to NumPy arrays. 

93 

94 Raises: 

95 FlightError: If there's an error in the Flight protocol communication. 

96 ValueError: If the data cannot be converted to an Arrow Table. 

97 """ 

98 # Create a descriptor for the data transfer 

99 descriptor = self.descriptor(command) 

100 

101 # check if the dictionary is empty 

102 if not data: 

103 raise ValueError("Data cannot be converted to an Arrow Table.") 

104 

105 # Convert NumPy arrays to Arrow Table 

106 table = np_2_pa(data) 

107 

108 if table.num_rows == 0: 

109 raise TypeError("The data cannot be converted to an Arrow Table.") 

110 

111 # Initialize the write operation with the server 

112 writer, _ = self.flight.do_put(descriptor, table.schema) 

113 

114 try: 

115 # Send the data to the server 

116 writer.write_table(table) 

117 finally: 

118 # Ensure the writer is closed even if an error occurs 

119 writer.close() 

120 

121 def get(self, command: str) -> pa.Table: 

122 """Retrieve data from the Flight server. 

123 

124 Issues a GET request to the server and returns the results as an Arrow Table. 

125 

126 Args: 

127 command: The command string identifying the data to retrieve. 

128 

129 Returns: 

130 An Arrow Table containing the retrieved data. 

131 

132 Raises: 

133 FlightError: If there's an error in the Flight protocol communication. 

134 """ 

135 # Create a ticket for the data request 

136 ticket = fl.Ticket(command) 

137 

138 # Get a reader for the requested data 

139 reader = self.flight.do_get(ticket) 

140 

141 # Read and return all data as an Arrow Table 

142 return reader.read_all() 

143 

144 def compute(self, command: str, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]: 

145 """Send data to the server, perform computation, and retrieve results. 

146 

147 This is a convenience method that combines write and get operations into 

148 a single call. It handles the conversion between NumPy arrays and Arrow 

149 Tables in both directions. 

150 

151 Args: 

152 command: The command string identifying the computation to perform. 

153 data: A dictionary mapping column names to NumPy arrays for input. 

154 

155 Returns: 

156 A dictionary mapping column names to NumPy arrays containing the results. 

157 

158 Raises: 

159 FlightError: If there's an error in the Flight protocol communication. 

160 ValueError: If the data cannot be converted between formats. 

161 """ 

162 # Send input data to the server 

163 self.write(command, data) 

164 

165 # Retrieve and convert results back to NumPy arrays 

166 return pa_2_np(self.get(command))