From 6b68715cf3132d32de029ea94decb9dd3a45dc55 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Wed, 4 Dec 2024 15:46:01 -0800 Subject: [PATCH] Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886821 --- chex/_src/asserts.py | 2 +- chex/_src/fake.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 7984ddac..11fefc3d 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -996,7 +996,7 @@ def assert_numerical_grads(f: Callable[..., Array], difference gradients. """ # Correct scaling. - # Remove after https://github.com/google/jax/issues/3130 is fixed. + # Remove after https://github.com/jax-ml/jax/issues/3130 is fixed. atol *= f_args[0].size # Mock `jax.lax.stop_gradient` because finite diff. method does not honour it. diff --git a/chex/_src/fake.py b/chex/_src/fake.py index 7d3d622e..9fe2a57d 100644 --- a/chex/_src/fake.py +++ b/chex/_src/fake.py @@ -57,7 +57,7 @@ def set_n_cpu_devices(n: Optional[int] = None) -> None: This allows `jax.pmap` to be tested on a single-CPU platform. This utility only takes effect before XLA backends are initialized, i.e. before any JAX operation is executed (including `jax.devices()` etc.). - See https://github.com/google/jax/issues/1408. + See https://github.com/jax-ml/jax/issues/1408. Args: n: A required number of CPU devices (``FLAGS.chex_n_cpu_devices`` is used by