I noticed that all versions of dm-pix library require `jax>=0.4.16`, even the `0.1.0` version. Is that necessary? I wonder if this can be modified so that I can use dm-pix with `jax==0.3.14`.