diff --git a/docs/examples/preconditioner.ipynb b/docs/examples/preconditioner.ipynb new file mode 100644 index 0000000..253c6b0 --- /dev/null +++ b/docs/examples/preconditioner.ipynb @@ -0,0 +1,198 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using preconditioners\n", + "\n", + "\n", + "Preconditioning can notably improve the convergence of iterative methods. Preconditioners are particularly suited for solving sparse systems that arise from PDE problems. In this example, we will show how to use a simple Jacobi preconditioner (see [here](https://en.wikipedia.org/wiki/Preconditioner#Jacobi_(or_diagonal)_preconditioner)) to solve a 2D Laplacian linear system using `lx.cg`. We will first show the performance of the solver without preconditioning and then with Jacobi preconditioning.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define a [Poisson problem in 2D](https://en.wikipedia.org/wiki/Discrete_Poisson_equation)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "from scipy.sparse import diags, kron, eye\n", + "import jax.experimental.sparse as js\n", + "import lineax as lx\n", + "\n", + "def poisson(n, m):\n", + " \"\"\"\n", + " Create a 2D Laplacian matrix on an n-by-m grid as a JAX BCOO sparse matrix.\n", + " \"\"\"\n", + " lap_1d_n = diags([-1, 2, -1], [-1, 0, 1], shape=(n, n), format=\"csr\")\n", + " lap_1d_m = diags([-1, 2, -1], [-1, 0, 1], shape=(m, m), format=\"csr\")\n", + " lap_2d = kron(eye(m, format=\"csr\"), lap_1d_n) + kron(lap_1d_m, eye(n, format=\"csr\"))\n", + " return js.BCOO.from_scipy_sparse(lap_2d)\n", + "\n", + "\n", + "# Set up the problem: A x = b\n", + "n, m = 200, 200\n", + "A = poisson(n, m)\n", + "key = jr.PRNGKey(0)\n", + "b = jr.uniform(key, (A.shape[0],))\n", + "\n", + "in_structure = jax.eval_shape(lambda: b)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our Laplacian matrix `A` is a large sparse matrix of size `(n*m, n*m)`. We do not want to materialize it with a `MatrixLinearOperator`, which only supports dense matrices. Instead, we define a `SparseMatrixLinearOperator` that computes the sparse matrix-vector `A @ x` product." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "# Define operator and solve with GMRES\n", + "class SparseMatrixLinearOperator(lx.MatrixLinearOperator):\n", + " def mv(self, vector):\n", + " return self.matrix @ vector\n", + " \n", + "\n", + "@lx.is_positive_semidefinite.register(SparseMatrixLinearOperator)\n", + "def _(op):\n", + " return True\n", + " \n", + "operator = SparseMatrixLinearOperator(A)\n", + "solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=30)\n", + "x = lx.linear_solve(operator, b, solver=solver, throw=False).value" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's check the performance of this solve." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(19.014511, dtype=float32)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check the residual norm\n", + "error = jnp.linalg.norm(b - (A @ x))\n", + "error" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pretty bad hey. Now we use a simple Jacobi preconditioner. We need to define another `FunctionLinearOperator` that computes the sparse matrix-vector `M @ x` product, where `M` is the Jacobi preconditioner. The Jacobi preconditioner is a diagonal matrix with the diagonal elements equal to the diagonal elements of `A`. We need to write a utility function to extract the diagonal of a `BCOO` matrix." + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def get_diagonal(matrix):\n", + " \"\"\"\n", + " Extract the diagonal from a sparse matrix.\n", + " \"\"\"\n", + " is_diag = matrix.indices[:, 0] == matrix.indices[:, 1]\n", + " diag_values = jnp.where(is_diag, matrix.data, 0)\n", + " diag = jnp.zeros(matrix.shape[0], dtype=matrix.data.dtype)\n", + " diag = diag.at[matrix.indices[:, 0]].add(diag_values)\n", + " return diag\n", + "jacobi = get_diagonal(A)\n", + "\n", + "preconditioner = lx.FunctionLinearOperator(lambda x: x / jacobi, \n", + " in_structure, \n", + " tags=[lx.positive_semidefinite_tag])\n", + "\n", + "solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=30)\n", + "x = lx.linear_solve(operator, \n", + " b, \n", + " solver=solver, \n", + " options={\"preconditioner\": preconditioner}, \n", + " throw=False).value" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(19.014511, dtype=float32)" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check the residual norm\n", + "error = jnp.linalg.norm(b - (A @ x))\n", + "error" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's much better! More advanced preconditioners such as multigrid preconditioners could be used to further improve the convergence of the solver." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}