Conversation
| self._n_modes_via_ids_estimate = float(torch.unique(ids).numel()) | ||
| self._mode_stats_kind = "approx" | ||
| except Exception: | ||
| warnings.warn("+ Warning: Failed to compute mode_stats (skipping).") |
There was a problem hiding this comment.
better to use logger.exception here, to print the exception as well
Also it would be better to avoid catching Exception in general. Why this can fail?
There was a problem hiding this comment.
this would catch the ValueError in "exact" branch as well. Is this what we want? Should we catch at all?
| # Cheap exact threshold (up to ~200k states) | ||
| if self.n_states <= 200_000: | ||
| axes = [ | ||
| torch.arange(self.height, dtype=torch.long) for _ in range(self.ndim) | ||
| ] | ||
| grid = torch.cartesian_prod(*axes) | ||
| rewards = self.reward_fn(grid) |
There was a problem hiding this comment.
how did you come up with this number? Doing the cartesian product seems memory intensive
There was a problem hiding this comment.
this number might need to be lowered. It was arbitrary.
| except Exception: | ||
| # Fall back to heuristic paths below | ||
| pass |
There was a problem hiding this comment.
maybe add a logger
I don't think in general it is a good idea to mask a lot of stuff to the user. Sometimes we compute the exact mode existence, sometimes we use heuristic
| for col in range(m): | ||
| # Find pivot | ||
| piv = None | ||
| for r in range(row, k): | ||
| if A[r, col]: | ||
| piv = r | ||
| break | ||
| if piv is None: | ||
| continue | ||
| # Swap | ||
| if piv != row: | ||
| A[[row, piv]] = A[[piv, row]] | ||
| c[[row, piv]] = c[[piv, row]] | ||
| # Eliminate below | ||
| for r in range(row + 1, k): | ||
| if A[r, col]: | ||
| A[r, :] ^= A[row, :] | ||
| c[r] ^= c[row] | ||
| row += 1 | ||
| if row == k: | ||
| break | ||
| # Check for inconsistency: 0 = 1 rows | ||
| for r in range(k): | ||
| if not A[r, :].any() and c[r]: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
I didn't check the details tbh, but it seems quite inefficient and not easily readable. Can we rely to scipy for these stuffs?
There was a problem hiding this comment.
I'll look into it
| """ | ||
| with torch.no_grad(): | ||
| device = torch.device("cpu") | ||
| B = min(2048, max(128, 8 * self.ndim)) |
There was a problem hiding this comment.
what are these numbers? Maybe use constant to improve clarity
| try: | ||
| all_states = self.all_states | ||
| if all_states is not None: | ||
| mask = self.mode_mask(all_states) | ||
| ids = self.mode_ids(all_states) | ||
| ids = ids[mask] | ||
| ids = ids[ids >= 0] | ||
| return int(torch.unique(ids).numel()) | ||
| except Exception: | ||
| pass | ||
| if self._mode_stats_kind == "exact" and self._n_modes_via_ids_exact is not None: | ||
| return int(self._n_modes_via_ids_exact) | ||
| if ( | ||
| self._mode_stats_kind == "approx" | ||
| and self._n_modes_via_ids_estimate is not None | ||
| ): | ||
| return int(self._n_modes_via_ids_estimate) | ||
|
|
||
| return 2**self.ndim |
There was a problem hiding this comment.
do we need to recompute this every time?
There was a problem hiding this comment.
no you're right it should be stored.
| except Exception: | ||
| pass |
There was a problem hiding this comment.
similar to other comment, this is not nice for debuggability
|
Hi @younik - I hear you, this is a big PR. The "splits" would have to be along tasks, though, so the resulting PRs would still be large. I appreciate your comments on the code. I think it would make sense to also look at the tasks (the stuff that's plotted in the notebook) to see if they make sense. I'm not convinced by all of the tasks. I would be open to removing a task or two. I think the one that works best for it's intended purpose is the coprime reward. |
|
In the above commit, I fixed the comments of Deceptive Reward and also fixed a pyright error. |
I do think |
hypergrid refactor
saleml
left a comment
There was a problem hiding this comment.
This is high-quality research code with excellent mathematical foundations and thorough testing. The main concerns are:
- Complexity barrier for new users
- Performance documentation gaps
- Some missing edge-case handling
A few questions and suggestions
- The new reward functions are mathematically sophisticated (GF(2) algebra, prime factorization, etc.). While excellent for research, the barrier to entry is high
Suggestion: Add a "Quick Start" section to the documentation showing simple use cases before diving into the mathematical details. - The
_solve_gf2_has_solutionmethod uses Gaussian elimination which could be slow for large constraint systems
Suggestion: Add performance warnings in docstrings - Why GF(2)? The choice is elegant but not obvious. Could you add a paragraph in the documentation explaining why linear algebra over GF(2) is natural for compositional structure?
- What happens if a user picks "impossible" preset with ndim=2, height=16? Should the factory functions validate compatibility?
- Can you consider adding type hings for kwargs? something like
class BitwiseXORRewardKwargs(TypedDict, total=False):
R0: float
tier_weights: list[float]
dims_constrained: list[int]
bits_per_tier: list[tuple[int, int]]
parity_checks: list[dict] | None
- What do you think of adding visualization helpers
def visualize_mode_structure(env: HyperGrid, sample_size: int = 10000):
"""Generate 2D/3D plots of mode distribution."""
# Auto-generate plots similar to notebook but as API
I'd also like to suggest a structural consideration:
The original HyperGrid has become a pedagogical entrypoint of the GFlowNets library:
- It's the first environment new users encounter
- Its simplicity (grid + distance-based reward) makes it ideal for teaching core concepts
- Tutorial code often uses it as the "Hello World" of GFlowNets
- The cognitive load is intentionally minimal: "navigate a grid, reach high-reward corners"
Can we consider creating a separate file src/gfn/gym/compositional_hypergrid.py that:
- Inherits from HyperGrid to reuse the core grid mechanics
- Houses the new reward families (BitwiseXOR, MultiplicativeCoprime, TemplateMinkowski)
- Includes the sophisticated mode validation and statistics machinery
- Keeps the original simple and focused on accessibility
| mode_stats_samples: Number of random samples used when | ||
| `mode_stats="approx"`. | ||
| """ | ||
| if height <= 4: |
There was a problem hiding this comment.
This was removed but the condition is still relevant. Should this warning be reinstated or is it now handled by validate_modes
There was a problem hiding this comment.
It should be handled by validate_modes
| ax = (idx / Hm1 - 0.5).abs() | ||
| pdf = (1.0 / sqrt(2 * pi)) * torch.exp(-0.5 * (5 * ax) ** 2) | ||
| per_dim_discrete = float(((torch.cos(50 * ax) + 1.0) * pdf).max()) | ||
| per_dim_base = per_dim_discrete if self.height > 4 else per_dim_peak |
There was a problem hiding this comment.
Magic number: height > 4 threshold. Should this be documented or parameterized?
| return bool((rr >= thr - EPS_REWARD_CMP).any().item()) | ||
|
|
||
| @staticmethod | ||
| def _solve_gf2_has_solution(A: torch.Tensor, c: torch.Tensor) -> bool: |
There was a problem hiding this comment.
Could you explain the GF(2) algorithm in the docstring?
Also, this could be slow for large constraint systems
Suggestion: Add complexity note in docstring (O(k·m²) for k×m matrix

Description