Coverage for src/flight/numpy_server.py: 49%
41 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 01:30 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 01:30 +0000
1"""Server module for handling NumPy array operations over Apache Arrow Flight.
3This module provides a server implementation that can receive NumPy arrays via
4Arrow Flight protocol, perform computations on them, and return the results.
5It includes thread-safe storage for data and an abstract method for implementing
6specific computation logic.
7"""
9import logging
10import threading
11from abc import ABC, abstractmethod
13import numpy as np
14import pyarrow.flight as fl
16from .utils.alter import np_2_pa, pa_2_np
19class Server(fl.FlightServerBase, ABC):
20 """A Flight Server implementation that handles matrix data and performs computations on it.
22 This abstract base class provides the foundation for creating Flight servers that can
23 receive NumPy arrays, store them, perform computations, and return results. Subclasses
24 must implement the `f` method to define specific computation logic.
26 Attributes:
27 _logger: Logger instance for recording server activities.
28 _storage: Dictionary to store uploaded data associated with specific commands.
29 _lock: Threading lock for ensuring thread safety when accessing shared resources.
30 """
32 def __init__(
33 self, host: str = "127.0.0.1", port: int = 8080, logger: logging.Logger | None = None, **kwargs
34 ) -> None:
35 """Initialize the server with the provided host and port, and optionally a logger.
37 Args:
38 host: Host address for the server to bind to.
39 port: Port on which the server will listen.
40 logger: Optional logger to use for logging messages.
41 **kwargs: Additional arguments passed to the FlightServerBase constructor.
42 """
43 uri = f"grpc://{host}:{port}"
44 super().__init__(uri, **kwargs)
45 self._logger = logger or logging.getLogger(__name__)
46 self._storage = {} # Dictionary to store uploaded data
47 self._lock = threading.Lock() # Lock for thread safety
49 @property
50 def logger(self) -> logging.Logger:
51 """Get the logger instance used by this server.
53 Returns:
54 The logger instance for recording server activities.
55 """
56 return self._logger
58 @staticmethod
59 def _extract_command_from_ticket(ticket: fl.Ticket) -> str:
60 """Extract the command string from a Flight Ticket.
62 Args:
63 ticket: The Flight Ticket containing the command.
65 Returns:
66 The command string extracted from the ticket.
67 """
68 return ticket.ticket.decode("utf-8")
70 def do_put(
71 self,
72 context: fl.ServerCallContext,
73 descriptor: fl.FlightDescriptor,
74 reader: fl.MetadataRecordBatchReader,
75 writer: fl.FlightMetadataWriter,
76 ) -> fl.FlightDescriptor:
77 """Handle a PUT request, storing the provided data in the server's storage.
79 This method is called when a client sends data to the server. It reads the
80 Arrow Table from the reader, stores it in the server's storage using the
81 command from the descriptor as the key, and returns a descriptor confirming
82 the storage.
84 Args:
85 context: The request context containing client information.
86 descriptor: The Flight Descriptor for the PUT request containing the command.
87 reader: Reader for reading the Arrow Table data sent by the client.
88 writer: Writer for writing metadata responses back to the client.
90 Returns:
91 A Flight Descriptor confirming the data storage.
92 """
93 with self._lock: # Ensure thread safety
94 # Extract command and read data
95 command = descriptor.command.decode("utf-8")
96 self.logger.info(f"Processing PUT request for command: {command}")
98 table = reader.read_all()
99 self.logger.info(f"Table: {table}")
101 # Store the table using the command as the key
102 self._storage[command] = table
103 self.logger.info(f"Data stored for command: {command}")
105 return fl.FlightDescriptor.for_command(command)
107 def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket) -> fl.RecordBatchStream:
108 """Handle a GET request, retrieving and processing stored data.
110 This method is called when a client requests data from the server. It extracts
111 the command from the ticket, retrieves the corresponding data from storage,
112 processes it using the abstract `f` method, and returns the results.
114 Args:
115 context: The request context containing client information.
116 ticket: The Flight Ticket containing the command for the GET request.
118 Returns:
119 A RecordBatchStream containing the processed result data.
121 Raises:
122 fl.FlightServerError: If no data is found for the requested command.
123 """
124 # Extract command from ticket
125 command = self._extract_command_from_ticket(ticket)
126 self.logger.info(f"Processing GET request for command: {command}")
128 # Retrieve the stored table
129 if command not in self._storage:
130 raise fl.FlightServerError(f"No data found for command: {command}")
132 table = self._storage[command]
133 self.logger.info(f"Retrieved data for command: {command}")
135 # Convert Arrow Table to NumPy arrays
136 matrices = pa_2_np(table)
138 # Apply the computation function and convert results back to Arrow
139 np_data = self.f(matrices)
140 result_table = np_2_pa(np_data)
142 self.logger.info(f"Result schema: {result_table.schema.names}")
143 self.logger.info("Computation completed. Returning results.")
145 # Return results as a RecordBatchStream
146 return fl.RecordBatchStream(result_table)
148 @classmethod
149 def start(
150 cls, host: str = "127.0.0.1", port: int = 8080, logger: logging.Logger | None = None, **kwargs
151 ) -> "Server": # pragma: no cover
152 """Create and start a server instance.
154 This class method creates a new server instance with the specified parameters
155 and returns it. The actual server is not started (serve() is not called).
157 Args:
158 host: The host address to bind the server to.
159 port: The port on which to run the server.
160 logger: Optional logger to use for recording server activities.
161 **kwargs: Additional arguments passed to the constructor.
163 Returns:
164 A configured server instance ready to be started.
165 """
166 server = cls(host=host, port=port, logger=logger, **kwargs)
167 server.logger.info(f"Starting {cls.__name__} Flight server on {host}:{port}...")
168 # server.serve() # Uncomment to actually start the server
169 return server
171 @abstractmethod
172 def f(self, matrices: dict[str, np.ndarray]) -> dict[str, np.ndarray]: # pragma: no cover
173 """Process the input matrices and return the computation results.
175 This abstract method must be implemented by subclasses to define the specific
176 computation logic to be applied to the input data.
178 Args:
179 matrices: A dictionary mapping column names to NumPy arrays containing
180 the input data to process.
182 Returns:
183 A dictionary mapping column names to NumPy arrays containing the
184 computation results.
185 """
186 ...