diff --git a/README.md b/README.md index abdd149..3516571 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Builds on [JAX](https://jax.readthedocs.io/en/latest/). - ⚡ A stand-alone implementation of **stochastic Lanczos quadrature** - ⚡ Matrix-decomposition algorithms for **large sparse eigenvalue problems** - ⚡ Polynomial methods for approximating **functions of large matrices** +- ⚡ Partial Cholesky **preconditioners** with and without pivoting and many other things. Everything is natively compatible with the rest of JAX: diff --git a/matfree/backend/linalg.py b/matfree/backend/linalg.py index ed91743..baec55c 100644 --- a/matfree/backend/linalg.py +++ b/matfree/backend/linalg.py @@ -20,6 +20,18 @@ def eigh(x, /): return jnp.linalg.eigh(x) +def cholesky(x, /): + return jnp.linalg.cholesky(x) + + +def cho_factor(matrix, /): + return jax.scipy.linalg.cho_factor(matrix) + + +def cho_solve(factor, b, /): + return jax.scipy.linalg.cho_solve(factor, b) + + def slogdet(x, /): return jnp.linalg.slogdet(x) diff --git a/matfree/backend/np.py b/matfree/backend/np.py index 548d164..3650a43 100644 --- a/matfree/backend/np.py +++ b/matfree/backend/np.py @@ -88,6 +88,10 @@ def sign(x, /): return jnp.sign(x) +def logical_and(a, b, /): + return jnp.logical_and(a, b) + + # Utility functions @@ -126,6 +130,14 @@ def array_max(x, /, axis=None): return jnp.amax(x, axis=axis) +def argmax(x, /, axis=None): + return jnp.argmax(x, axis=axis) + + +def argsort(x, /): + return jnp.argsort(x) + + def elementwise_max(x1, x2, /): return jnp.maximum(x1, x2) diff --git a/matfree/backend/testing.py b/matfree/backend/testing.py index 1680fdc..3665877 100644 --- a/matfree/backend/testing.py +++ b/matfree/backend/testing.py @@ -13,6 +13,10 @@ def parametrize(argnames, argvalues, /): return pytest.mark.parametrize(argnames, argvalues) +def parametrize_with_cases(argnames, /, cases, prefix): + return pytest_cases.parametrize_with_cases(argnames, cases=cases, prefix=prefix) + + def check_grads(fun, /, args, *, order, atol, rtol): return jax.test_util.check_grads(fun, args, order=order, atol=atol, rtol=rtol) diff --git a/matfree/low_rank.py b/matfree/low_rank.py new file mode 100644 index 0000000..e9f8908 --- /dev/null +++ b/matfree/low_rank.py @@ -0,0 +1,186 @@ +"""Low-rank approximations (like partial Cholesky decompositions) of matrices.""" + +from matfree.backend import control_flow, func, linalg, np +from matfree.backend.typing import Array, Callable + + +def preconditioner(cholesky: Callable, /) -> Callable: + r"""Turn a low-rank approximation into a preconditioner. + + Parameters + ---------- + cholesky + (Partial) Cholesky decomposition. + Usually, the result of either + [cholesky_partial][matfree.low_rank.cholesky_partial] + or + [cholesky_partial_pivot][matfree.low_rank.cholesky_partial_pivot]. + + + Returns + ------- + solve + A function that computes + + $$ + (v, s, *p) \mapsto (sI + L(*p) L(*p)^\top)^{-1} v, + $$ + + where $K = [k(i,j,*p)]_{ij} \approx L(*p) L(*p)^\top$ + and $L$ comes from the low-rank approximation. + """ + + def solve(v: Array, s: float, *cholesky_params): + chol, info = cholesky(*cholesky_params) + + # Assert that the low-rank matrix is tall, + # not wide (every sign has a story...) + N, n = np.shape(chol) + assert n <= N, (N, n) + + # Scale + U = chol / np.sqrt(s) + V = chol.T / np.sqrt(s) + v /= s + + # Cholesky decompose the capacitance matrix + # and solve the system + eye_n = np.eye(n) + chol_cap = linalg.cho_factor(eye_n + V @ U) + sol = linalg.cho_solve(chol_cap, V @ v) + return v - U @ sol, info + + return solve + + +def cholesky_partial(mat_el: Callable, /, *, nrows: int, rank: int) -> Callable: + """Compute a partial Cholesky factorisation.""" + + def cholesky(*params): + if rank > nrows: + msg = f"Rank exceeds n: {rank} >= {nrows}." + raise ValueError(msg) + if rank < 1: + msg = f"Rank must be positive, but {rank} < {1}." + raise ValueError(msg) + + step = _cholesky_partial_body(mat_el, nrows, *params) + chol = np.zeros((nrows, rank)) + return control_flow.fori_loop(0, rank, step, chol), {} + + return cholesky + + +def _cholesky_partial_body(fn: Callable, n: int, *args): + idx = np.arange(n) + + def matrix_element(i, j): + return fn(i, j, *args) + + def matrix_column(i): + fun = func.vmap(matrix_element, in_axes=(0, None)) + return fun(idx, i) + + def body(i, L): + element = matrix_element(i, i) + l_ii = np.sqrt(element - linalg.vecdot(L[i], L[i])) + + column = matrix_column(i) + l_ji = column - L @ L[i, :] + l_ji /= l_ii + + return L.at[:, i].set(l_ji) + + return body + + +def cholesky_partial_pivot(mat_el: Callable, /, *, nrows: int, rank: int) -> Callable: + """Compute a partial Cholesky factorisation with pivoting.""" + + def cholesky(*params): + if rank > nrows: + msg = f"Rank exceeds nrows: {rank} >= {nrows}." + raise ValueError(msg) + if rank < 1: + msg = f"Rank must be positive, but {rank} < {1}." + raise ValueError(msg) + + body = _cholesky_partial_pivot_body(mat_el, nrows, *params) + + L = np.zeros((nrows, rank)) + P = np.arange(nrows) + + init = (L, P, P, True) + (L, P, _matrix, success) = control_flow.fori_loop(0, rank, body, init) + return _pivot_invert(L, P), {"success": success} + + return cholesky + + +def _cholesky_partial_pivot_body(fn: Callable, n: int, *args): + idx = np.arange(n) + + def matrix_element(i, j): + return fn(i, j, *args) + + def matrix_element_p(i, j, *, permute): + return matrix_element(permute[i], permute[j]) + + def matrix_column_p(i, *, permute): + fun = func.vmap(matrix_element, in_axes=(0, None)) + return fun(permute[idx], permute[i]) + + def matrix_diagonal_p(*, permute): + fun = func.vmap(matrix_element) + return fun(permute[idx], permute[idx]) + + def body(i, carry): + L, P, P_matrix, success = carry + + # Access the matrix + diagonal = matrix_diagonal_p(permute=P_matrix) + + # Find the largest entry for the residuals + residual_diag = diagonal - func.vmap(linalg.vecdot)(L, L) + res = np.abs(residual_diag) + k = np.argmax(res) + + # Pivot [pivot!!! pivot!!! pivot!!! :)] + P_matrix = _swap_cols(P_matrix, i, k) + L = _swap_rows(L, i, k) + P = _swap_rows(P, i, k) + + # Access the matrix + element = matrix_element_p(i, i, permute=P_matrix) + column = matrix_column_p(i, permute=P_matrix) + + # Perform a Cholesky step + # (The first line could also be accessed via + # residual_diag[k], but it might + # be more readable to do it again) + l_ii_squared = element - linalg.vecdot(L[i], L[i]) + l_ii = np.sqrt(l_ii_squared) + l_ji = column - L @ L[i, :] + l_ji /= l_ii + success = np.logical_and(success, l_ii_squared > 0.0) + + # Update the estimate + L = L.at[:, i].set(l_ji) + return L, P, P_matrix, success + + return body + + +def _swap_cols(arr, i, j): + return _swap_rows(arr.T, i, j).T + + +def _swap_rows(arr, i, j): + ai, aj = arr[i], arr[j] + arr = arr.at[i].set(aj) + return arr.at[j].set(ai) + + +def _pivot_invert(arr, pivot, /): + """Invert and apply a pivoting array to a matrix.""" + return arr[np.argsort(pivot)] diff --git a/tests/test_low_rank/test_cholesky.py b/tests/test_low_rank/test_cholesky.py new file mode 100644 index 0000000..331b972 --- /dev/null +++ b/tests/test_low_rank/test_cholesky.py @@ -0,0 +1,73 @@ +"""Test the partial Cholesky decompositions.""" + +from matfree import low_rank, test_util +from matfree.backend import linalg, np, prng, testing +from matfree.backend.typing import Callable + + +def case_cholesky_partial(): + return low_rank.cholesky_partial + + +def case_cholesky_partial_pivot(): + return low_rank.cholesky_partial_pivot + + +@testing.parametrize_with_cases("cholesky", cases=".", prefix="case_cholesky") +def test_full_rank_cholesky_reconstructs_matrix(cholesky, n=5): + key = prng.prng_key(2) + + cov_eig = 1.0 + prng.uniform(key, shape=(n,), dtype=float) + cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) + + approximation, _info = cholesky(lambda i, j: cov[i, j], nrows=n, rank=n)() + + tol = np.finfo_eps(approximation.dtype) + assert np.allclose(approximation @ approximation.T, cov, atol=tol, rtol=tol) + + +@testing.parametrize_with_cases("cholesky", cases=".", prefix="case_cholesky") +def test_output_the_right_shapes(cholesky: Callable, n=4, rank=4): + key = prng.prng_key(1) + + cov_eig = 0.1 + prng.uniform(key, shape=(n,)) + cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) + + approximation, _info = cholesky(lambda i, j: cov[i, j], nrows=n, rank=rank)() + assert approximation.shape == (n, rank) + + +def test_full_rank_nopivot_matches_cholesky(n=10): + key = prng.prng_key(2) + cov_eig = 0.01 + prng.uniform(key, shape=(n,), dtype=float) + cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) + reference = linalg.cholesky(cov) + + # Sanity check: pivoting should definitely not satisfy this: + cholesky_p = low_rank.cholesky_partial_pivot( + lambda i, j: cov[i, j], nrows=n, rank=n + ) + received, info = cholesky_p() + assert not np.allclose(received, reference) + + # But without pivoting, we should get there! + cholesky = low_rank.cholesky_partial(lambda i, j: cov[i, j], nrows=n, rank=n) + received, info = cholesky() + assert np.allclose(received, reference, atol=1e-6) + + +def test_pivoting_improves_the_estimate(n=10, rank=5): + key = prng.prng_key(1) + + cov_eig = 0.1 + prng.uniform(key, shape=(n,)) + cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) + + def element(i, j): + return cov[i, j] + + nopivot, _info = low_rank.cholesky_partial(element, nrows=n, rank=rank)() + pivot, _info = low_rank.cholesky_partial_pivot(element, nrows=n, rank=rank)() + + error_nopivot = linalg.matrix_norm(cov - nopivot @ nopivot.T, which="fro") + error_pivot = linalg.matrix_norm(cov - pivot @ pivot.T, which="fro") + assert error_pivot < error_nopivot diff --git a/tests/test_low_rank/test_preconditioner.py b/tests/test_low_rank/test_preconditioner.py new file mode 100644 index 0000000..89ca204 --- /dev/null +++ b/tests/test_low_rank/test_preconditioner.py @@ -0,0 +1,35 @@ +"""Test preconditioning with partial Cholesky decompositions.""" + +from matfree import low_rank, test_util +from matfree.backend import linalg, np + + +def test_preconditioner_solves_correctly(n=10): + # Create a relatively ill-conditioned matrix + cov_eig = 1.5 ** np.arange(-n // 2, n // 2, step=1.0) + cov = test_util.symmetric_matrix_from_eigenvalues(cov_eig) + + def element(i, j): + return cov[i, j] + + # Assert that the Cholesky decomposition is full-rank. + # This is important to ensure that the test below makes sense. + cholesky = low_rank.cholesky_partial(element, nrows=n, rank=n) + matrix, _info = cholesky() + assert np.allclose(matrix @ matrix.T, cov) + + # Set up the test-problem + small_value = 1e-1 + b = np.arange(1.0, 1 + len(cov)) + b /= linalg.vector_norm(b) + + # Solve the linear system + cov_added = cov + small_value * np.eye(len(cov)) + expected = linalg.solve(cov_added, b) + + # Derive the preconditioner + precondition = low_rank.preconditioner(cholesky) + received, info = precondition(b, small_value) + + # Test that the preconditioner solves correctly + assert np.allclose(received, expected, rtol=1e-2) diff --git a/tutorials/6_low_memory_trace_estimation.py b/tutorials/6_low_memory_trace_estimation.py index 5880c87..6dabbfd 100644 --- a/tutorials/6_low_memory_trace_estimation.py +++ b/tutorials/6_low_memory_trace_estimation.py @@ -41,10 +41,10 @@ def large_matvec(v): print(trace) -# The above code requires nrows*nsamples storage, which +# The above code requires nrows $\times$ nsamples storage, which # is prohibitive for extremely large matrices. # Instead, we can loop around estimate() to do the following: -# The below code requires nrows*1 storage: +# The below code requires nrows $\times$ 1 storage: sampler = hutchinson.sampler_rademacher(x0, num=1) estimate = hutchinson.hutchinson(integrand, sampler) @@ -57,7 +57,7 @@ def large_matvec(v): # In practice, we often combine both approaches by choosing # the largest nsamples (in the first implementation) so that -# nrows*nsamples fits into memory, and handle all samples beyond +# nrows $\times$ nsamples fits into memory, and handle all samples beyond # that via the split-and-map combination. # #