Skip to content

Commit

Permalink
Implement a partial Cholesky factorisation and diagonal-plus-low-rank…
Browse files Browse the repository at this point in the history
… preconditioning (#185)

* Implement partial Cholesky factorisations

* Update the docstrings of low-rank approximations

* Simplify the test-code structure

* Flatten the function hierarchy in low_rank.py

* Reorganise test-code
  • Loading branch information
pnkraemer authored May 24, 2024
1 parent f1cb7c6 commit 89327a0
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Builds on [JAX](https://jax.readthedocs.io/en/latest/).
- ⚡ A stand-alone implementation of **stochastic Lanczos quadrature**
- ⚡ Matrix-decomposition algorithms for **large sparse eigenvalue problems**
- ⚡ Polynomial methods for approximating **functions of large matrices**
- ⚡ Partial Cholesky **preconditioners** with and without pivoting

and many other things.
Everything is natively compatible with the rest of JAX:
Expand Down
12 changes: 12 additions & 0 deletions matfree/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ def eigh(x, /):
return jnp.linalg.eigh(x)


def cholesky(x, /):
return jnp.linalg.cholesky(x)


def cho_factor(matrix, /):
return jax.scipy.linalg.cho_factor(matrix)


def cho_solve(factor, b, /):
return jax.scipy.linalg.cho_solve(factor, b)


def slogdet(x, /):
return jnp.linalg.slogdet(x)

Expand Down
12 changes: 12 additions & 0 deletions matfree/backend/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def sign(x, /):
return jnp.sign(x)


def logical_and(a, b, /):
return jnp.logical_and(a, b)


# Utility functions


Expand Down Expand Up @@ -126,6 +130,14 @@ def array_max(x, /, axis=None):
return jnp.amax(x, axis=axis)


def argmax(x, /, axis=None):
return jnp.argmax(x, axis=axis)


def argsort(x, /):
return jnp.argsort(x)


def elementwise_max(x1, x2, /):
return jnp.maximum(x1, x2)

Expand Down
4 changes: 4 additions & 0 deletions matfree/backend/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def parametrize(argnames, argvalues, /):
return pytest.mark.parametrize(argnames, argvalues)


def parametrize_with_cases(argnames, /, cases, prefix):
return pytest_cases.parametrize_with_cases(argnames, cases=cases, prefix=prefix)


def check_grads(fun, /, args, *, order, atol, rtol):
return jax.test_util.check_grads(fun, args, order=order, atol=atol, rtol=rtol)

Expand Down
186 changes: 186 additions & 0 deletions matfree/low_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Low-rank approximations (like partial Cholesky decompositions) of matrices."""

from matfree.backend import control_flow, func, linalg, np
from matfree.backend.typing import Array, Callable


def preconditioner(cholesky: Callable, /) -> Callable:
r"""Turn a low-rank approximation into a preconditioner.
Parameters
----------
cholesky
(Partial) Cholesky decomposition.
Usually, the result of either
[cholesky_partial][matfree.low_rank.cholesky_partial]
or
[cholesky_partial_pivot][matfree.low_rank.cholesky_partial_pivot].
Returns
-------
solve
A function that computes
$$
(v, s, *p) \mapsto (sI + L(*p) L(*p)^\top)^{-1} v,
$$
where $K = [k(i,j,*p)]_{ij} \approx L(*p) L(*p)^\top$
and $L$ comes from the low-rank approximation.
"""

def solve(v: Array, s: float, *cholesky_params):
chol, info = cholesky(*cholesky_params)

# Assert that the low-rank matrix is tall,
# not wide (every sign has a story...)
N, n = np.shape(chol)
assert n <= N, (N, n)

# Scale
U = chol / np.sqrt(s)
V = chol.T / np.sqrt(s)
v /= s

# Cholesky decompose the capacitance matrix
# and solve the system
eye_n = np.eye(n)
chol_cap = linalg.cho_factor(eye_n + V @ U)
sol = linalg.cho_solve(chol_cap, V @ v)
return v - U @ sol, info

return solve


def cholesky_partial(mat_el: Callable, /, *, nrows: int, rank: int) -> Callable:
"""Compute a partial Cholesky factorisation."""

def cholesky(*params):
if rank > nrows:
msg = f"Rank exceeds n: {rank} >= {nrows}."
raise ValueError(msg)
if rank < 1:
msg = f"Rank must be positive, but {rank} < {1}."
raise ValueError(msg)

step = _cholesky_partial_body(mat_el, nrows, *params)
chol = np.zeros((nrows, rank))
return control_flow.fori_loop(0, rank, step, chol), {}

return cholesky


def _cholesky_partial_body(fn: Callable, n: int, *args):
idx = np.arange(n)

def matrix_element(i, j):
return fn(i, j, *args)

def matrix_column(i):
fun = func.vmap(matrix_element, in_axes=(0, None))
return fun(idx, i)

def body(i, L):
element = matrix_element(i, i)
l_ii = np.sqrt(element - linalg.vecdot(L[i], L[i]))

column = matrix_column(i)
l_ji = column - L @ L[i, :]
l_ji /= l_ii

return L.at[:, i].set(l_ji)

return body


def cholesky_partial_pivot(mat_el: Callable, /, *, nrows: int, rank: int) -> Callable:
"""Compute a partial Cholesky factorisation with pivoting."""

def cholesky(*params):
if rank > nrows:
msg = f"Rank exceeds nrows: {rank} >= {nrows}."
raise ValueError(msg)
if rank < 1:
msg = f"Rank must be positive, but {rank} < {1}."
raise ValueError(msg)

body = _cholesky_partial_pivot_body(mat_el, nrows, *params)

L = np.zeros((nrows, rank))
P = np.arange(nrows)

init = (L, P, P, True)
(L, P, _matrix, success) = control_flow.fori_loop(0, rank, body, init)
return _pivot_invert(L, P), {"success": success}

return cholesky


def _cholesky_partial_pivot_body(fn: Callable, n: int, *args):
idx = np.arange(n)

def matrix_element(i, j):
return fn(i, j, *args)

def matrix_element_p(i, j, *, permute):
return matrix_element(permute[i], permute[j])

def matrix_column_p(i, *, permute):
fun = func.vmap(matrix_element, in_axes=(0, None))
return fun(permute[idx], permute[i])

def matrix_diagonal_p(*, permute):
fun = func.vmap(matrix_element)
return fun(permute[idx], permute[idx])

def body(i, carry):
L, P, P_matrix, success = carry

# Access the matrix
diagonal = matrix_diagonal_p(permute=P_matrix)

# Find the largest entry for the residuals
residual_diag = diagonal - func.vmap(linalg.vecdot)(L, L)
res = np.abs(residual_diag)
k = np.argmax(res)

# Pivot [pivot!!! pivot!!! pivot!!! :)]
P_matrix = _swap_cols(P_matrix, i, k)
L = _swap_rows(L, i, k)
P = _swap_rows(P, i, k)

# Access the matrix
element = matrix_element_p(i, i, permute=P_matrix)
column = matrix_column_p(i, permute=P_matrix)

# Perform a Cholesky step
# (The first line could also be accessed via
# residual_diag[k], but it might
# be more readable to do it again)
l_ii_squared = element - linalg.vecdot(L[i], L[i])
l_ii = np.sqrt(l_ii_squared)
l_ji = column - L @ L[i, :]
l_ji /= l_ii
success = np.logical_and(success, l_ii_squared > 0.0)

# Update the estimate
L = L.at[:, i].set(l_ji)
return L, P, P_matrix, success

return body


def _swap_cols(arr, i, j):
return _swap_rows(arr.T, i, j).T


def _swap_rows(arr, i, j):
ai, aj = arr[i], arr[j]
arr = arr.at[i].set(aj)
return arr.at[j].set(ai)


def _pivot_invert(arr, pivot, /):
"""Invert and apply a pivoting array to a matrix."""
return arr[np.argsort(pivot)]
73 changes: 73 additions & 0 deletions tests/test_low_rank/test_cholesky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Test the partial Cholesky decompositions."""

from matfree import low_rank, test_util
from matfree.backend import linalg, np, prng, testing
from matfree.backend.typing import Callable


def case_cholesky_partial():
return low_rank.cholesky_partial


def case_cholesky_partial_pivot():
return low_rank.cholesky_partial_pivot


@testing.parametrize_with_cases("cholesky", cases=".", prefix="case_cholesky")
def test_full_rank_cholesky_reconstructs_matrix(cholesky, n=5):
key = prng.prng_key(2)

cov_eig = 1.0 + prng.uniform(key, shape=(n,), dtype=float)
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig)

approximation, _info = cholesky(lambda i, j: cov[i, j], nrows=n, rank=n)()

tol = np.finfo_eps(approximation.dtype)
assert np.allclose(approximation @ approximation.T, cov, atol=tol, rtol=tol)


@testing.parametrize_with_cases("cholesky", cases=".", prefix="case_cholesky")
def test_output_the_right_shapes(cholesky: Callable, n=4, rank=4):
key = prng.prng_key(1)

cov_eig = 0.1 + prng.uniform(key, shape=(n,))
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig)

approximation, _info = cholesky(lambda i, j: cov[i, j], nrows=n, rank=rank)()
assert approximation.shape == (n, rank)


def test_full_rank_nopivot_matches_cholesky(n=10):
key = prng.prng_key(2)
cov_eig = 0.01 + prng.uniform(key, shape=(n,), dtype=float)
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig)
reference = linalg.cholesky(cov)

# Sanity check: pivoting should definitely not satisfy this:
cholesky_p = low_rank.cholesky_partial_pivot(
lambda i, j: cov[i, j], nrows=n, rank=n
)
received, info = cholesky_p()
assert not np.allclose(received, reference)

# But without pivoting, we should get there!
cholesky = low_rank.cholesky_partial(lambda i, j: cov[i, j], nrows=n, rank=n)
received, info = cholesky()
assert np.allclose(received, reference, atol=1e-6)


def test_pivoting_improves_the_estimate(n=10, rank=5):
key = prng.prng_key(1)

cov_eig = 0.1 + prng.uniform(key, shape=(n,))
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig)

def element(i, j):
return cov[i, j]

nopivot, _info = low_rank.cholesky_partial(element, nrows=n, rank=rank)()
pivot, _info = low_rank.cholesky_partial_pivot(element, nrows=n, rank=rank)()

error_nopivot = linalg.matrix_norm(cov - nopivot @ nopivot.T, which="fro")
error_pivot = linalg.matrix_norm(cov - pivot @ pivot.T, which="fro")
assert error_pivot < error_nopivot
35 changes: 35 additions & 0 deletions tests/test_low_rank/test_preconditioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Test preconditioning with partial Cholesky decompositions."""

from matfree import low_rank, test_util
from matfree.backend import linalg, np


def test_preconditioner_solves_correctly(n=10):
# Create a relatively ill-conditioned matrix
cov_eig = 1.5 ** np.arange(-n // 2, n // 2, step=1.0)
cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig)

def element(i, j):
return cov[i, j]

# Assert that the Cholesky decomposition is full-rank.
# This is important to ensure that the test below makes sense.
cholesky = low_rank.cholesky_partial(element, nrows=n, rank=n)
matrix, _info = cholesky()
assert np.allclose(matrix @ matrix.T, cov)

# Set up the test-problem
small_value = 1e-1
b = np.arange(1.0, 1 + len(cov))
b /= linalg.vector_norm(b)

# Solve the linear system
cov_added = cov + small_value * np.eye(len(cov))
expected = linalg.solve(cov_added, b)

# Derive the preconditioner
precondition = low_rank.preconditioner(cholesky)
received, info = precondition(b, small_value)

# Test that the preconditioner solves correctly
assert np.allclose(received, expected, rtol=1e-2)
6 changes: 3 additions & 3 deletions tutorials/6_low_memory_trace_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def large_matvec(v):
print(trace)


# The above code requires nrows*nsamples storage, which
# The above code requires nrows $\times$ nsamples storage, which
# is prohibitive for extremely large matrices.
# Instead, we can loop around estimate() to do the following:
# The below code requires nrows*1 storage:
# The below code requires nrows $\times$ 1 storage:

sampler = hutchinson.sampler_rademacher(x0, num=1)
estimate = hutchinson.hutchinson(integrand, sampler)
Expand All @@ -57,7 +57,7 @@ def large_matvec(v):

# In practice, we often combine both approaches by choosing
# the largest nsamples (in the first implementation) so that
# nrows*nsamples fits into memory, and handle all samples beyond
# nrows $\times$ nsamples fits into memory, and handle all samples beyond
# that via the split-and-map combination.
#
#
Expand Down

0 comments on commit 89327a0

Please sign in to comment.