Pytree Polynomial Object Abstractions#48
Pytree Polynomial Object Abstractions#48FelixBenning wants to merge 19 commits intof0uriest:mainfrom
Conversation
| else: # no diagnostic info is returned -> warning callback | ||
|
|
||
| def rank_warn(rank, order): | ||
| if rank != order: | ||
| msg = "The fit may be poorly conditioned" | ||
| warnings.warn(msg, np.exceptions.RankWarning, stacklevel=2) | ||
|
|
||
| jax.debug.callback(rank_warn, rank, order) |
There was a problem hiding this comment.
My first step was to try and add all tests from numpy that were left out because the objects where missing. One test checks if a RankWarning is emitted and I noticed that this logic was removed from the _fit function in polyutils in the orthax implementation. I added this functionality here. It uses a jax.debug.callback which is compatible with everything (jit, differentiation, etc.) It may have a performance impact (I have not tested it), but you can avoid this impact by using the full return value and processing the diagnostic info yourself. Overall the performance impact of a single numeric equality check should be negligible if you do not fit in a loop.
| if isinstance(other, ABCPolyBase): | ||
| if not isinstance(other, self.__class__): | ||
| raise TypeError("Polynomial types differ") | ||
|
|
||
| def error_if_unequal_domain(domain_equal): | ||
| if not domain_equal: | ||
| raise TypeError("Domains differ") | ||
| jax.debug.callback(error_if_unequal_domain, jnp.all(self.domain == other.domain)) | ||
|
|
||
| def error_if_unequal_window(window_equal): | ||
| if not window_equal: | ||
| raise TypeError("Windows differ") | ||
| jax.debug.callback(error_if_unequal_window, jnp.all(self.window == other.window)) | ||
|
|
||
| def error_if_unequal_symbol(self, other): | ||
| if self.symbol != other.symbol: | ||
| raise ValueError("Polynomial symbols differ") |
There was a problem hiding this comment.
This is the most annoying thing so far: numpy treats two polynomials essentially as if they are of a different type if they do not have matching domain and window. This requires these callbacks to throw errors if this would be the case
orthax/chebyshev.py
Outdated
| if isinstance(x, (tuple, list)): | ||
| x = jnp.asarray(x) | ||
| if isinstance(x, jnp.ndarray) and tensor: | ||
| c = c.reshape(c.shape + (1,) * x.ndim) | ||
|
|
||
| # x may be an array or an object that implements scalar multiplication | ||
| # and addition, e.g. a Polynomial | ||
| if len(c) == 1: | ||
| c0 = c[0] | ||
| c1 = 0 | ||
| return c[0] + 0 * x # return type(x), | ||
| elif len(c) == 2: | ||
| c0 = c[0] | ||
| c1 = c[1] | ||
| else: | ||
| x2 = 2 * x | ||
| c0 = c[-2] * jnp.ones_like(x) | ||
| c1 = c[-1] * jnp.ones_like(x) | ||
| return c[0] + c[1] * x | ||
|
|
||
| def body(i, val): | ||
| c0, c1 = val | ||
| tmp = c0 | ||
| c0 = c[-i] - c1 | ||
| c1 = tmp + c1 * x2 | ||
| return c0, c1 | ||
| x2 = 2 * x # type(x) | ||
| c0 = c[-2] # scalar | ||
| c1 = c[-1] # scalar | ||
|
|
||
| def body(i, val): | ||
| c0, c1 = val | ||
| tmp = c0 | ||
| c0 = c[-i] - c1 | ||
| c1 = tmp + c1 * x2 # c1 becomes type(x) | ||
| return c0, c1 |
There was a problem hiding this comment.
I rewrote the chebval, polyval etc. functions in an attempt to allow polynomial types to be passed as x. Numpy uses this for type conversion (i.e. if you pass the identity function i.e. X in the target type, and then evaluate the polynomial at X while remaining in the target type, then you obtain the polynomial in the target type). This works since it is possible to multiply and add polynomials together, so you can pass a polynomial into chebval, etc. (and multiply them to scalar values). Unfortunately the multiplication of polynomials results in a polynomial with more coefficients, so the type is not stable which jax does not like.
- Option 1: Abandon this attempt and simply write a custom
convertmethod - Option 2: Use a normal python
ifloop that would be unrolled. This may not be such a bad idea since the length of the loop will generally be the length of the polynomial coefficients - unlikely to be extremely large right?
There was a problem hiding this comment.
I went ahead with Option 2. But as I said, the length of the polynomial coefficients is never going to be really big. There are virtually no usecases for thousands of coefficients in polynomials. So either the performance is non-critical anyway, or the performance hinges on fast evaluation after compilation in which cases the loop un-wrapping may actually be a performance benefit.
We could provide a loop implementation as a functional alternative if you want though.
There was a problem hiding this comment.
Option 2 has the benefit that chebval etc. now accept polynomial classes as input (which is essentially claimed in the doc string already)
5f0fe87 to
f1e3e3d
Compare
|
I removed an attempt to also provide an object for the general orthogonal polynomials in commit f1e3e3d, because the PR is already getting quite big. I am currently running all tests but unless something comes up this is a wrap from my part. In future PRs we could tackle the general orthogonal polynomials and maybe improve the performance of jax's autodiff with a custom jvp via the polynomial deriv implementation. But I do not want to blow up the diff even more. I will be away for much of the rest of september so I don't know if I can make changes, but I should be able to comment on any questions from time to time. |
|
One last thing with respect to the following code in c = pu.as_series(c)
if isinstance(x, (tuple, list)):
x = jnp.asarray(x)
if isinstance(x, jnp.ndarray) and tensor:
c = c.reshape(c.shape + (1,) * x.ndim)I was queasy whether or not this would break on numpy arrays but it seems like it works - I have no idea why.
EDIT: yeah - unjitted it breaks. Numpy has the opposite problem: >>> from numpy.polynomial import polynomial as poly
>>> b
Array([[1, 2, 3],
[2, 3, 4]], dtype=int32)
>>> poly.polyval(np.array(b),np.array(b), tensor=True)
array([[[ 3., 5., 7.],
[ 5., 7., 9.]],
[[ 5., 8., 11.],
[ 8., 11., 14.]],
[[ 7., 11., 15.],
[11., 15., 19.]]])
>>> poly.polyval(b,b, tensor=True)
Array([[ 3., 8., 15.],
[ 5., 11., 19.]], dtype=float32) |
|
The test added in e74194e breaks if the EDIT: Thinking about it, the unjitted version should not behave differently that is basically summoning a Heisenbug. I created an issue on the numpy repo to ask them if they want to generalize this to array_likes (numpy/numpy#29680). Otherwise (for drop-in compatibility with numpy) I would simply add a check for numpy array. What do you think? |
Implementation of #46