-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Labels
questionUser queriesUser queries
Description
I am looking to extend the diffrax example on the 1D heat equation to use a Tridiagonal solve, unfortunately lineax can not tell a generic 2D array has tridiagonal structure. If this is not possible, could we make a change to optimistix to create the desired LinearOperator from the jacobian
MWE for context:
import jax.numpy as jnp
import jax
import lineax
import optimistix
jax.config.update("jax_enable_x64", True)
def laplacian_numerator(y):
return jnp.diff(jnp.diff(y), prepend=0., append=0.)
def diffus_step(y, args):
return y + 0.1 * laplacian_numerator(y)
solver = optimistix.Newton(rtol=1e-8, atol=1e-8, linear_solver=lineax.Tridiagonal())
key = jax.random.PRNGKey(0)
rhs = jax.random.normal(key, (100,))
@jax.jit
def loss(y, args):
return diffus_step(y, args) - rhs
sol = optimistix.root_find(lambda y, args: diffus_step(y, args) - rhs, solver, rhs, max_steps=1, throw=False)Returns the follow traceback
...
--> 220 return iterative_solve(
221 fn,
222 solver,
223 y0,
224 args,
225 options,
226 max_steps=max_steps,
227 adjoint=adjoint,
228 throw=throw,
229 tags=tags,
230 f_struct=f_struct,
231 aux_struct=aux_struct,
232 rewrite_fn=_rewrite_fn,
...
53 "matrices"
54 )
55 return tridiagonal(operator), pack_structures(operator)
ValueError: `Tridiagonal` may only be used for linear solves with tridiagonal matrices
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries