Skip to content

Pytree Polynomial Object Abstractions#48

Open
FelixBenning wants to merge 19 commits intof0uriest:mainfrom
FelixBenning:objects
Open

Pytree Polynomial Object Abstractions#48
FelixBenning wants to merge 19 commits intof0uriest:mainfrom
FelixBenning:objects

Conversation

@FelixBenning
Copy link

Implementation of #46

Comment on lines +655 to +662
else: # no diagnostic info is returned -> warning callback

def rank_warn(rank, order):
if rank != order:
msg = "The fit may be poorly conditioned"
warnings.warn(msg, np.exceptions.RankWarning, stacklevel=2)

jax.debug.callback(rank_warn, rank, order)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first step was to try and add all tests from numpy that were left out because the objects where missing. One test checks if a RankWarning is emitted and I noticed that this logic was removed from the _fit function in polyutils in the orthax implementation. I added this functionality here. It uses a jax.debug.callback which is compatible with everything (jit, differentiation, etc.) It may have a performance impact (I have not tested it), but you can avoid this impact by using the full return value and processing the diagnostic info yourself. Overall the performance impact of a single numeric equality check should be negligible if you do not fit in a loop.

@FelixBenning FelixBenning marked this pull request as ready for review September 4, 2025 11:19
Comment on lines 285 to 301
if isinstance(other, ABCPolyBase):
if not isinstance(other, self.__class__):
raise TypeError("Polynomial types differ")

def error_if_unequal_domain(domain_equal):
if not domain_equal:
raise TypeError("Domains differ")
jax.debug.callback(error_if_unequal_domain, jnp.all(self.domain == other.domain))

def error_if_unequal_window(window_equal):
if not window_equal:
raise TypeError("Windows differ")
jax.debug.callback(error_if_unequal_window, jnp.all(self.window == other.window))

def error_if_unequal_symbol(self, other):
if self.symbol != other.symbol:
raise ValueError("Polynomial symbols differ")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the most annoying thing so far: numpy treats two polynomials essentially as if they are of a different type if they do not have matching domain and window. This requires these callbacks to throw errors if this would be the case

Comment on lines 983 to 1004
if isinstance(x, (tuple, list)):
x = jnp.asarray(x)
if isinstance(x, jnp.ndarray) and tensor:
c = c.reshape(c.shape + (1,) * x.ndim)

# x may be an array or an object that implements scalar multiplication
# and addition, e.g. a Polynomial
if len(c) == 1:
c0 = c[0]
c1 = 0
return c[0] + 0 * x # return type(x),
elif len(c) == 2:
c0 = c[0]
c1 = c[1]
else:
x2 = 2 * x
c0 = c[-2] * jnp.ones_like(x)
c1 = c[-1] * jnp.ones_like(x)
return c[0] + c[1] * x

def body(i, val):
c0, c1 = val
tmp = c0
c0 = c[-i] - c1
c1 = tmp + c1 * x2
return c0, c1
x2 = 2 * x # type(x)
c0 = c[-2] # scalar
c1 = c[-1] # scalar

def body(i, val):
c0, c1 = val
tmp = c0
c0 = c[-i] - c1
c1 = tmp + c1 * x2 # c1 becomes type(x)
return c0, c1
Copy link
Author

@FelixBenning FelixBenning Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote the chebval, polyval etc. functions in an attempt to allow polynomial types to be passed as x. Numpy uses this for type conversion (i.e. if you pass the identity function i.e. X in the target type, and then evaluate the polynomial at X while remaining in the target type, then you obtain the polynomial in the target type). This works since it is possible to multiply and add polynomials together, so you can pass a polynomial into chebval, etc. (and multiply them to scalar values). Unfortunately the multiplication of polynomials results in a polynomial with more coefficients, so the type is not stable which jax does not like.

  • Option 1: Abandon this attempt and simply write a custom convert method
  • Option 2: Use a normal python if loop that would be unrolled. This may not be such a bad idea since the length of the loop will generally be the length of the polynomial coefficients - unlikely to be extremely large right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead with Option 2. But as I said, the length of the polynomial coefficients is never going to be really big. There are virtually no usecases for thousands of coefficients in polynomials. So either the performance is non-critical anyway, or the performance hinges on fast evaluation after compilation in which cases the loop un-wrapping may actually be a performance benefit.

We could provide a loop implementation as a functional alternative if you want though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 2 has the benefit that chebval etc. now accept polynomial classes as input (which is essentially claimed in the doc string already)

@FelixBenning
Copy link
Author

FelixBenning commented Sep 5, 2025

I removed an attempt to also provide an object for the general orthogonal polynomials in commit f1e3e3d, because the PR is already getting quite big. I am currently running all tests but unless something comes up this is a wrap from my part.

In future PRs we could tackle the general orthogonal polynomials and maybe improve the performance of jax's autodiff with a custom jvp via the polynomial deriv implementation. But I do not want to blow up the diff even more.

I will be away for much of the rest of september so I don't know if I can make changes, but I should be able to comment on any questions from time to time.

@FelixBenning
Copy link
Author

FelixBenning commented Sep 5, 2025

One last thing with respect to the following code in polyval (and chebval etc.)

    c = pu.as_series(c)
    if isinstance(x, (tuple, list)):
        x = jnp.asarray(x)
    if isinstance(x, jnp.ndarray) and tensor:
        c = c.reshape(c.shape + (1,) * x.ndim)

I was queasy whether or not this would break on numpy arrays but it seems like it works - I have no idea why.

isinstance(x, jnp.ndarray) fails on np.arrays but it appears that it works somehow. Maybe jax automatically converts numpy arrays into jnp.arrays for jitted functions and thus cannot tell the difference. This should maybe be tested somewhere

EDIT: yeah - unjitted it breaks.

Numpy has the opposite problem:

>>> from numpy.polynomial import polynomial as poly
>>> b
Array([[1, 2, 3],
       [2, 3, 4]], dtype=int32)
>>> poly.polyval(np.array(b),np.array(b), tensor=True)
array([[[ 3.,  5.,  7.],
        [ 5.,  7.,  9.]],

       [[ 5.,  8., 11.],
        [ 8., 11., 14.]],

       [[ 7., 11., 15.],
        [11., 15., 19.]]])
>>> poly.polyval(b,b, tensor=True)
Array([[ 3.,  8., 15.],
       [ 5., 11., 19.]], dtype=float32)

@FelixBenning
Copy link
Author

FelixBenning commented Sep 5, 2025

The test added in e74194e breaks if the @jit decorator is removed (and isinstance(x, jnp.ndarray) fails on numpy arrays) - so it gives warning if someone does something stupid.

EDIT: Thinking about it, the unjitted version should not behave differently that is basically summoning a Heisenbug. I created an issue on the numpy repo to ask them if they want to generalize this to array_likes (numpy/numpy#29680). Otherwise (for drop-in compatibility with numpy) I would simply add a check for numpy array. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant