From 3fd1b0012b163342bdda1fde972e687d96365e46 Mon Sep 17 00:00:00 2001 From: Victor Boussange Date: Thu, 19 Dec 2024 10:10:56 +0100 Subject: [PATCH 1/2] Added preconditioner example --- docs/examples/preconditioner.ipynb | 195 +++++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 docs/examples/preconditioner.ipynb diff --git a/docs/examples/preconditioner.ipynb b/docs/examples/preconditioner.ipynb new file mode 100644 index 0000000..20ec7d5 --- /dev/null +++ b/docs/examples/preconditioner.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using preconditioners\n", + "\n", + "\n", + "Preconditioning can notably improve the convergence of iterative methods. Preconditioners are particularly suited for solving sparse systems that arise from PDE problems. In this example, we will show how to use a simple Jacobi preconditioner (see [here](https://en.wikipedia.org/wiki/Preconditioner#Jacobi_(or_diagonal)_preconditioner)) to solve a 2D Laplacian linear system using `lx.cg`. We will first show the performance of the solver without preconditioning and then with Jacobi preconditioning.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define a [Poisson problem in 2D](https://en.wikipedia.org/wiki/Discrete_Poisson_equation)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "from scipy.sparse import diags, kron, eye\n", + "from jax.experimental.sparse import BCOO\n", + "import lineax as lx\n", + "\n", + "def create_sparse_2d_laplacian(n, m):\n", + " \"\"\"\n", + " Create a 2D Laplacian matrix on an n-by-m grid as a JAX BCOO sparse matrix.\n", + " \"\"\"\n", + " lap_1d_n = diags([1, -2, 1], [-1, 0, 1], shape=(n, n), format=\"csr\")\n", + " lap_1d_m = diags([1, -2, 1], [-1, 0, 1], shape=(m, m), format=\"csr\")\n", + " lap_2d = kron(eye(m, format=\"csr\"), lap_1d_n) + kron(lap_1d_m, eye(n, format=\"csr\"))\n", + " return BCOO.from_scipy_sparse(lap_2d)\n", + "\n", + "\n", + "# Set up the problem: A x = b\n", + "n, m = 200, 200\n", + "A = create_sparse_2d_laplacian(n, m)\n", + "key = jr.PRNGKey(0)\n", + "b = jr.uniform(key, (A.shape[0],))\n", + "\n", + "in_structure = jax.eval_shape(lambda: b)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our Laplacian matrix `A` is a large sparse matrix of size `(n*m, n*m)`. We do not want to materialize it with a `MatrixLinearOperator`, which only supports dense matrices. Instead, we define a `FunctionLinearOperator` that computes the sparse matrix-vector `A @ x` product." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Define operator and solve with GMRES\n", + "operator = lx.FunctionLinearOperator(lambda x: A @ x, in_structure)\n", + "solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=100)\n", + "x = lx.linear_solve(operator, b, solver=solver, throw=False).value" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's check the performance of this solve." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(0.5655846, dtype=float32)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check the residual norm\n", + "error = jnp.linalg.norm(b - (A @ x))\n", + "error" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pretty bad hey. Now we use a simple Jacobi preconditioner. We need to define another `FunctionLinearOperator` that computes the sparse matrix-vector `M @ x` product, where `M` is the Jacobi preconditioner. The Jacobi preconditioner is a diagonal matrix with the diagonal elements equal to the diagonal elements of `A`. We need to write a utility function to extract the diagonal of a `BCOO` matrix." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "The preconditioner must be positive definite.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m preconditioner \u001b[38;5;241m=\u001b[39m lx\u001b[38;5;241m.\u001b[39mFunctionLinearOperator(\u001b[38;5;28;01mlambda\u001b[39;00m x: x \u001b[38;5;241m/\u001b[39m jacobi, in_structure)\n\u001b[1;32m 14\u001b[0m solver \u001b[38;5;241m=\u001b[39m lx\u001b[38;5;241m.\u001b[39mGMRES(atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m, rtol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mlx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear_solve\u001b[49m\u001b[43m(\u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msolver\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msolver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpreconditioner\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreconditioner\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mvalue\n", + " \u001b[0;31m[... skipping hidden 15 frame]\u001b[0m\n", + "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solve.py:810\u001b[0m, in \u001b[0;36mlinear_solve\u001b[0;34m(operator, vector, solver, options, state, throw)\u001b[0m\n\u001b[1;32m 804\u001b[0m options \u001b[38;5;241m=\u001b[39m eqxi\u001b[38;5;241m.\u001b[39mnondifferentiable(\n\u001b[1;32m 805\u001b[0m options, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`lineax.linear_solve(..., options=...)`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 806\u001b[0m )\n\u001b[1;32m 807\u001b[0m solver \u001b[38;5;241m=\u001b[39m eqxi\u001b[38;5;241m.\u001b[39mnondifferentiable(\n\u001b[1;32m 808\u001b[0m solver, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`lineax.linear_solve(..., solver=...)`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 809\u001b[0m )\n\u001b[0;32m--> 810\u001b[0m solution, result, stats \u001b[38;5;241m=\u001b[39m \u001b[43meqxi\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter_primitive_bind\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 811\u001b[0m \u001b[43m \u001b[49m\u001b[43mlinear_solve_p\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msolver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mthrow\u001b[49m\n\u001b[1;32m 812\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# TODO: prevent forward-mode autodiff through stats\u001b[39;00m\n\u001b[1;32m 814\u001b[0m stats \u001b[38;5;241m=\u001b[39m eqxi\u001b[38;5;241m.\u001b[39mnondifferentiable_backward(stats)\n", + "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/equinox/internal/_primitive.py:272\u001b[0m, in \u001b[0;36mfilter_primitive_bind\u001b[0;34m(prim, *args)\u001b[0m\n\u001b[1;32m 270\u001b[0m static \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(_missing_dynamic \u001b[38;5;28;01mif\u001b[39;00m is_array(x) \u001b[38;5;28;01melse\u001b[39;00m x \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m flat)\n\u001b[1;32m 271\u001b[0m flatten \u001b[38;5;241m=\u001b[39m Flatten()\n\u001b[0;32m--> 272\u001b[0m flat_out \u001b[38;5;241m=\u001b[39m \u001b[43mprim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdynamic\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtreedef\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtreedef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstatic\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstatic\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflatten\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mflatten\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 273\u001b[0m treedef_out, static_out \u001b[38;5;241m=\u001b[39m flatten\u001b[38;5;241m.\u001b[39mget()\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m combine(jtu\u001b[38;5;241m.\u001b[39mtree_unflatten(treedef_out, flat_out), static_out)\n", + " \u001b[0;31m[... skipping hidden 5 frame]\u001b[0m\n", + "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/equinox/internal/_primitive.py:155\u001b[0m, in \u001b[0;36mfilter_primitive_def.._wrapper\u001b[0;34m(treedef, static, flatten, *dynamic)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39mdynamic, treedef, static, flatten):\n\u001b[1;32m 154\u001b[0m args \u001b[38;5;241m=\u001b[39m jtu\u001b[38;5;241m.\u001b[39mtree_unflatten(treedef, _combine(dynamic, static))\n\u001b[0;32m--> 155\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mrule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m flatten(out)\n", + "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solve.py:116\u001b[0m, in \u001b[0;36m_linear_solve_abstract_eval\u001b[0;34m(operator, state, vector, options, solver, throw)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;129m@eqxi\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_primitive_def\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_linear_solve_abstract_eval\u001b[39m(operator, state, vector, options, solver, throw):\n\u001b[1;32m 113\u001b[0m state, vector, options, solver \u001b[38;5;241m=\u001b[39m jtu\u001b[38;5;241m.\u001b[39mtree_map(\n\u001b[1;32m 114\u001b[0m _to_struct, (state, vector, options, solver)\n\u001b[1;32m 115\u001b[0m )\n\u001b[0;32m--> 116\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter_eval_shape\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 117\u001b[0m \u001b[43m \u001b[49m\u001b[43m_linear_solve_impl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[43m \u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43msolver\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43mthrow\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 124\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_closure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 125\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 126\u001b[0m out \u001b[38;5;241m=\u001b[39m jtu\u001b[38;5;241m.\u001b[39mtree_map(_to_shapedarray, out)\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n", + " \u001b[0;31m[... skipping hidden 14 frame]\u001b[0m\n", + "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solve.py:87\u001b[0m, in \u001b[0;36m_linear_solve_impl\u001b[0;34m(_, state, vector, options, solver, throw, check_closure)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_linear_solve_impl\u001b[39m(_, state, vector, options, solver, throw, \u001b[38;5;241m*\u001b[39m, check_closure):\n\u001b[0;32m---> 87\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43msolver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_closure:\n\u001b[1;32m 89\u001b[0m out \u001b[38;5;241m=\u001b[39m eqxi\u001b[38;5;241m.\u001b[39mnontraceable(\n\u001b[1;32m 90\u001b[0m out, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlineax.linear_solve with respect to a closed-over value\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 91\u001b[0m )\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solver/gmres.py:124\u001b[0m, in \u001b[0;36mGMRES.compute\u001b[0;34m(self, state, vector, options)\u001b[0m\n\u001b[1;32m 122\u001b[0m b_scale \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39matol \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrtol \u001b[38;5;241m*\u001b[39m ω(vector)\u001b[38;5;241m.\u001b[39mcall(jnp\u001b[38;5;241m.\u001b[39mabs))\u001b[38;5;241m.\u001b[39mω\n\u001b[1;32m 123\u001b[0m operator \u001b[38;5;241m=\u001b[39m state\n\u001b[0;32m--> 124\u001b[0m preconditioner, y0 \u001b[38;5;241m=\u001b[39m \u001b[43mpreconditioner_and_y0\u001b[49m\u001b[43m(\u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m leaves, _ \u001b[38;5;241m=\u001b[39m jtu\u001b[38;5;241m.\u001b[39mtree_flatten(vector)\n\u001b[1;32m 126\u001b[0m size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(leaf\u001b[38;5;241m.\u001b[39msize \u001b[38;5;28;01mfor\u001b[39;00m leaf \u001b[38;5;129;01min\u001b[39;00m leaves)\n", + "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solver/misc.py:55\u001b[0m, in \u001b[0;36mpreconditioner_and_y0\u001b[0;34m(operator, vector, options)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 51\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe preconditioner must have `out_structure` that matches the \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moperator\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms `in_structure`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 53\u001b[0m )\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_positive_semidefinite(preconditioner):\n\u001b[0;32m---> 55\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe preconditioner must be positive definite.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 57\u001b[0m y0 \u001b[38;5;241m=\u001b[39m options[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my0\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", + "\u001b[0;31mValueError\u001b[0m: The preconditioner must be positive definite." + ] + } + ], + "source": [ + "@jax.jit\n", + "def get_diagonal(matrix):\n", + " \"\"\"\n", + " Extract the diagonal from a sparse matrix.\n", + " \"\"\"\n", + " is_diag = matrix.indices[:, 0] == matrix.indices[:, 1]\n", + " diag_values = jnp.where(is_diag, matrix.data, 0)\n", + " diag = jnp.zeros(matrix.shape[0], dtype=matrix.data.dtype)\n", + " diag = diag.at[matrix.indices[:, 0]].add(diag_values)\n", + " return diag\n", + "jacobi = get_diagonal(A)\n", + "preconditioner = lx.FunctionLinearOperator(lambda x: x / jacobi, in_structure)\n", + "\n", + "solver = lx.GMRES(atol=1e-5, rtol=1e-5)\n", + "x = lx.linear_solve(operator, b, solver=solver, options={\"preconditioner\": preconditioner}).value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check the residual norm\n", + "error = jnp.linalg.norm(b - (A @ x))\n", + "error" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's much better! More advanced preconditioners such as multigrid preconditioners could be used to further improve the convergence of the solver." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 9aa97bfc29fe1703876f3217fa203e8d83205427 Mon Sep 17 00:00:00 2001 From: Victor Boussange Date: Thu, 19 Dec 2024 10:41:45 +0100 Subject: [PATCH 2/2] fixed preconditioner tags --- docs/examples/preconditioner.ipynb | 93 +++++++++++++++--------------- 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/docs/examples/preconditioner.ipynb b/docs/examples/preconditioner.ipynb index 20ec7d5..253c6b0 100644 --- a/docs/examples/preconditioner.ipynb +++ b/docs/examples/preconditioner.ipynb @@ -20,31 +20,30 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ - "\n", "import jax\n", "import jax.numpy as jnp\n", "import jax.random as jr\n", "from scipy.sparse import diags, kron, eye\n", - "from jax.experimental.sparse import BCOO\n", + "import jax.experimental.sparse as js\n", "import lineax as lx\n", "\n", - "def create_sparse_2d_laplacian(n, m):\n", + "def poisson(n, m):\n", " \"\"\"\n", " Create a 2D Laplacian matrix on an n-by-m grid as a JAX BCOO sparse matrix.\n", " \"\"\"\n", - " lap_1d_n = diags([1, -2, 1], [-1, 0, 1], shape=(n, n), format=\"csr\")\n", - " lap_1d_m = diags([1, -2, 1], [-1, 0, 1], shape=(m, m), format=\"csr\")\n", + " lap_1d_n = diags([-1, 2, -1], [-1, 0, 1], shape=(n, n), format=\"csr\")\n", + " lap_1d_m = diags([-1, 2, -1], [-1, 0, 1], shape=(m, m), format=\"csr\")\n", " lap_2d = kron(eye(m, format=\"csr\"), lap_1d_n) + kron(lap_1d_m, eye(n, format=\"csr\"))\n", - " return BCOO.from_scipy_sparse(lap_2d)\n", + " return js.BCOO.from_scipy_sparse(lap_2d)\n", "\n", "\n", "# Set up the problem: A x = b\n", "n, m = 200, 200\n", - "A = create_sparse_2d_laplacian(n, m)\n", + "A = poisson(n, m)\n", "key = jr.PRNGKey(0)\n", "b = jr.uniform(key, (A.shape[0],))\n", "\n", @@ -55,18 +54,27 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Our Laplacian matrix `A` is a large sparse matrix of size `(n*m, n*m)`. We do not want to materialize it with a `MatrixLinearOperator`, which only supports dense matrices. Instead, we define a `FunctionLinearOperator` that computes the sparse matrix-vector `A @ x` product." + "Our Laplacian matrix `A` is a large sparse matrix of size `(n*m, n*m)`. We do not want to materialize it with a `MatrixLinearOperator`, which only supports dense matrices. Instead, we define a `SparseMatrixLinearOperator` that computes the sparse matrix-vector `A @ x` product." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "# Define operator and solve with GMRES\n", - "operator = lx.FunctionLinearOperator(lambda x: A @ x, in_structure)\n", - "solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=100)\n", + "class SparseMatrixLinearOperator(lx.MatrixLinearOperator):\n", + " def mv(self, vector):\n", + " return self.matrix @ vector\n", + " \n", + "\n", + "@lx.is_positive_semidefinite.register(SparseMatrixLinearOperator)\n", + "def _(op):\n", + " return True\n", + " \n", + "operator = SparseMatrixLinearOperator(A)\n", + "solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=30)\n", "x = lx.linear_solve(operator, b, solver=solver, throw=False).value" ] }, @@ -79,16 +87,16 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Array(0.5655846, dtype=float32)" + "Array(19.014511, dtype=float32)" ] }, - "execution_count": 3, + "execution_count": 53, "metadata": {}, "output_type": "execute_result" } @@ -108,32 +116,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 54, "metadata": {}, - "outputs": [ - { - "ename": "ValueError", - "evalue": "The preconditioner must be positive definite.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m preconditioner \u001b[38;5;241m=\u001b[39m lx\u001b[38;5;241m.\u001b[39mFunctionLinearOperator(\u001b[38;5;28;01mlambda\u001b[39;00m x: x \u001b[38;5;241m/\u001b[39m jacobi, in_structure)\n\u001b[1;32m 14\u001b[0m solver \u001b[38;5;241m=\u001b[39m lx\u001b[38;5;241m.\u001b[39mGMRES(atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m, rtol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mlx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear_solve\u001b[49m\u001b[43m(\u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msolver\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msolver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpreconditioner\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreconditioner\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mvalue\n", - " \u001b[0;31m[... skipping hidden 15 frame]\u001b[0m\n", - "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solve.py:810\u001b[0m, in \u001b[0;36mlinear_solve\u001b[0;34m(operator, vector, solver, options, state, throw)\u001b[0m\n\u001b[1;32m 804\u001b[0m options \u001b[38;5;241m=\u001b[39m eqxi\u001b[38;5;241m.\u001b[39mnondifferentiable(\n\u001b[1;32m 805\u001b[0m options, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`lineax.linear_solve(..., options=...)`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 806\u001b[0m )\n\u001b[1;32m 807\u001b[0m solver \u001b[38;5;241m=\u001b[39m eqxi\u001b[38;5;241m.\u001b[39mnondifferentiable(\n\u001b[1;32m 808\u001b[0m solver, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`lineax.linear_solve(..., solver=...)`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 809\u001b[0m )\n\u001b[0;32m--> 810\u001b[0m solution, result, stats \u001b[38;5;241m=\u001b[39m \u001b[43meqxi\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter_primitive_bind\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 811\u001b[0m \u001b[43m \u001b[49m\u001b[43mlinear_solve_p\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msolver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mthrow\u001b[49m\n\u001b[1;32m 812\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# TODO: prevent forward-mode autodiff through stats\u001b[39;00m\n\u001b[1;32m 814\u001b[0m stats \u001b[38;5;241m=\u001b[39m eqxi\u001b[38;5;241m.\u001b[39mnondifferentiable_backward(stats)\n", - "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/equinox/internal/_primitive.py:272\u001b[0m, in \u001b[0;36mfilter_primitive_bind\u001b[0;34m(prim, *args)\u001b[0m\n\u001b[1;32m 270\u001b[0m static \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(_missing_dynamic \u001b[38;5;28;01mif\u001b[39;00m is_array(x) \u001b[38;5;28;01melse\u001b[39;00m x \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m flat)\n\u001b[1;32m 271\u001b[0m flatten \u001b[38;5;241m=\u001b[39m Flatten()\n\u001b[0;32m--> 272\u001b[0m flat_out \u001b[38;5;241m=\u001b[39m \u001b[43mprim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdynamic\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtreedef\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtreedef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstatic\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstatic\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflatten\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mflatten\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 273\u001b[0m treedef_out, static_out \u001b[38;5;241m=\u001b[39m flatten\u001b[38;5;241m.\u001b[39mget()\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m combine(jtu\u001b[38;5;241m.\u001b[39mtree_unflatten(treedef_out, flat_out), static_out)\n", - " \u001b[0;31m[... skipping hidden 5 frame]\u001b[0m\n", - "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/equinox/internal/_primitive.py:155\u001b[0m, in \u001b[0;36mfilter_primitive_def.._wrapper\u001b[0;34m(treedef, static, flatten, *dynamic)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39mdynamic, treedef, static, flatten):\n\u001b[1;32m 154\u001b[0m args \u001b[38;5;241m=\u001b[39m jtu\u001b[38;5;241m.\u001b[39mtree_unflatten(treedef, _combine(dynamic, static))\n\u001b[0;32m--> 155\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mrule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m flatten(out)\n", - "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solve.py:116\u001b[0m, in \u001b[0;36m_linear_solve_abstract_eval\u001b[0;34m(operator, state, vector, options, solver, throw)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;129m@eqxi\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_primitive_def\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_linear_solve_abstract_eval\u001b[39m(operator, state, vector, options, solver, throw):\n\u001b[1;32m 113\u001b[0m state, vector, options, solver \u001b[38;5;241m=\u001b[39m jtu\u001b[38;5;241m.\u001b[39mtree_map(\n\u001b[1;32m 114\u001b[0m _to_struct, (state, vector, options, solver)\n\u001b[1;32m 115\u001b[0m )\n\u001b[0;32m--> 116\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter_eval_shape\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 117\u001b[0m \u001b[43m \u001b[49m\u001b[43m_linear_solve_impl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[43m \u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43msolver\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43mthrow\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 124\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_closure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 125\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 126\u001b[0m out \u001b[38;5;241m=\u001b[39m jtu\u001b[38;5;241m.\u001b[39mtree_map(_to_shapedarray, out)\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n", - " \u001b[0;31m[... skipping hidden 14 frame]\u001b[0m\n", - "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solve.py:87\u001b[0m, in \u001b[0;36m_linear_solve_impl\u001b[0;34m(_, state, vector, options, solver, throw, check_closure)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_linear_solve_impl\u001b[39m(_, state, vector, options, solver, throw, \u001b[38;5;241m*\u001b[39m, check_closure):\n\u001b[0;32m---> 87\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43msolver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_closure:\n\u001b[1;32m 89\u001b[0m out \u001b[38;5;241m=\u001b[39m eqxi\u001b[38;5;241m.\u001b[39mnontraceable(\n\u001b[1;32m 90\u001b[0m out, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlineax.linear_solve with respect to a closed-over value\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 91\u001b[0m )\n", - " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", - "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solver/gmres.py:124\u001b[0m, in \u001b[0;36mGMRES.compute\u001b[0;34m(self, state, vector, options)\u001b[0m\n\u001b[1;32m 122\u001b[0m b_scale \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39matol \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrtol \u001b[38;5;241m*\u001b[39m ω(vector)\u001b[38;5;241m.\u001b[39mcall(jnp\u001b[38;5;241m.\u001b[39mabs))\u001b[38;5;241m.\u001b[39mω\n\u001b[1;32m 123\u001b[0m operator \u001b[38;5;241m=\u001b[39m state\n\u001b[0;32m--> 124\u001b[0m preconditioner, y0 \u001b[38;5;241m=\u001b[39m \u001b[43mpreconditioner_and_y0\u001b[49m\u001b[43m(\u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m leaves, _ \u001b[38;5;241m=\u001b[39m jtu\u001b[38;5;241m.\u001b[39mtree_flatten(vector)\n\u001b[1;32m 126\u001b[0m size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(leaf\u001b[38;5;241m.\u001b[39msize \u001b[38;5;28;01mfor\u001b[39;00m leaf \u001b[38;5;129;01min\u001b[39;00m leaves)\n", - "File \u001b[0;32m~/projects/connectivity/connectivity_analysis/code/python/.env/lib/python3.12/site-packages/lineax/_solver/misc.py:55\u001b[0m, in \u001b[0;36mpreconditioner_and_y0\u001b[0;34m(operator, vector, options)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 51\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe preconditioner must have `out_structure` that matches the \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moperator\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms `in_structure`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 53\u001b[0m )\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_positive_semidefinite(preconditioner):\n\u001b[0;32m---> 55\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe preconditioner must be positive definite.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 57\u001b[0m y0 \u001b[38;5;241m=\u001b[39m options[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my0\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", - "\u001b[0;31mValueError\u001b[0m: The preconditioner must be positive definite." - ] - } - ], + "outputs": [], "source": [ "@jax.jit\n", "def get_diagonal(matrix):\n", @@ -146,17 +131,35 @@ " diag = diag.at[matrix.indices[:, 0]].add(diag_values)\n", " return diag\n", "jacobi = get_diagonal(A)\n", - "preconditioner = lx.FunctionLinearOperator(lambda x: x / jacobi, in_structure)\n", "\n", - "solver = lx.GMRES(atol=1e-5, rtol=1e-5)\n", - "x = lx.linear_solve(operator, b, solver=solver, options={\"preconditioner\": preconditioner}).value" + "preconditioner = lx.FunctionLinearOperator(lambda x: x / jacobi, \n", + " in_structure, \n", + " tags=[lx.positive_semidefinite_tag])\n", + "\n", + "solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=30)\n", + "x = lx.linear_solve(operator, \n", + " b, \n", + " solver=solver, \n", + " options={\"preconditioner\": preconditioner}, \n", + " throw=False).value" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 55, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Array(19.014511, dtype=float32)" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Check the residual norm\n", "error = jnp.linalg.norm(b - (A @ x))\n",