Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions docs/examples/preconditioner.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using preconditioners\n",
"\n",
"\n",
"Preconditioning can notably improve the convergence of iterative methods. Preconditioners are particularly suited for solving sparse systems that arise from PDE problems. In this example, we will show how to use a simple Jacobi preconditioner (see [here](https://en.wikipedia.org/wiki/Preconditioner#Jacobi_(or_diagonal)_preconditioner)) to solve a 2D Laplacian linear system using `lx.cg`. We will first show the performance of the solver without preconditioning and then with Jacobi preconditioning.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a [Poisson problem in 2D](https://en.wikipedia.org/wiki/Discrete_Poisson_equation)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"from scipy.sparse import diags, kron, eye\n",
"import jax.experimental.sparse as js\n",
"import lineax as lx\n",
"\n",
"def poisson(n, m):\n",
" \"\"\"\n",
" Create a 2D Laplacian matrix on an n-by-m grid as a JAX BCOO sparse matrix.\n",
" \"\"\"\n",
" lap_1d_n = diags([-1, 2, -1], [-1, 0, 1], shape=(n, n), format=\"csr\")\n",
" lap_1d_m = diags([-1, 2, -1], [-1, 0, 1], shape=(m, m), format=\"csr\")\n",
" lap_2d = kron(eye(m, format=\"csr\"), lap_1d_n) + kron(lap_1d_m, eye(n, format=\"csr\"))\n",
" return js.BCOO.from_scipy_sparse(lap_2d)\n",
"\n",
"\n",
"# Set up the problem: A x = b\n",
"n, m = 200, 200\n",
"A = poisson(n, m)\n",
"key = jr.PRNGKey(0)\n",
"b = jr.uniform(key, (A.shape[0],))\n",
"\n",
"in_structure = jax.eval_shape(lambda: b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our Laplacian matrix `A` is a large sparse matrix of size `(n*m, n*m)`. We do not want to materialize it with a `MatrixLinearOperator`, which only supports dense matrices. Instead, we define a `SparseMatrixLinearOperator` that computes the sparse matrix-vector `A @ x` product."
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# Define operator and solve with GMRES\n",
"class SparseMatrixLinearOperator(lx.MatrixLinearOperator):\n",
" def mv(self, vector):\n",
" return self.matrix @ vector\n",
" \n",
"\n",
"@lx.is_positive_semidefinite.register(SparseMatrixLinearOperator)\n",
"def _(op):\n",
" return True\n",
" \n",
"operator = SparseMatrixLinearOperator(A)\n",
"solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=30)\n",
"x = lx.linear_solve(operator, b, solver=solver, throw=False).value"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check the performance of this solve."
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(19.014511, dtype=float32)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Check the residual norm\n",
"error = jnp.linalg.norm(b - (A @ x))\n",
"error"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pretty bad hey. Now we use a simple Jacobi preconditioner. We need to define another `FunctionLinearOperator` that computes the sparse matrix-vector `M @ x` product, where `M` is the Jacobi preconditioner. The Jacobi preconditioner is a diagonal matrix with the diagonal elements equal to the diagonal elements of `A`. We need to write a utility function to extract the diagonal of a `BCOO` matrix."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def get_diagonal(matrix):\n",
" \"\"\"\n",
" Extract the diagonal from a sparse matrix.\n",
" \"\"\"\n",
" is_diag = matrix.indices[:, 0] == matrix.indices[:, 1]\n",
" diag_values = jnp.where(is_diag, matrix.data, 0)\n",
" diag = jnp.zeros(matrix.shape[0], dtype=matrix.data.dtype)\n",
" diag = diag.at[matrix.indices[:, 0]].add(diag_values)\n",
" return diag\n",
"jacobi = get_diagonal(A)\n",
"\n",
"preconditioner = lx.FunctionLinearOperator(lambda x: x / jacobi, \n",
" in_structure, \n",
" tags=[lx.positive_semidefinite_tag])\n",
"\n",
"solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=30)\n",
"x = lx.linear_solve(operator, \n",
" b, \n",
" solver=solver, \n",
" options={\"preconditioner\": preconditioner}, \n",
" throw=False).value"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(19.014511, dtype=float32)"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Check the residual norm\n",
"error = jnp.linalg.norm(b - (A @ x))\n",
"error"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's much better! More advanced preconditioners such as multigrid preconditioners could be used to further improve the convergence of the solver."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading