Skip to content

Is there a way to auto-recognise operator structure? #149

@jpbrodrick89

Description

@jpbrodrick89

I am looking to extend the diffrax example on the 1D heat equation to use a Tridiagonal solve, unfortunately lineax can not tell a generic 2D array has tridiagonal structure. If this is not possible, could we make a change to optimistix to create the desired LinearOperator from the jacobian

MWE for context:

import jax.numpy as jnp
import jax
import lineax
import optimistix
jax.config.update("jax_enable_x64", True)

def laplacian_numerator(y):
    return jnp.diff(jnp.diff(y), prepend=0., append=0.)

def diffus_step(y, args):
    return y + 0.1 * laplacian_numerator(y)

solver = optimistix.Newton(rtol=1e-8, atol=1e-8, linear_solver=lineax.Tridiagonal())

key = jax.random.PRNGKey(0)
rhs = jax.random.normal(key, (100,))

@jax.jit
def loss(y, args):
    return diffus_step(y, args) - rhs

sol = optimistix.root_find(lambda y, args: diffus_step(y, args) - rhs, solver, rhs, max_steps=1, throw=False)

Returns the follow traceback

...
--> 220 return iterative_solve(
    221     fn,
    222     solver,
    223     y0,
    224     args,
    225     options,
    226     max_steps=max_steps,
    227     adjoint=adjoint,
    228     throw=throw,
    229     tags=tags,
    230     f_struct=f_struct,
    231     aux_struct=aux_struct,
    232     rewrite_fn=_rewrite_fn,
...
     53         "matrices"
     54     )
     55 return tridiagonal(operator), pack_structures(operator)

ValueError: `Tridiagonal` may only be used for linear solves with tridiagonal matrices

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions