Coverage for src / tinycta / linalg.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 13:04 +0000

1# Copyright (c) 2023 Thomas Schmelzer 

2# 

3# Permission is hereby granted, free of charge, to any person obtaining a copy 

4# of this software and associated documentation files (the "Software"), to deal 

5# in the Software without restriction, including without limitation the rights 

6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

7# copies of the Software, and to permit persons to whom the Software is 

8# furnished to do so, subject to the following conditions: 

9# 

10# The above copyright notice and this permission notice shall be included in all 

11# copies or substantial portions of the Software. 

12"""Linear algebra utilities that handle matrices with missing (NaN) values. 

13 

14All operations extract the finite submatrix via the ``valid`` function before 

15performing computations, so partially-observed covariance matrices are handled 

16gracefully without raising errors. 

17""" 

18 

19from __future__ import annotations 

20 

21import numpy as np 

22 

23 

24def valid(matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 

25 """Extract the finite submatrix by checking diagonal elements. 

26 

27 Args: 

28 matrix: A square NumPy array to be validated and processed. 

29 

30 Returns: 

31 A tuple of (mask, submatrix) where mask is a boolean array of finite 

32 diagonal elements and submatrix contains only the rows and columns 

33 with finite diagonals. 

34 

35 Raises: 

36 AssertionError: If the input matrix is not square. 

37 """ 

38 if matrix.shape[0] != matrix.shape[1]: 

39 raise AssertionError 

40 

41 v = np.isfinite(np.diag(matrix)) 

42 return v, matrix[:, v][v] 

43 

44 

45def a_norm(vector: np.ndarray, matrix: np.ndarray | None = None) -> float: 

46 """Calculate the generalized norm of a vector with respect to a matrix. 

47 

48 Computes the Euclidean norm when no matrix is provided, or the quadratic 

49 form ``sqrt(v^T A v)`` when a matrix is given. 

50 

51 Args: 

52 vector: The input one-dimensional vector. 

53 matrix: Square matrix defining the quadratic form. Its size must match 

54 the vector length. If None, the standard Euclidean norm is returned. 

55 

56 Returns: 

57 The computed norm, or NaN if no valid entries exist. 

58 

59 Raises: 

60 AssertionError: If the matrix is not square or its dimensions do not 

61 match the vector. 

62 """ 

63 if matrix is None: 

64 return float(np.linalg.norm(vector[np.isfinite(vector)], 2)) 

65 

66 if matrix.shape[0] != matrix.shape[1]: 

67 raise AssertionError 

68 

69 if vector.size != matrix.shape[0]: 

70 raise AssertionError 

71 

72 v, mat = valid(matrix) 

73 

74 if v.any(): 

75 return float(np.sqrt(np.dot(vector[v], np.dot(mat, vector[v])))) 

76 return float("nan") 

77 

78 

79def inv_a_norm(vector: np.ndarray, matrix: np.ndarray | None = None) -> float: 

80 """Calculate the inverse A-norm of a vector using an optional matrix. 

81 

82 Computes ``sqrt(v^T A^{-1} v)`` when a matrix is provided, or the 

83 Euclidean norm of finite entries when no matrix is given. 

84 

85 Args: 

86 vector: The input vector for which the norm is to be calculated. 

87 matrix: Square matrix used for computing the norm. If not provided, 

88 the Euclidean norm of finite entries is returned. 

89 

90 Returns: 

91 The computed norm as a float. Returns ``np.nan`` if no valid entries exist. 

92 

93 Raises: 

94 AssertionError: If the matrix is not square or the vector size is 

95 incompatible with the matrix dimensions. 

96 """ 

97 if matrix is None: 

98 return float(np.linalg.norm(vector[np.isfinite(vector)], 2)) 

99 

100 if matrix.shape[0] != matrix.shape[1]: 

101 raise AssertionError 

102 

103 if vector.size != matrix.shape[0]: 

104 raise AssertionError 

105 

106 v, mat = valid(matrix) 

107 

108 if v.any(): 

109 return float(np.sqrt(np.dot(vector[v], np.linalg.solve(mat, vector[v])))) 

110 return float("nan") 

111 

112 

113def solve(matrix: np.ndarray, rhs: np.ndarray) -> np.ndarray: 

114 """Solve a linear system restricted to the valid (finite-diagonal) submatrix. 

115 

116 Args: 

117 matrix: Square matrix representing the coefficients of the linear system. 

118 rhs: Right-hand side vector. Its length must match the number of rows 

119 in the matrix. 

120 

121 Returns: 

122 Solution vector of the same length as ``rhs``. Entries corresponding 

123 to invalid rows/columns are set to NaN. 

124 

125 Raises: 

126 AssertionError: If the matrix is not square or ``rhs`` length does not 

127 match the matrix dimensions. 

128 """ 

129 if matrix.shape[0] != matrix.shape[1]: 

130 raise AssertionError 

131 

132 if rhs.size != matrix.shape[0]: 

133 raise AssertionError 

134 

135 x = np.nan * np.ones(rhs.size) 

136 v, mat = valid(matrix) 

137 

138 if v.any(): 

139 x[v] = np.linalg.solve(mat, rhs[v]) 

140 

141 return x