Coverage for src / tinycta / linalg.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-03 00:31 +0000

1# Copyright (c) 2023 Thomas Schmelzer 

2# 

3# Permission is hereby granted, free of charge, to any person obtaining a copy 

4# of this software and associated documentation files (the "Software"), to deal 

5# in the Software without restriction, including without limitation the rights 

6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

7# copies of the Software, and to permit persons to whom the Software is 

8# furnished to do so, subject to the following conditions: 

9# 

10# The above copyright notice and this permission notice shall be included in all 

11# copies or substantial portions of the Software. 

12"""Linear algebra utilities that handle matrices with missing (NaN) values. 

13 

14All operations extract the finite submatrix via the ``valid`` function before 

15performing computations, so partially-observed covariance matrices are handled 

16gracefully without raising errors. 

17""" 

18 

19from __future__ import annotations 

20 

21import numpy as np 

22 

23 

24def valid(matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 

25 """Validates and processes a square matrix. 

26 

27 This function checks if the input matrix is square (i.e., has the same number 

28 of rows and columns). It also validates that the diagonal elements of the 

29 matrix are finite. If the matrix does not meet these criteria, an error is 

30 raised. If valid, it returns a boolean array indicating the validity of the 

31 diagonal elements and a sub-matrix with only the valid rows and columns. 

32 

33 Parameters: 

34 matrix (np.ndarray): A NumPy array representing the input matrix to be 

35 validated and processed. 

36 

37 Returns: 

38 tuple[np.ndarray, np.ndarray]: A tuple where the first element is a boolean 

39 array indicating which diagonal elements are 

40 finite, and the second element is a sub-matrix 

41 containing only the rows and columns 

42 corresponding to finite diagonal elements. 

43 

44 Raises: 

45 AssertionError: If the input matrix is not square. 

46 """ 

47 # make sure matrix is quadratic 

48 if matrix.shape[0] != matrix.shape[1]: 

49 raise AssertionError 

50 

51 v = np.isfinite(np.diag(matrix)) 

52 return v, matrix[:, v][v] 

53 

54 

55# that's somewhat not needed... 

56def a_norm(vector: np.ndarray, matrix: np.ndarray | None = None) -> float: 

57 """Calculate the generalized norm of a vector optionally using a matrix. 

58 

59 The function computes the Euclidean norm of the vector if no matrix is 

60 provided. If a matrix is provided, it computes the generalized norm 

61 of the vector with respect to the quadratic form defined by the matrix. 

62 The function ensures both the matrix and vector meet certain validity 

63 conditions before proceeding with calculations. 

64 

65 Parameters: 

66 vector: np.ndarray 

67 The input vector. Must be a one-dimensional numpy array. 

68 

69 matrix: np.ndarray | None, optional 

70 An optional square matrix defining the quadratic form to calculate the 

71 generalized norm. If provided, its size must match the size of the 

72 input vector. If None, the function computes the standard Euclidean 

73 norm of the vector. 

74 

75 Raises: 

76 AssertionError 

77 If the matrix is not square or if its dimensions do not align with 

78 the size of the input vector. 

79 

80 Returns: 

81 float 

82 The norm of the vector calculated using the given quadratic form 

83 defined by the matrix, or the Euclidean norm if no matrix is supplied. 

84 If the computation is invalid, returns NaN. 

85 """ 

86 if matrix is None: 

87 return float(np.linalg.norm(vector[np.isfinite(vector)], 2)) 

88 

89 # make sure matrix is quadratic 

90 if matrix.shape[0] != matrix.shape[1]: 

91 raise AssertionError 

92 

93 # make sure the vector has the right number of entries 

94 if vector.size != matrix.shape[0]: 

95 raise AssertionError 

96 

97 v, mat = valid(matrix) 

98 

99 if v.any(): 

100 return float(np.sqrt(np.dot(vector[v], np.dot(mat, vector[v])))) 

101 return float("nan") 

102 

103 

104def inv_a_norm(vector: np.ndarray, matrix: np.ndarray | None = None) -> float: 

105 """Calculates the inverse A-norm of a given vector using an optional matrix. 

106 

107 If the matrix is not provided, it defaults to calculating the Euclidean norm of the 

108 finite entries in the vector. If the matrix is provided, it checks that the 

109 matrix is square and that the dimensions are compatible with the vector before 

110 computing the norm. 

111 

112 Parameters 

113 ---------- 

114 vector : np.ndarray 

115 The input vector for which the norm is to be calculated. 

116 matrix : np.ndarray | None, optional 

117 An optional square matrix used for computing the norm. If not provided, 

118 the function computes the Euclidean norm. 

119 

120 Returns: 

121 ------- 

122 float 

123 The computed norm as a float value. If no valid entries exist for the 

124 calculation, it returns 'np.nan'. 

125 

126 Raises: 

127 ------ 

128 AssertionError 

129 If the matrix is not square or if the vector's size is incompatible with 

130 the matrix's dimensions. 

131 """ 

132 if matrix is None: 

133 return float(np.linalg.norm(vector[np.isfinite(vector)], 2)) 

134 

135 # make sure matrix is quadratic 

136 if matrix.shape[0] != matrix.shape[1]: 

137 raise AssertionError 

138 

139 # make sure the vector has the right number of entries 

140 if vector.size != matrix.shape[0]: 

141 raise AssertionError 

142 

143 v, mat = valid(matrix) 

144 

145 if v.any(): 

146 return float(np.sqrt(np.dot(vector[v], np.linalg.solve(mat, vector[v])))) 

147 return float("nan") 

148 

149 

150def solve(matrix: np.ndarray, rhs: np.ndarray) -> np.ndarray: 

151 """Solve a linear system of equations using parts of a given matrix marked as valid. 

152 

153 This function solves a linear system Ax = b for the subset of the system where 

154 certain rows and columns of the square matrix A, and corresponding entries of 

155 the right-hand side vector b, are marked as valid. The solution is computed for 

156 the valid subset using NumPy's linear algebra solver. The input matrix must be 

157 a square matrix, and the size of the vector must match the number of rows (or 

158 columns) of the matrix. 

159 

160 Parameters: 

161 matrix: np.ndarray 

162 A square matrix (2D array) representing the coefficients of the linear 

163 system. 

164 rhs: np.ndarray 

165 A 1D array representing the right-hand side vector of the equation. Its 

166 size must match the number of rows in the matrix. 

167 

168 Returns: 

169 np.ndarray 

170 A 1D array containing the solution to the system for the valid subset. If 

171 no valid rows or columns exist, the array will contain only NaN values. 

172 """ 

173 # make sure matrix is quadratic 

174 if matrix.shape[0] != matrix.shape[1]: 

175 raise AssertionError 

176 

177 # make sure the vector rhs has the right number of entries 

178 if rhs.size != matrix.shape[0]: 

179 raise AssertionError 

180 

181 x = np.nan * np.ones(rhs.size) 

182 v, mat = valid(matrix) 

183 

184 if v.any(): 

185 x[v] = np.linalg.solve(mat, rhs[v]) 

186 

187 return x