Fix class comparison stability for JAX uniqueness#3
Conversation
- Update `_CounterMeta` in `thrml/pgm.py` to assign a unique `_class_id`. - Use `_class_id` as a tie-breaker in `__lt__` to ensure strict, deterministic ordering. - Add `tests/test_pgm_ordering.py` for verification. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
|
👋 Jules, reporting for duty! I'm here to lend a hand with this pull request. When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down. I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job! For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with New to Jules? Learn more at jules.google/docs. For security, I will only act on instructions from the user who triggered this task. |
Summary of ChangesHello @igor-holt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a potential source of non-deterministic behavior in JAX pytree flattening and other dictionary-based operations. By assigning a unique, incremental ID to classes created via Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to ensure stable and unique ordering for dynamically created classes by adding a _class_id to _CounterMeta. This is a good solution to prevent non-deterministic behavior in JAX pytrees. The implementation is clean and is accompanied by a comprehensive set of new tests.
My main feedback is to improve the thread safety of the class creation counter to prevent potential race conditions in a multi-threaded environment. Other than that, the changes look solid.
|
@igor-holt I've opened a new pull request, #13, to work on those changes. Once the pull request is ready, I'll request review from you. |
There was a problem hiding this comment.
Pull request overview
This PR addresses non-deterministic sorting behavior in JAX pytree operations by introducing a unique class identifier system for AbstractNode subclasses. The change ensures that dynamically created classes with identical names and modules can still be deterministically ordered.
Changes:
- Added
_class_idcounter to_CounterMetametaclass to assign unique IDs to each class at creation time - Modified
__lt__comparison method to include_class_idin the ordering tuple, ensuring uniqueness - Added comprehensive tests to verify class ordering stability and uniqueness
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| thrml/pgm.py | Implements _class_id counter in _CounterMeta.__init__ and updates __lt__ to use it for unique class ordering |
| tests/test_pgm_ordering.py | New test file verifying deterministic sorting behavior for classes with identical module/qualname attributes |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
|
|
There was a problem hiding this comment.
The metaclass implements lt but does not override eq and hash. According to Python's data model, when you implement comparison methods like lt, it's recommended to also implement eq and hash for consistency, especially since classes will be used as dictionary keys (as seen in BlockSpec.node_shape_dtypes). While Python's default type comparison might work, explicitly defining these methods would ensure consistent behavior when classes are used in sets, dictionaries, or for equality comparisons. This is particularly important given that the PR is about ensuring deterministic behavior in JAX operations.
| def __eq__(cls, other): | |
| if not isinstance(other, type): | |
| return NotImplemented | |
| return cls is other | |
| def __hash__(cls): | |
| # Preserve default identity-based hashing for classes. | |
| return object.__hash__(cls) |
| self_id = getattr(cls, "_class_id", -1) | ||
| other_id = getattr(other, "_class_id", -1) |
There was a problem hiding this comment.
The use of getattr with a default value of -1 for _class_id could be problematic. If a class somehow doesn't have _class_id set (e.g., due to metaclass conflicts or inheritance issues), using -1 as the fallback means it will sort before all properly initialized classes. This could mask bugs or create unexpected ordering. Consider raising an error instead if _class_id is missing, or document why -1 is an acceptable fallback for classes not created through this metaclass.
| self_id = getattr(cls, "_class_id", -1) | |
| other_id = getattr(other, "_class_id", -1) | |
| try: | |
| self_id = cls._class_id | |
| other_id = other._class_id | |
| except AttributeError as exc: | |
| raise TypeError( | |
| "Class comparison is only supported between types created with _CounterMeta " | |
| "that define a '_class_id' attribute." | |
| ) from exc |
tests/test_pgm_ordering.py
Outdated
| @@ -0,0 +1,55 @@ | |||
| import pytest | |||
There was a problem hiding this comment.
The pytest module is imported but not used. The tests use plain assert statements without any pytest-specific features like fixtures, parameterization, or markers. Consider removing the unused import or leveraging pytest features like @pytest.mark.parametrize for the test variations.
| import pytest |
tests/test_pgm_ordering.py
Outdated
| # 'A' < 'B' alphabetically | ||
| # Assuming module is same | ||
| assert A < B | ||
| assert not (B < A) |
There was a problem hiding this comment.
The test should verify the behavior when comparing classes with _CounterMeta against classes that don't have _CounterMeta. The lt method raises NotImplementedError for non-type comparisons, but it's unclear how it behaves when comparing a _CounterMeta class with a regular type. Adding a test case for this scenario would ensure the comparison logic handles edge cases correctly.
| assert not (B < A) | |
| assert not (B < A) | |
| def test_ordering_with_non_countermeta_type(): | |
| """Verify behavior when comparing _CounterMeta classes with regular types.""" | |
| # Create a class using _CounterMeta | |
| MetaClass = _CounterMeta("MetaClass", (), {}) | |
| # Regular (non-_CounterMeta) type | |
| class RegularClass: | |
| pass | |
| # Comparing a _CounterMeta class with a non-_CounterMeta type should | |
| # raise NotImplementedError to avoid ambiguous ordering. | |
| with pytest.raises(NotImplementedError): | |
| _ = MetaClass < RegularClass | |
| with pytest.raises(NotImplementedError): | |
| _ = RegularClass < MetaClass |
thrml/pgm.py
Outdated
| cls._class_id = _CounterMeta._class_creation_counter | ||
| _CounterMeta._class_creation_counter += 1 |
There was a problem hiding this comment.
The counter increment operation is not thread-safe. If classes are created concurrently (e.g., during module imports in a multi-threaded environment or with parallel loading), multiple classes could receive the same _class_id value, which would break the uniqueness guarantee that this fix is trying to establish. Consider using threading.Lock to protect the counter increment, or document that classes must not be created concurrently.
tests/test_pgm_ordering.py
Outdated
| # 'A' < 'B' alphabetically | ||
| # Assuming module is same | ||
| assert A < B | ||
| assert not (B < A) |
There was a problem hiding this comment.
Test is always false, because of this condition.
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 131ced0a66
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
thrml/pgm.py
Outdated
| cls._class_id = _CounterMeta._class_creation_counter | ||
| _CounterMeta._class_creation_counter += 1 |
There was a problem hiding this comment.
Serialize class-id counter updates
The _class_id assignment uses an unsynchronized read/modify/write sequence, so if classes are generated from multiple threads (for example in plugin loading or parallel test code), two class creations can observe the same _class_creation_counter value and receive duplicate IDs. In that case the new tie-breaker no longer guarantees a strict ordering for classes with identical __module__/__qualname__, reintroducing the non-deterministic sort behavior this change is meant to eliminate.
Useful? React with 👍 / 👎.
- Update `thrml/pgm.py` and `tests/test_pgm_ordering.py` to match `ruff` and `black` formatting requirements. - Fix import ordering and whitespace. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # Comparison should prioritize name/module | ||
| # 'A' < 'B' alphabetically | ||
| self.assertTrue(A < B) |
There was a problem hiding this comment.
assertTrue(a < b) cannot provide an informative message. Using assertLess(a, b) instead will give more informative messages.
| self.assertTrue(A < B) | |
| self.assertLess(A, B) |
| # Comparison should prioritize name/module | ||
| # 'A' < 'B' alphabetically | ||
| self.assertTrue(A < B) | ||
| self.assertFalse(B < A) |
There was a problem hiding this comment.
assertFalse(a < b) cannot provide an informative message. Using assertGreaterEqual(a, b) instead will give more informative messages.
| self.assertFalse(B < A) | |
| self.assertGreaterEqual(B, A) |
|
@igor-holt I've opened a new pull request, #15, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
This PR fixes potential non-deterministic sorting behavior in JAX pytree flattening or other dictionary-based operations by ensuring a unique ordering for
AbstractNodesubclasses, even when created dynamically with identical names. It introduces a_class_idcounter to disambiguate such classes.PR created automatically by Jules for task 9360936800683187626 started by @igor-holt