Coverage for src / tinycta / linalg.py: 100%
39 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 13:04 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 13:04 +0000
1# Copyright (c) 2023 Thomas Schmelzer
2#
3# Permission is hereby granted, free of charge, to any person obtaining a copy
4# of this software and associated documentation files (the "Software"), to deal
5# in the Software without restriction, including without limitation the rights
6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7# copies of the Software, and to permit persons to whom the Software is
8# furnished to do so, subject to the following conditions:
9#
10# The above copyright notice and this permission notice shall be included in all
11# copies or substantial portions of the Software.
12"""Linear algebra utilities that handle matrices with missing (NaN) values.
14All operations extract the finite submatrix via the ``valid`` function before
15performing computations, so partially-observed covariance matrices are handled
16gracefully without raising errors.
17"""
19from __future__ import annotations
21import numpy as np
24def valid(matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
25 """Extract the finite submatrix by checking diagonal elements.
27 Args:
28 matrix: A square NumPy array to be validated and processed.
30 Returns:
31 A tuple of (mask, submatrix) where mask is a boolean array of finite
32 diagonal elements and submatrix contains only the rows and columns
33 with finite diagonals.
35 Raises:
36 AssertionError: If the input matrix is not square.
37 """
38 if matrix.shape[0] != matrix.shape[1]:
39 raise AssertionError
41 v = np.isfinite(np.diag(matrix))
42 return v, matrix[:, v][v]
45def a_norm(vector: np.ndarray, matrix: np.ndarray | None = None) -> float:
46 """Calculate the generalized norm of a vector with respect to a matrix.
48 Computes the Euclidean norm when no matrix is provided, or the quadratic
49 form ``sqrt(v^T A v)`` when a matrix is given.
51 Args:
52 vector: The input one-dimensional vector.
53 matrix: Square matrix defining the quadratic form. Its size must match
54 the vector length. If None, the standard Euclidean norm is returned.
56 Returns:
57 The computed norm, or NaN if no valid entries exist.
59 Raises:
60 AssertionError: If the matrix is not square or its dimensions do not
61 match the vector.
62 """
63 if matrix is None:
64 return float(np.linalg.norm(vector[np.isfinite(vector)], 2))
66 if matrix.shape[0] != matrix.shape[1]:
67 raise AssertionError
69 if vector.size != matrix.shape[0]:
70 raise AssertionError
72 v, mat = valid(matrix)
74 if v.any():
75 return float(np.sqrt(np.dot(vector[v], np.dot(mat, vector[v]))))
76 return float("nan")
79def inv_a_norm(vector: np.ndarray, matrix: np.ndarray | None = None) -> float:
80 """Calculate the inverse A-norm of a vector using an optional matrix.
82 Computes ``sqrt(v^T A^{-1} v)`` when a matrix is provided, or the
83 Euclidean norm of finite entries when no matrix is given.
85 Args:
86 vector: The input vector for which the norm is to be calculated.
87 matrix: Square matrix used for computing the norm. If not provided,
88 the Euclidean norm of finite entries is returned.
90 Returns:
91 The computed norm as a float. Returns ``np.nan`` if no valid entries exist.
93 Raises:
94 AssertionError: If the matrix is not square or the vector size is
95 incompatible with the matrix dimensions.
96 """
97 if matrix is None:
98 return float(np.linalg.norm(vector[np.isfinite(vector)], 2))
100 if matrix.shape[0] != matrix.shape[1]:
101 raise AssertionError
103 if vector.size != matrix.shape[0]:
104 raise AssertionError
106 v, mat = valid(matrix)
108 if v.any():
109 return float(np.sqrt(np.dot(vector[v], np.linalg.solve(mat, vector[v]))))
110 return float("nan")
113def solve(matrix: np.ndarray, rhs: np.ndarray) -> np.ndarray:
114 """Solve a linear system restricted to the valid (finite-diagonal) submatrix.
116 Args:
117 matrix: Square matrix representing the coefficients of the linear system.
118 rhs: Right-hand side vector. Its length must match the number of rows
119 in the matrix.
121 Returns:
122 Solution vector of the same length as ``rhs``. Entries corresponding
123 to invalid rows/columns are set to NaN.
125 Raises:
126 AssertionError: If the matrix is not square or ``rhs`` length does not
127 match the matrix dimensions.
128 """
129 if matrix.shape[0] != matrix.shape[1]:
130 raise AssertionError
132 if rhs.size != matrix.shape[0]:
133 raise AssertionError
135 x = np.nan * np.ones(rhs.size)
136 v, mat = valid(matrix)
138 if v.any():
139 x[v] = np.linalg.solve(mat, rhs[v])
141 return x