Coverage for src / flight / numpy_client.py: 95%
39 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-27 04:42 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-27 04:42 +0000
1"""Client module for handling NumPy array operations over Apache Arrow Flight.
3This module provides a client interface for sending NumPy arrays to a Flight server,
4retrieving data, and performing computations with automatic conversion between
5NumPy arrays and Arrow Tables.
6"""
8import numpy as np
9import pyarrow as pa
10import pyarrow.flight as fl
12from .utils.alter import np_2_pa, pa_2_np
15class Client:
16 """A client for handling NumPy array operations over Apache Arrow Flight.
18 This class provides an interface for sending NumPy arrays to a Flight server,
19 retrieving data, and performing computations. It handles the conversion between
20 NumPy arrays and Arrow Tables automatically.
22 Attributes:
23 _client (fl.FlightClient): The underlying Flight client for network communication.
24 """
26 def __init__(self, location: str, **kwargs: object) -> None:
27 """Initialize the NumpyClient with a Flight server location.
29 Args:
30 location: The URI location of the Flight server to connect to.
31 **kwargs: Additional keyword arguments to pass to the Flight client.
32 """
33 self._location = location
34 self._kwargs = kwargs
36 def __enter__(self) -> "Client":
37 """Open the database connection.
39 Returns:
40 Client: The client instance with an active connection.
41 """
42 self._client = fl.connect(self._location, **self._kwargs)
43 return self
45 def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
46 """Close the connection.
48 Args:
49 exc_type: The exception type if an exception was raised in the context.
50 exc_val: The exception value if an exception was raised in the context.
51 exc_tb: The traceback if an exception was raised in the context.
53 Raises:
54 Exception: Re-raises any exception that occurred in the context.
55 """
56 self._client.close()
57 if exc_val: # pragma: no cover
58 raise
60 @property
61 def flight(self) -> fl.FlightClient:
62 """Get the underlying Flight client.
64 Returns:
65 fl.FlightClient: The Flight client instance used for communication.
66 """
67 return self._client
69 @staticmethod
70 def descriptor(command: str) -> fl.FlightDescriptor:
71 """Create a FlightDescriptor for an opaque command.
73 A FlightDescriptor is used to identify and describe the data being transferred
74 over the Flight protocol.
76 Args:
77 command: The command string that identifies the operation to perform.
79 Returns:
80 A FlightDescriptor containing the command.
81 """
82 return fl.FlightDescriptor.for_command(command.encode())
84 def write(self, command: str, data: dict[str, np.ndarray]) -> None:
85 """Write NumPy array data to the Flight server.
87 This method converts the input NumPy arrays to an Arrow Table and sends it
88 to the server using the specified command.
90 Args:
91 command: The command string identifying the operation.
92 data: A dictionary mapping column names to NumPy arrays.
94 Raises:
95 FlightError: If there's an error in the Flight protocol communication.
96 ValueError: If the data cannot be converted to an Arrow Table.
97 """
98 # Create a descriptor for the data transfer
99 descriptor = self.descriptor(command)
101 # check if the dictionary is empty
102 if not data:
103 msg = "Empty data"
104 raise ValueError(msg)
106 # Convert NumPy arrays to Arrow Table
107 table = np_2_pa(data)
109 if table.num_rows == 0:
110 msg = "Empty table"
111 raise TypeError(msg)
113 # Initialize the write operation with the server
114 writer, _ = self.flight.do_put(descriptor, table.schema)
116 try:
117 # Send the data to the server
118 writer.write_table(table)
119 finally:
120 # Ensure the writer is closed even if an error occurs
121 writer.close()
123 def get(self, command: str) -> pa.Table:
124 """Retrieve data from the Flight server.
126 Issues a GET request to the server and returns the results as an Arrow Table.
128 Args:
129 command: The command string identifying the data to retrieve.
131 Returns:
132 An Arrow Table containing the retrieved data.
134 Raises:
135 FlightError: If there's an error in the Flight protocol communication.
136 """
137 # Create a ticket for the data request
138 ticket = fl.Ticket(command)
140 # Get a reader for the requested data
141 reader = self.flight.do_get(ticket)
143 # Read and return all data as an Arrow Table
144 return reader.read_all()
146 def compute(self, command: str, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
147 """Send data to the server, perform computation, and retrieve results.
149 This is a convenience method that combines write and get operations into
150 a single call. It handles the conversion between NumPy arrays and Arrow
151 Tables in both directions.
153 Args:
154 command: The command string identifying the computation to perform.
155 data: A dictionary mapping column names to NumPy arrays for input.
157 Returns:
158 A dictionary mapping column names to NumPy arrays containing the results.
160 Raises:
161 FlightError: If there's an error in the Flight protocol communication.
162 ValueError: If the data cannot be converted between formats.
163 """
164 # Send input data to the server
165 self.write(command, data)
167 # Retrieve and convert results back to NumPy arrays
168 return pa_2_np(self.get(command))