Coverage for src/flight/numpy_server.py: 49%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 01:30 +0000

1"""Server module for handling NumPy array operations over Apache Arrow Flight. 

2 

3This module provides a server implementation that can receive NumPy arrays via 

4Arrow Flight protocol, perform computations on them, and return the results. 

5It includes thread-safe storage for data and an abstract method for implementing 

6specific computation logic. 

7""" 

8 

9import logging 

10import threading 

11from abc import ABC, abstractmethod 

12 

13import numpy as np 

14import pyarrow.flight as fl 

15 

16from .utils.alter import np_2_pa, pa_2_np 

17 

18 

19class Server(fl.FlightServerBase, ABC): 

20 """A Flight Server implementation that handles matrix data and performs computations on it. 

21 

22 This abstract base class provides the foundation for creating Flight servers that can 

23 receive NumPy arrays, store them, perform computations, and return results. Subclasses 

24 must implement the `f` method to define specific computation logic. 

25 

26 Attributes: 

27 _logger: Logger instance for recording server activities. 

28 _storage: Dictionary to store uploaded data associated with specific commands. 

29 _lock: Threading lock for ensuring thread safety when accessing shared resources. 

30 """ 

31 

32 def __init__( 

33 self, host: str = "127.0.0.1", port: int = 8080, logger: logging.Logger | None = None, **kwargs 

34 ) -> None: 

35 """Initialize the server with the provided host and port, and optionally a logger. 

36 

37 Args: 

38 host: Host address for the server to bind to. 

39 port: Port on which the server will listen. 

40 logger: Optional logger to use for logging messages. 

41 **kwargs: Additional arguments passed to the FlightServerBase constructor. 

42 """ 

43 uri = f"grpc://{host}:{port}" 

44 super().__init__(uri, **kwargs) 

45 self._logger = logger or logging.getLogger(__name__) 

46 self._storage = {} # Dictionary to store uploaded data 

47 self._lock = threading.Lock() # Lock for thread safety 

48 

49 @property 

50 def logger(self) -> logging.Logger: 

51 """Get the logger instance used by this server. 

52 

53 Returns: 

54 The logger instance for recording server activities. 

55 """ 

56 return self._logger 

57 

58 @staticmethod 

59 def _extract_command_from_ticket(ticket: fl.Ticket) -> str: 

60 """Extract the command string from a Flight Ticket. 

61 

62 Args: 

63 ticket: The Flight Ticket containing the command. 

64 

65 Returns: 

66 The command string extracted from the ticket. 

67 """ 

68 return ticket.ticket.decode("utf-8") 

69 

70 def do_put( 

71 self, 

72 context: fl.ServerCallContext, 

73 descriptor: fl.FlightDescriptor, 

74 reader: fl.MetadataRecordBatchReader, 

75 writer: fl.FlightMetadataWriter, 

76 ) -> fl.FlightDescriptor: 

77 """Handle a PUT request, storing the provided data in the server's storage. 

78 

79 This method is called when a client sends data to the server. It reads the 

80 Arrow Table from the reader, stores it in the server's storage using the 

81 command from the descriptor as the key, and returns a descriptor confirming 

82 the storage. 

83 

84 Args: 

85 context: The request context containing client information. 

86 descriptor: The Flight Descriptor for the PUT request containing the command. 

87 reader: Reader for reading the Arrow Table data sent by the client. 

88 writer: Writer for writing metadata responses back to the client. 

89 

90 Returns: 

91 A Flight Descriptor confirming the data storage. 

92 """ 

93 with self._lock: # Ensure thread safety 

94 # Extract command and read data 

95 command = descriptor.command.decode("utf-8") 

96 self.logger.info(f"Processing PUT request for command: {command}") 

97 

98 table = reader.read_all() 

99 self.logger.info(f"Table: {table}") 

100 

101 # Store the table using the command as the key 

102 self._storage[command] = table 

103 self.logger.info(f"Data stored for command: {command}") 

104 

105 return fl.FlightDescriptor.for_command(command) 

106 

107 def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket) -> fl.RecordBatchStream: 

108 """Handle a GET request, retrieving and processing stored data. 

109 

110 This method is called when a client requests data from the server. It extracts 

111 the command from the ticket, retrieves the corresponding data from storage, 

112 processes it using the abstract `f` method, and returns the results. 

113 

114 Args: 

115 context: The request context containing client information. 

116 ticket: The Flight Ticket containing the command for the GET request. 

117 

118 Returns: 

119 A RecordBatchStream containing the processed result data. 

120 

121 Raises: 

122 fl.FlightServerError: If no data is found for the requested command. 

123 """ 

124 # Extract command from ticket 

125 command = self._extract_command_from_ticket(ticket) 

126 self.logger.info(f"Processing GET request for command: {command}") 

127 

128 # Retrieve the stored table 

129 if command not in self._storage: 

130 raise fl.FlightServerError(f"No data found for command: {command}") 

131 

132 table = self._storage[command] 

133 self.logger.info(f"Retrieved data for command: {command}") 

134 

135 # Convert Arrow Table to NumPy arrays 

136 matrices = pa_2_np(table) 

137 

138 # Apply the computation function and convert results back to Arrow 

139 np_data = self.f(matrices) 

140 result_table = np_2_pa(np_data) 

141 

142 self.logger.info(f"Result schema: {result_table.schema.names}") 

143 self.logger.info("Computation completed. Returning results.") 

144 

145 # Return results as a RecordBatchStream 

146 return fl.RecordBatchStream(result_table) 

147 

148 @classmethod 

149 def start( 

150 cls, host: str = "127.0.0.1", port: int = 8080, logger: logging.Logger | None = None, **kwargs 

151 ) -> "Server": # pragma: no cover 

152 """Create and start a server instance. 

153 

154 This class method creates a new server instance with the specified parameters 

155 and returns it. The actual server is not started (serve() is not called). 

156 

157 Args: 

158 host: The host address to bind the server to. 

159 port: The port on which to run the server. 

160 logger: Optional logger to use for recording server activities. 

161 **kwargs: Additional arguments passed to the constructor. 

162 

163 Returns: 

164 A configured server instance ready to be started. 

165 """ 

166 server = cls(host=host, port=port, logger=logger, **kwargs) 

167 server.logger.info(f"Starting {cls.__name__} Flight server on {host}:{port}...") 

168 # server.serve() # Uncomment to actually start the server 

169 return server 

170 

171 @abstractmethod 

172 def f(self, matrices: dict[str, np.ndarray]) -> dict[str, np.ndarray]: # pragma: no cover 

173 """Process the input matrices and return the computation results. 

174 

175 This abstract method must be implemented by subclasses to define the specific 

176 computation logic to be applied to the input data. 

177 

178 Args: 

179 matrices: A dictionary mapping column names to NumPy arrays containing 

180 the input data to process. 

181 

182 Returns: 

183 A dictionary mapping column names to NumPy arrays containing the 

184 computation results. 

185 """ 

186 ...