Skip to content

Fix class comparison stability for JAX uniqueness#3

Open
igor-holt wants to merge 4 commits intomainfrom
fix-jax-class-comparison-9360936800683187626
Open

Fix class comparison stability for JAX uniqueness#3
igor-holt wants to merge 4 commits intomainfrom
fix-jax-class-comparison-9360936800683187626

Conversation

@igor-holt
Copy link
Owner

This PR fixes potential non-deterministic sorting behavior in JAX pytree flattening or other dictionary-based operations by ensuring a unique ordering for AbstractNode subclasses, even when created dynamically with identical names. It introduces a _class_id counter to disambiguate such classes.


PR created automatically by Jules for task 9360936800683187626 started by @igor-holt

- Update `_CounterMeta` in `thrml/pgm.py` to assign a unique `_class_id`.
- Use `_class_id` as a tie-breaker in `__lt__` to ensure strict, deterministic ordering.
- Add `tests/test_pgm_ordering.py` for verification.

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
@google-labs-jules
Copy link

👋 Jules, reporting for duty! I'm here to lend a hand with this pull request.

When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down.

I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job!

For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with @jules. You can find this option in the Pull Request section of your global Jules UI settings. You can always switch back!

New to Jules? Learn more at jules.google/docs.


For security, I will only act on instructions from the user who triggered this task.

@gemini-code-assist
Copy link

Summary of Changes

Hello @igor-holt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a potential source of non-deterministic behavior in JAX pytree flattening and other dictionary-based operations. By assigning a unique, incremental ID to classes created via _CounterMeta, it ensures a stable and predictable ordering for these classes, which is crucial for consistent system behavior. The changes prevent issues that could arise from dynamically generated classes having identical names but needing distinct identities for sorting purposes.

Highlights

  • Deterministic Class Ordering: Introduced a _class_id counter within the _CounterMeta metaclass to assign a unique identifier to each class upon creation, ensuring deterministic ordering even for dynamically generated classes with identical names.
  • Comparison Logic Update: Modified the __lt__ (less than) method in _CounterMeta to incorporate the newly added _class_id in its comparison logic, guaranteeing stable and unique sorting behavior for AbstractNode subclasses.
  • New Test Coverage: Added a new test file, tests/test_pgm_ordering.py, to thoroughly verify the uniqueness, strictness, and stability of the class ordering mechanism implemented by _CounterMeta.
Changelog
  • tests/test_pgm_ordering.py
    • Added new test file to verify class ordering uniqueness and stability.
    • Included test_class_ordering_uniqueness to check distinctness, strict ordering, and sorting stability for classes with identical names.
    • Added test_ordering_with_different_classes to confirm correct alphabetical ordering for classes with different names.
  • thrml/pgm.py
    • Added a class-level _class_creation_counter to _CounterMeta to track class instantiation order.
    • Implemented an __init__ method in _CounterMeta to assign a unique _class_id to each class instance using the _class_creation_counter.
    • Updated the __lt__ method in _CounterMeta to include the _class_id in the comparison tuple, ensuring unique and stable ordering for classes.
Activity
  • PR created automatically by Jules for task 9360936800683187626, started by @igor-holt.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to ensure stable and unique ordering for dynamically created classes by adding a _class_id to _CounterMeta. This is a good solution to prevent non-deterministic behavior in JAX pytrees. The implementation is clean and is accompanied by a comprehensive set of new tests.

My main feedback is to improve the thread safety of the class creation counter to prevent potential race conditions in a multi-threaded environment. Other than that, the changes look solid.

Copy link

Copilot AI commented Feb 13, 2026

@igor-holt I've opened a new pull request, #13, to work on those changes. Once the pull request is ready, I'll request review from you.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses non-deterministic sorting behavior in JAX pytree operations by introducing a unique class identifier system for AbstractNode subclasses. The change ensures that dynamically created classes with identical names and modules can still be deterministically ordered.

Changes:

  • Added _class_id counter to _CounterMeta metaclass to assign unique IDs to each class at creation time
  • Modified __lt__ comparison method to include _class_id in the ordering tuple, ensuring uniqueness
  • Added comprehensive tests to verify class ordering stability and uniqueness

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
thrml/pgm.py Implements _class_id counter in _CounterMeta.__init__ and updates __lt__ to use it for unique class ordering
tests/test_pgm_ordering.py New test file verifying deterministic sorting behavior for classes with identical module/qualname attributes

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 37 to 38


Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metaclass implements lt but does not override eq and hash. According to Python's data model, when you implement comparison methods like lt, it's recommended to also implement eq and hash for consistency, especially since classes will be used as dictionary keys (as seen in BlockSpec.node_shape_dtypes). While Python's default type comparison might work, explicitly defining these methods would ensure consistent behavior when classes are used in sets, dictionaries, or for equality comparisons. This is particularly important given that the PR is about ensuring deterministic behavior in JAX operations.

Suggested change
def __eq__(cls, other):
if not isinstance(other, type):
return NotImplemented
return cls is other
def __hash__(cls):
# Preserve default identity-based hashing for classes.
return object.__hash__(cls)

Copilot uses AI. Check for mistakes.
Comment on lines +33 to +34
self_id = getattr(cls, "_class_id", -1)
other_id = getattr(other, "_class_id", -1)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr with a default value of -1 for _class_id could be problematic. If a class somehow doesn't have _class_id set (e.g., due to metaclass conflicts or inheritance issues), using -1 as the fallback means it will sort before all properly initialized classes. This could mask bugs or create unexpected ordering. Consider raising an error instead if _class_id is missing, or document why -1 is an acceptable fallback for classes not created through this metaclass.

Suggested change
self_id = getattr(cls, "_class_id", -1)
other_id = getattr(other, "_class_id", -1)
try:
self_id = cls._class_id
other_id = other._class_id
except AttributeError as exc:
raise TypeError(
"Class comparison is only supported between types created with _CounterMeta "
"that define a '_class_id' attribute."
) from exc

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,55 @@
import pytest
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pytest module is imported but not used. The tests use plain assert statements without any pytest-specific features like fixtures, parameterization, or markers. Consider removing the unused import or leveraging pytest features like @pytest.mark.parametrize for the test variations.

Suggested change
import pytest

Copilot uses AI. Check for mistakes.
# 'A' < 'B' alphabetically
# Assuming module is same
assert A < B
assert not (B < A)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test should verify the behavior when comparing classes with _CounterMeta against classes that don't have _CounterMeta. The lt method raises NotImplementedError for non-type comparisons, but it's unclear how it behaves when comparing a _CounterMeta class with a regular type. Adding a test case for this scenario would ensure the comparison logic handles edge cases correctly.

Suggested change
assert not (B < A)
assert not (B < A)
def test_ordering_with_non_countermeta_type():
"""Verify behavior when comparing _CounterMeta classes with regular types."""
# Create a class using _CounterMeta
MetaClass = _CounterMeta("MetaClass", (), {})
# Regular (non-_CounterMeta) type
class RegularClass:
pass
# Comparing a _CounterMeta class with a non-_CounterMeta type should
# raise NotImplementedError to avoid ambiguous ordering.
with pytest.raises(NotImplementedError):
_ = MetaClass < RegularClass
with pytest.raises(NotImplementedError):
_ = RegularClass < MetaClass

Copilot uses AI. Check for mistakes.
thrml/pgm.py Outdated
Comment on lines 18 to 19
cls._class_id = _CounterMeta._class_creation_counter
_CounterMeta._class_creation_counter += 1
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The counter increment operation is not thread-safe. If classes are created concurrently (e.g., during module imports in a multi-threaded environment or with parallel loading), multiple classes could receive the same _class_id value, which would break the uniqueness guarantee that this fix is trying to establish. Consider using threading.Lock to protect the counter increment, or document that classes must not be created concurrently.

Copilot uses AI. Check for mistakes.
# 'A' < 'B' alphabetically
# Assuming module is same
assert A < B
assert not (B < A)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test is always false, because of this condition.

Copilot uses AI. Check for mistakes.
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 131ced0a66

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

thrml/pgm.py Outdated
Comment on lines 18 to 19
cls._class_id = _CounterMeta._class_creation_counter
_CounterMeta._class_creation_counter += 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Serialize class-id counter updates

The _class_id assignment uses an unsynchronized read/modify/write sequence, so if classes are generated from multiple threads (for example in plugin loading or parallel test code), two class creations can observe the same _class_creation_counter value and receive duplicate IDs. In that case the new tie-breaker no longer guarantees a strict ordering for classes with identical __module__/__qualname__, reintroducing the non-deterministic sort behavior this change is meant to eliminate.

Useful? React with 👍 / 👎.

- Update `thrml/pgm.py` and `tests/test_pgm_ordering.py` to match `ruff` and `black` formatting requirements.
- Fix import ordering and whitespace.

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Copilot AI review requested due to automatic review settings February 13, 2026 08:06
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


# Comparison should prioritize name/module
# 'A' < 'B' alphabetically
self.assertTrue(A < B)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertTrue(a < b) cannot provide an informative message. Using assertLess(a, b) instead will give more informative messages.

Suggested change
self.assertTrue(A < B)
self.assertLess(A, B)

Copilot uses AI. Check for mistakes.
# Comparison should prioritize name/module
# 'A' < 'B' alphabetically
self.assertTrue(A < B)
self.assertFalse(B < A)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertFalse(a < b) cannot provide an informative message. Using assertGreaterEqual(a, b) instead will give more informative messages.

Suggested change
self.assertFalse(B < A)
self.assertGreaterEqual(B, A)

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI commented Feb 14, 2026

@igor-holt I've opened a new pull request, #15, to work on those changes. Once the pull request is ready, I'll request review from you.

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants