fix derived tag check rules#192
Conversation
patrick-kidger
left a comment
There was a problem hiding this comment.
Aha, thank you! Don't know how these slipped in 😅 the fixes all look correct to me. I have some minor comments but that's it.
lineax/_operator.py
Outdated
| return check(operator.operator) | ||
|
|
||
|
|
||
| def _scalar_sign(scalar) -> int | None: |
There was a problem hiding this comment.
An enum.Enum might be a bit neater here?
lineax/_operator.py
Outdated
|
|
||
|
|
||
| def _scalar_sign(scalar) -> int | None: | ||
| """Returns scalar sign if known at trace time otherwise None.""" |
There was a problem hiding this comment.
This means that the behaviour will change depending on whether or not we are inside JIT, which I think is undesirable. I think it'd probably be better to just always treat JAX arrays as unknown? The JAX-array-outside-JIT case is fairly edge-case anyway since so little computation ever happens outside JIT.
So I think we'd end up replacing the try-except with a isinstance(scalar, (int, float, np.ndarray, np.generic)).
WDYT?
There was a problem hiding this comment.
That would work too, but behaviour may still differ in or outside of jit if jax.jit is used instead eqx.filter_jit is used which could convert all these type to tracers? If you'd rather have it always unknown I can sympathise with that too.
There was a problem hiding this comment.
I'm okay without the jit/filter_jit equivalency here, since that's already true across a whole host of cases – and making that possible is really the point of filter_jit in the first place.
(Coming at this a few years later, I really do not like that we had to introduce filter_* though. Sometimes I wonder if we should have mandated trees with all-arrays and tackled this point with static fields instead. That ship has very definitely sailed, however.)
|
Another idea for |
Haha! This is a good observation. I think the correct separation of concerns here would be for It's fairly edge-case so that comes under "happy to take a PR on that" if you feel strongly! |
|
LGTM, merged! Thank you for the fix. 🎉 |
Just noticed that a few of these were incorrect for composed/derived linear operator, think all fixed as best we can for now.