From 3310edd2cdb164e9fa0b1b33cfea63f653de1efd Mon Sep 17 00:00:00 2001 From: ChexDev Date: Mon, 8 Jul 2024 10:24:40 -0700 Subject: [PATCH] Sort devices explicitly by process index, then id (as opposed to IDs alone). IDs may be randomly generated, and are not guaranteed to be ordered based on their process index. PiperOrigin-RevId: 650294623 --- chex/_src/asserts_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index fed49ceb..c172a1db 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -1200,12 +1200,14 @@ def test_assert_tree_is_on_device(self): asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_2) with self.assertRaisesRegex( - AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=0") + AssertionError, + _get_err_regex(r"'a' resides on.*TpuDevice\(process_index=0, id=0"), ): asserts.assert_tree_is_on_device(tpu_1_tree, device=tpu_2) with self.assertRaisesRegex( - AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=1") + AssertionError, + _get_err_regex(r"'a' resides on.*TpuDevice\(process_index=0, id=1"), ): asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_1) @@ -1735,7 +1737,7 @@ def test_assert_equal_pass(self, first, second): asserts.assert_equal(first, second) def test_assert_equal_pass_on_arrays(self): - # Not using named_parameters, becase JAX cannot be used before app.run(). + # Not using named_parameters, because JAX cannot be used before app.run(). asserts.assert_equal(jnp.ones([]), np.ones([])) asserts.assert_equal( jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64))