Coverage for src/flight/numpy_client.py: 95%
37 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 01:30 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 01:30 +0000
1"""Client module for handling NumPy array operations over Apache Arrow Flight.
3This module provides a client interface for sending NumPy arrays to a Flight server,
4retrieving data, and performing computations with automatic conversion between
5NumPy arrays and Arrow Tables.
6"""
8import numpy as np
9import pyarrow as pa
10import pyarrow.flight as fl
12from .utils.alter import np_2_pa, pa_2_np
15class Client:
16 """A client for handling NumPy array operations over Apache Arrow Flight.
18 This class provides an interface for sending NumPy arrays to a Flight server,
19 retrieving data, and performing computations. It handles the conversion between
20 NumPy arrays and Arrow Tables automatically.
22 Attributes:
23 _client (fl.FlightClient): The underlying Flight client for network communication.
24 """
26 def __init__(self, location: str, **kwargs) -> None:
27 """Initialize the NumpyClient with a Flight server location.
29 Args:
30 location: The URI location of the Flight server to connect to.
31 **kwargs: Additional keyword arguments to pass to the Flight client.
32 """
33 self._location = location
34 self._kwargs = kwargs
36 def __enter__(self) -> "Client":
37 """Open the database connection.
39 Returns:
40 Client: The client instance with an active connection.
41 """
42 self._client = fl.connect(self._location, **self._kwargs)
43 return self
45 def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
46 """Close the connection.
48 Args:
49 exc_type: The exception type if an exception was raised in the context.
50 exc_val: The exception value if an exception was raised in the context.
51 exc_tb: The traceback if an exception was raised in the context.
53 Raises:
54 Exception: Re-raises any exception that occurred in the context.
55 """
56 self._client.close()
57 if exc_val: # pragma: no cover
58 raise
60 @property
61 def flight(self) -> fl.FlightClient:
62 """Get the underlying Flight client.
64 Returns:
65 fl.FlightClient: The Flight client instance used for communication.
66 """
67 return self._client
69 @staticmethod
70 def descriptor(command: str) -> fl.FlightDescriptor:
71 """Create a FlightDescriptor for an opaque command.
73 A FlightDescriptor is used to identify and describe the data being transferred
74 over the Flight protocol.
76 Args:
77 command: The command string that identifies the operation to perform.
79 Returns:
80 A FlightDescriptor containing the command.
81 """
82 return fl.FlightDescriptor.for_command(command.encode())
84 def write(self, command: str, data: dict[str, np.ndarray]) -> None:
85 """Write NumPy array data to the Flight server.
87 This method converts the input NumPy arrays to an Arrow Table and sends it
88 to the server using the specified command.
90 Args:
91 command: The command string identifying the operation.
92 data: A dictionary mapping column names to NumPy arrays.
94 Raises:
95 FlightError: If there's an error in the Flight protocol communication.
96 ValueError: If the data cannot be converted to an Arrow Table.
97 """
98 # Create a descriptor for the data transfer
99 descriptor = self.descriptor(command)
101 # check if the dictionary is empty
102 if not data:
103 raise ValueError("Data cannot be converted to an Arrow Table.")
105 # Convert NumPy arrays to Arrow Table
106 table = np_2_pa(data)
108 if table.num_rows == 0:
109 raise TypeError("The data cannot be converted to an Arrow Table.")
111 # Initialize the write operation with the server
112 writer, _ = self.flight.do_put(descriptor, table.schema)
114 try:
115 # Send the data to the server
116 writer.write_table(table)
117 finally:
118 # Ensure the writer is closed even if an error occurs
119 writer.close()
121 def get(self, command: str) -> pa.Table:
122 """Retrieve data from the Flight server.
124 Issues a GET request to the server and returns the results as an Arrow Table.
126 Args:
127 command: The command string identifying the data to retrieve.
129 Returns:
130 An Arrow Table containing the retrieved data.
132 Raises:
133 FlightError: If there's an error in the Flight protocol communication.
134 """
135 # Create a ticket for the data request
136 ticket = fl.Ticket(command)
138 # Get a reader for the requested data
139 reader = self.flight.do_get(ticket)
141 # Read and return all data as an Arrow Table
142 return reader.read_all()
144 def compute(self, command: str, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
145 """Send data to the server, perform computation, and retrieve results.
147 This is a convenience method that combines write and get operations into
148 a single call. It handles the conversion between NumPy arrays and Arrow
149 Tables in both directions.
151 Args:
152 command: The command string identifying the computation to perform.
153 data: A dictionary mapping column names to NumPy arrays for input.
155 Returns:
156 A dictionary mapping column names to NumPy arrays containing the results.
158 Raises:
159 FlightError: If there's an error in the Flight protocol communication.
160 ValueError: If the data cannot be converted between formats.
161 """
162 # Send input data to the server
163 self.write(command, data)
165 # Retrieve and convert results back to NumPy arrays
166 return pa_2_np(self.get(command))