Coverage for src / tinycta / linalg.py: 100%
39 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-03 00:31 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-03 00:31 +0000
1# Copyright (c) 2023 Thomas Schmelzer
2#
3# Permission is hereby granted, free of charge, to any person obtaining a copy
4# of this software and associated documentation files (the "Software"), to deal
5# in the Software without restriction, including without limitation the rights
6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7# copies of the Software, and to permit persons to whom the Software is
8# furnished to do so, subject to the following conditions:
9#
10# The above copyright notice and this permission notice shall be included in all
11# copies or substantial portions of the Software.
12"""Linear algebra utilities that handle matrices with missing (NaN) values.
14All operations extract the finite submatrix via the ``valid`` function before
15performing computations, so partially-observed covariance matrices are handled
16gracefully without raising errors.
17"""
19from __future__ import annotations
21import numpy as np
24def valid(matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
25 """Validates and processes a square matrix.
27 This function checks if the input matrix is square (i.e., has the same number
28 of rows and columns). It also validates that the diagonal elements of the
29 matrix are finite. If the matrix does not meet these criteria, an error is
30 raised. If valid, it returns a boolean array indicating the validity of the
31 diagonal elements and a sub-matrix with only the valid rows and columns.
33 Parameters:
34 matrix (np.ndarray): A NumPy array representing the input matrix to be
35 validated and processed.
37 Returns:
38 tuple[np.ndarray, np.ndarray]: A tuple where the first element is a boolean
39 array indicating which diagonal elements are
40 finite, and the second element is a sub-matrix
41 containing only the rows and columns
42 corresponding to finite diagonal elements.
44 Raises:
45 AssertionError: If the input matrix is not square.
46 """
47 # make sure matrix is quadratic
48 if matrix.shape[0] != matrix.shape[1]:
49 raise AssertionError
51 v = np.isfinite(np.diag(matrix))
52 return v, matrix[:, v][v]
55# that's somewhat not needed...
56def a_norm(vector: np.ndarray, matrix: np.ndarray | None = None) -> float:
57 """Calculate the generalized norm of a vector optionally using a matrix.
59 The function computes the Euclidean norm of the vector if no matrix is
60 provided. If a matrix is provided, it computes the generalized norm
61 of the vector with respect to the quadratic form defined by the matrix.
62 The function ensures both the matrix and vector meet certain validity
63 conditions before proceeding with calculations.
65 Parameters:
66 vector: np.ndarray
67 The input vector. Must be a one-dimensional numpy array.
69 matrix: np.ndarray | None, optional
70 An optional square matrix defining the quadratic form to calculate the
71 generalized norm. If provided, its size must match the size of the
72 input vector. If None, the function computes the standard Euclidean
73 norm of the vector.
75 Raises:
76 AssertionError
77 If the matrix is not square or if its dimensions do not align with
78 the size of the input vector.
80 Returns:
81 float
82 The norm of the vector calculated using the given quadratic form
83 defined by the matrix, or the Euclidean norm if no matrix is supplied.
84 If the computation is invalid, returns NaN.
85 """
86 if matrix is None:
87 return float(np.linalg.norm(vector[np.isfinite(vector)], 2))
89 # make sure matrix is quadratic
90 if matrix.shape[0] != matrix.shape[1]:
91 raise AssertionError
93 # make sure the vector has the right number of entries
94 if vector.size != matrix.shape[0]:
95 raise AssertionError
97 v, mat = valid(matrix)
99 if v.any():
100 return float(np.sqrt(np.dot(vector[v], np.dot(mat, vector[v]))))
101 return float("nan")
104def inv_a_norm(vector: np.ndarray, matrix: np.ndarray | None = None) -> float:
105 """Calculates the inverse A-norm of a given vector using an optional matrix.
107 If the matrix is not provided, it defaults to calculating the Euclidean norm of the
108 finite entries in the vector. If the matrix is provided, it checks that the
109 matrix is square and that the dimensions are compatible with the vector before
110 computing the norm.
112 Parameters
113 ----------
114 vector : np.ndarray
115 The input vector for which the norm is to be calculated.
116 matrix : np.ndarray | None, optional
117 An optional square matrix used for computing the norm. If not provided,
118 the function computes the Euclidean norm.
120 Returns:
121 -------
122 float
123 The computed norm as a float value. If no valid entries exist for the
124 calculation, it returns 'np.nan'.
126 Raises:
127 ------
128 AssertionError
129 If the matrix is not square or if the vector's size is incompatible with
130 the matrix's dimensions.
131 """
132 if matrix is None:
133 return float(np.linalg.norm(vector[np.isfinite(vector)], 2))
135 # make sure matrix is quadratic
136 if matrix.shape[0] != matrix.shape[1]:
137 raise AssertionError
139 # make sure the vector has the right number of entries
140 if vector.size != matrix.shape[0]:
141 raise AssertionError
143 v, mat = valid(matrix)
145 if v.any():
146 return float(np.sqrt(np.dot(vector[v], np.linalg.solve(mat, vector[v]))))
147 return float("nan")
150def solve(matrix: np.ndarray, rhs: np.ndarray) -> np.ndarray:
151 """Solve a linear system of equations using parts of a given matrix marked as valid.
153 This function solves a linear system Ax = b for the subset of the system where
154 certain rows and columns of the square matrix A, and corresponding entries of
155 the right-hand side vector b, are marked as valid. The solution is computed for
156 the valid subset using NumPy's linear algebra solver. The input matrix must be
157 a square matrix, and the size of the vector must match the number of rows (or
158 columns) of the matrix.
160 Parameters:
161 matrix: np.ndarray
162 A square matrix (2D array) representing the coefficients of the linear
163 system.
164 rhs: np.ndarray
165 A 1D array representing the right-hand side vector of the equation. Its
166 size must match the number of rows in the matrix.
168 Returns:
169 np.ndarray
170 A 1D array containing the solution to the system for the valid subset. If
171 no valid rows or columns exist, the array will contain only NaN values.
172 """
173 # make sure matrix is quadratic
174 if matrix.shape[0] != matrix.shape[1]:
175 raise AssertionError
177 # make sure the vector rhs has the right number of entries
178 if rhs.size != matrix.shape[0]:
179 raise AssertionError
181 x = np.nan * np.ones(rhs.size)
182 v, mat = valid(matrix)
184 if v.any():
185 x[v] = np.linalg.solve(mat, rhs[v])
187 return x