-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement a partial Cholesky factorisation and diagonal-plus-low-rank…
… preconditioning (#185) * Implement partial Cholesky factorisations * Update the docstrings of low-rank approximations * Simplify the test-code structure * Flatten the function hierarchy in low_rank.py * Reorganise test-code
- Loading branch information
Showing
8 changed files
with
326 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
"""Low-rank approximations (like partial Cholesky decompositions) of matrices.""" | ||
|
||
from matfree.backend import control_flow, func, linalg, np | ||
from matfree.backend.typing import Array, Callable | ||
|
||
|
||
def preconditioner(cholesky: Callable, /) -> Callable: | ||
r"""Turn a low-rank approximation into a preconditioner. | ||
Parameters | ||
---------- | ||
cholesky | ||
(Partial) Cholesky decomposition. | ||
Usually, the result of either | ||
[cholesky_partial][matfree.low_rank.cholesky_partial] | ||
or | ||
[cholesky_partial_pivot][matfree.low_rank.cholesky_partial_pivot]. | ||
Returns | ||
------- | ||
solve | ||
A function that computes | ||
$$ | ||
(v, s, *p) \mapsto (sI + L(*p) L(*p)^\top)^{-1} v, | ||
$$ | ||
where $K = [k(i,j,*p)]_{ij} \approx L(*p) L(*p)^\top$ | ||
and $L$ comes from the low-rank approximation. | ||
""" | ||
|
||
def solve(v: Array, s: float, *cholesky_params): | ||
chol, info = cholesky(*cholesky_params) | ||
|
||
# Assert that the low-rank matrix is tall, | ||
# not wide (every sign has a story...) | ||
N, n = np.shape(chol) | ||
assert n <= N, (N, n) | ||
|
||
# Scale | ||
U = chol / np.sqrt(s) | ||
V = chol.T / np.sqrt(s) | ||
v /= s | ||
|
||
# Cholesky decompose the capacitance matrix | ||
# and solve the system | ||
eye_n = np.eye(n) | ||
chol_cap = linalg.cho_factor(eye_n + V @ U) | ||
sol = linalg.cho_solve(chol_cap, V @ v) | ||
return v - U @ sol, info | ||
|
||
return solve | ||
|
||
|
||
def cholesky_partial(mat_el: Callable, /, *, nrows: int, rank: int) -> Callable: | ||
"""Compute a partial Cholesky factorisation.""" | ||
|
||
def cholesky(*params): | ||
if rank > nrows: | ||
msg = f"Rank exceeds n: {rank} >= {nrows}." | ||
raise ValueError(msg) | ||
if rank < 1: | ||
msg = f"Rank must be positive, but {rank} < {1}." | ||
raise ValueError(msg) | ||
|
||
step = _cholesky_partial_body(mat_el, nrows, *params) | ||
chol = np.zeros((nrows, rank)) | ||
return control_flow.fori_loop(0, rank, step, chol), {} | ||
|
||
return cholesky | ||
|
||
|
||
def _cholesky_partial_body(fn: Callable, n: int, *args): | ||
idx = np.arange(n) | ||
|
||
def matrix_element(i, j): | ||
return fn(i, j, *args) | ||
|
||
def matrix_column(i): | ||
fun = func.vmap(matrix_element, in_axes=(0, None)) | ||
return fun(idx, i) | ||
|
||
def body(i, L): | ||
element = matrix_element(i, i) | ||
l_ii = np.sqrt(element - linalg.vecdot(L[i], L[i])) | ||
|
||
column = matrix_column(i) | ||
l_ji = column - L @ L[i, :] | ||
l_ji /= l_ii | ||
|
||
return L.at[:, i].set(l_ji) | ||
|
||
return body | ||
|
||
|
||
def cholesky_partial_pivot(mat_el: Callable, /, *, nrows: int, rank: int) -> Callable: | ||
"""Compute a partial Cholesky factorisation with pivoting.""" | ||
|
||
def cholesky(*params): | ||
if rank > nrows: | ||
msg = f"Rank exceeds nrows: {rank} >= {nrows}." | ||
raise ValueError(msg) | ||
if rank < 1: | ||
msg = f"Rank must be positive, but {rank} < {1}." | ||
raise ValueError(msg) | ||
|
||
body = _cholesky_partial_pivot_body(mat_el, nrows, *params) | ||
|
||
L = np.zeros((nrows, rank)) | ||
P = np.arange(nrows) | ||
|
||
init = (L, P, P, True) | ||
(L, P, _matrix, success) = control_flow.fori_loop(0, rank, body, init) | ||
return _pivot_invert(L, P), {"success": success} | ||
|
||
return cholesky | ||
|
||
|
||
def _cholesky_partial_pivot_body(fn: Callable, n: int, *args): | ||
idx = np.arange(n) | ||
|
||
def matrix_element(i, j): | ||
return fn(i, j, *args) | ||
|
||
def matrix_element_p(i, j, *, permute): | ||
return matrix_element(permute[i], permute[j]) | ||
|
||
def matrix_column_p(i, *, permute): | ||
fun = func.vmap(matrix_element, in_axes=(0, None)) | ||
return fun(permute[idx], permute[i]) | ||
|
||
def matrix_diagonal_p(*, permute): | ||
fun = func.vmap(matrix_element) | ||
return fun(permute[idx], permute[idx]) | ||
|
||
def body(i, carry): | ||
L, P, P_matrix, success = carry | ||
|
||
# Access the matrix | ||
diagonal = matrix_diagonal_p(permute=P_matrix) | ||
|
||
# Find the largest entry for the residuals | ||
residual_diag = diagonal - func.vmap(linalg.vecdot)(L, L) | ||
res = np.abs(residual_diag) | ||
k = np.argmax(res) | ||
|
||
# Pivot [pivot!!! pivot!!! pivot!!! :)] | ||
P_matrix = _swap_cols(P_matrix, i, k) | ||
L = _swap_rows(L, i, k) | ||
P = _swap_rows(P, i, k) | ||
|
||
# Access the matrix | ||
element = matrix_element_p(i, i, permute=P_matrix) | ||
column = matrix_column_p(i, permute=P_matrix) | ||
|
||
# Perform a Cholesky step | ||
# (The first line could also be accessed via | ||
# residual_diag[k], but it might | ||
# be more readable to do it again) | ||
l_ii_squared = element - linalg.vecdot(L[i], L[i]) | ||
l_ii = np.sqrt(l_ii_squared) | ||
l_ji = column - L @ L[i, :] | ||
l_ji /= l_ii | ||
success = np.logical_and(success, l_ii_squared > 0.0) | ||
|
||
# Update the estimate | ||
L = L.at[:, i].set(l_ji) | ||
return L, P, P_matrix, success | ||
|
||
return body | ||
|
||
|
||
def _swap_cols(arr, i, j): | ||
return _swap_rows(arr.T, i, j).T | ||
|
||
|
||
def _swap_rows(arr, i, j): | ||
ai, aj = arr[i], arr[j] | ||
arr = arr.at[i].set(aj) | ||
return arr.at[j].set(ai) | ||
|
||
|
||
def _pivot_invert(arr, pivot, /): | ||
"""Invert and apply a pivoting array to a matrix.""" | ||
return arr[np.argsort(pivot)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Test the partial Cholesky decompositions.""" | ||
|
||
from matfree import low_rank, test_util | ||
from matfree.backend import linalg, np, prng, testing | ||
from matfree.backend.typing import Callable | ||
|
||
|
||
def case_cholesky_partial(): | ||
return low_rank.cholesky_partial | ||
|
||
|
||
def case_cholesky_partial_pivot(): | ||
return low_rank.cholesky_partial_pivot | ||
|
||
|
||
@testing.parametrize_with_cases("cholesky", cases=".", prefix="case_cholesky") | ||
def test_full_rank_cholesky_reconstructs_matrix(cholesky, n=5): | ||
key = prng.prng_key(2) | ||
|
||
cov_eig = 1.0 + prng.uniform(key, shape=(n,), dtype=float) | ||
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) | ||
|
||
approximation, _info = cholesky(lambda i, j: cov[i, j], nrows=n, rank=n)() | ||
|
||
tol = np.finfo_eps(approximation.dtype) | ||
assert np.allclose(approximation @ approximation.T, cov, atol=tol, rtol=tol) | ||
|
||
|
||
@testing.parametrize_with_cases("cholesky", cases=".", prefix="case_cholesky") | ||
def test_output_the_right_shapes(cholesky: Callable, n=4, rank=4): | ||
key = prng.prng_key(1) | ||
|
||
cov_eig = 0.1 + prng.uniform(key, shape=(n,)) | ||
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) | ||
|
||
approximation, _info = cholesky(lambda i, j: cov[i, j], nrows=n, rank=rank)() | ||
assert approximation.shape == (n, rank) | ||
|
||
|
||
def test_full_rank_nopivot_matches_cholesky(n=10): | ||
key = prng.prng_key(2) | ||
cov_eig = 0.01 + prng.uniform(key, shape=(n,), dtype=float) | ||
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) | ||
reference = linalg.cholesky(cov) | ||
|
||
# Sanity check: pivoting should definitely not satisfy this: | ||
cholesky_p = low_rank.cholesky_partial_pivot( | ||
lambda i, j: cov[i, j], nrows=n, rank=n | ||
) | ||
received, info = cholesky_p() | ||
assert not np.allclose(received, reference) | ||
|
||
# But without pivoting, we should get there! | ||
cholesky = low_rank.cholesky_partial(lambda i, j: cov[i, j], nrows=n, rank=n) | ||
received, info = cholesky() | ||
assert np.allclose(received, reference, atol=1e-6) | ||
|
||
|
||
def test_pivoting_improves_the_estimate(n=10, rank=5): | ||
key = prng.prng_key(1) | ||
|
||
cov_eig = 0.1 + prng.uniform(key, shape=(n,)) | ||
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) | ||
|
||
def element(i, j): | ||
return cov[i, j] | ||
|
||
nopivot, _info = low_rank.cholesky_partial(element, nrows=n, rank=rank)() | ||
pivot, _info = low_rank.cholesky_partial_pivot(element, nrows=n, rank=rank)() | ||
|
||
error_nopivot = linalg.matrix_norm(cov - nopivot @ nopivot.T, which="fro") | ||
error_pivot = linalg.matrix_norm(cov - pivot @ pivot.T, which="fro") | ||
assert error_pivot < error_nopivot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Test preconditioning with partial Cholesky decompositions.""" | ||
|
||
from matfree import low_rank, test_util | ||
from matfree.backend import linalg, np | ||
|
||
|
||
def test_preconditioner_solves_correctly(n=10): | ||
# Create a relatively ill-conditioned matrix | ||
cov_eig = 1.5 ** np.arange(-n // 2, n // 2, step=1.0) | ||
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) | ||
|
||
def element(i, j): | ||
return cov[i, j] | ||
|
||
# Assert that the Cholesky decomposition is full-rank. | ||
# This is important to ensure that the test below makes sense. | ||
cholesky = low_rank.cholesky_partial(element, nrows=n, rank=n) | ||
matrix, _info = cholesky() | ||
assert np.allclose(matrix @ matrix.T, cov) | ||
|
||
# Set up the test-problem | ||
small_value = 1e-1 | ||
b = np.arange(1.0, 1 + len(cov)) | ||
b /= linalg.vector_norm(b) | ||
|
||
# Solve the linear system | ||
cov_added = cov + small_value * np.eye(len(cov)) | ||
expected = linalg.solve(cov_added, b) | ||
|
||
# Derive the preconditioner | ||
precondition = low_rank.preconditioner(cholesky) | ||
received, info = precondition(b, small_value) | ||
|
||
# Test that the preconditioner solves correctly | ||
assert np.allclose(received, expected, rtol=1e-2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters