Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def assert_numerical_grads(f: Callable[..., Array],
difference gradients.
"""
# Correct scaling.
# Remove after https://github.com/google/jax/issues/3130 is fixed.
# Remove after https://github.com/jax-ml/jax/issues/3130 is fixed.
atol *= f_args[0].size

# Mock `jax.lax.stop_gradient` because finite diff. method does not honour it.
Expand Down
2 changes: 1 addition & 1 deletion chex/_src/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def set_n_cpu_devices(n: Optional[int] = None) -> None:
This allows `jax.pmap` to be tested on a single-CPU platform.
This utility only takes effect before XLA backends are initialized, i.e.
before any JAX operation is executed (including `jax.devices()` etc.).
See https://github.com/google/jax/issues/1408.
See https://github.com/jax-ml/jax/issues/1408.

Args:
n: A required number of CPU devices (``FLAGS.chex_n_cpu_devices`` is used by
Expand Down
Loading