From d3148f76fb849c2179457abd2976f96644f1acde Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 29 Sep 2025 01:40:40 -0400 Subject: [PATCH 01/27] added policy adaptors, factorized samplers to allow for modular adaptor logic, recurrent estimators, recurrent modules --- src/gfn/estimators.py | 110 +++++++- src/gfn/samplers.py | 499 ++++++++++++++++++++++++++++++------ src/gfn/utils/modules.py | 539 +++++++++++++++++++++++++++++++++++++++ testing/test_modules.py | 174 +++++++++++++ 4 files changed, 1246 insertions(+), 76 deletions(-) create mode 100644 testing/test_modules.py diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 24f7c158..c9fde292 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Optional +from typing import Any, Callable, Optional, cast import torch import torch.nn as nn @@ -798,3 +798,111 @@ def expected_output_dim(self) -> Optional[int]: None, as the output_dim of a TensorDict is not well-defined. """ return None + + +class RecurrentDiscretePolicyEstimator(DiscretePolicyEstimator): + """Discrete policy estimator for recurrent architectures with explicit carry. + + Many sequence models (e.g., RNN/LSTM/GRU/Transformer in autoregressive mode) + maintain a recurrent hidden state ("carry") that must be threaded through + successive calls during sampling. This class formalizes that pattern for + GFlowNet policies by: + + - Exposing a forward signature ``forward(states, carry) -> (logits, carry)`` + so the policy can update and return the next carry at each step. + - Requiring an ``init_carry(batch_size, device)`` method to allocate the + initial hidden state for a rollout. + - Ensuring the per-step output (``logits`` over actions) is derived from the + latest token/time step while the internal model may process sequences. + + Interaction with the sampler/adapters + ------------------------------------- + The sampler uses a ``RecurrentEstimatorAdapter`` which calls this estimator + with the current carry, updates the carry on every step, and records + per-step artifacts. Non-recurrent estimators should use the default adapter + and the standard ``DiscretePolicyEstimator`` base class instead. + + Notes + ----- + - Forward is intended for on-policy generation; off-policy evaluation over + entire trajectories typically requires different batching and masking. + - ``init_carry`` is a hard requirement for compatibility with the recurrent + adapter. + """ + + def __init__( + self, + module: nn.Module, + n_actions: int, + preprocessor: Preprocessor | None = IdentityPreprocessor( + output_dim=None + ), # Addressed in https://github.com/GFNOrg/torchgfn/pull/399. + is_backward: bool = False, + ): + """Initializes a RecurrentDiscretePolicyEstimator. + + Args: + module: The neural network module to use. + n_actions: Total number of actions in the discrete environment. + preprocessor: Preprocessor object that transforms states to tensors. + """ + if preprocessor is None: + preprocessor = IdentityPreprocessor( + output_dim=None + ) # Addressed in https://github.com/GFNOrg/torchgfn/pull/399. + super().__init__( + module=module, + n_actions=n_actions, + preprocessor=preprocessor, + is_backward=is_backward, + ) + + def forward( + self, + states: States, + carry: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Forward pass of the module. + + Args: + states: The input states. + carry: The carry from the previous step. + + Returns: + The output of the module, as a tensor of shape (*batch_shape, output_dim). + """ + # TODO: Is this still true? NOTE: Can only be used for on-policy generation, + # not off-policy evaluation of entire trajectory. + # states.tensor.shape: (..., max_string_len) + current_input_len = states.tensor.shape[ + -1 + ].max() # TODO: Check if this is correct. + states_tensor = states.tensor[..., :current_input_len] # (..., string_len) + + # Compute sequence of logits and update carry. + logits, carry = self.module(states_tensor, carry) + + # Get the logits for the last token in the sequence. + logits = logits[:, -1, :] # (b, n_actions) + + if self.expected_output_dim is not None: + assert logits.shape[-1] == self.expected_output_dim, ( + f"Module output shape {logits.shape} does not match expected output " + f"dimension {self.expected_output_dim}" + ) + + return logits, carry + + def init_carry( + self, + batch_size: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + init_carry = getattr(self.module, "init_carry", None) + if not callable(init_carry): + raise NotImplementedError( + "Module does not implement init_carry(batch_size, device)." + ) + init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) + + return init_carry_fn(batch_size, device) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index bb6d48b6..95a0e431 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,6 +1,7 @@ -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast import torch +from torch.distributions import Distribution from gfn.actions import Actions from gfn.containers import Trajectories @@ -16,6 +17,363 @@ from gfn.utils.prob_calculations import get_trajectory_pbs, get_trajectory_pfs +class EstimatorAdapter(Protocol): + """Adapter interface for estimator-specific policy behavior. + + Purpose + ------- + This Protocol defines the minimal interface the Sampler relies on, allowing + us to keep one generic sampling loop while plugging in different estimator + behaviors (e.g., non‑recurrent, recurrent with carry, tempered variants) + without modifying the Sampler. We use a Protocol (structural typing) so any + class that implements these members is accepted; no inheritance is required. + + Opaque context (ctx) + -------------------- + The adapter owns an opaque context object (``ctx``). The Sampler never + inspects it and simply passes it back to the adapter at each step. The + adapter is responsible for: + - initializing ``ctx`` once per rollout in ``init_context`` + - updating any internal state (e.g., recurrent ``carry``) during ``compute`` + - recording per‑step artifacts in ``record_step`` (e.g., log_probs, + estimator outputs), typically with mask-aware padding + - producing stacked outputs for ``Trajectories`` in ``finalize`` + + Guidance + -------- + - Allocate ``ctx`` once per rollout; mutate it in place for performance. + - Apply masking inside the adapter (``step_mask``) when slicing conditioning + or padding per‑step tensors back to full batch size. + - Keep Sampler oblivious to estimator details (conditioning, carry, etc.). + """ + + @property + def is_backward(self) -> bool: + ... # fmt: skip + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> Any: + ... # fmt: skip + + def compute( + self, + states_active: States, + ctx: Any, + step_mask: torch.Tensor, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + ... # fmt: skip + + def record_step( + self, + ctx: Any, + step_mask: torch.Tensor, + sampled_actions: torch.Tensor, + dist: Distribution, + save_logprobs: bool, + save_estimator_outputs: bool, + ) -> None: + ... # fmt: skip + + def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: + ... # fmt: skip + + # Optional helper for `sample_actions` BC + def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: + ... # fmt: skip + + +class AdapterContext: + """Structured, mutable context owned by adapters. + + Uses fixed attributes for core fields and an `extras` dict for adapter- + specific extensions without changing the class shape. This keeps most + accesses fast and typed while preserving flexibility similar to dicts. + """ + + __slots__ = ( + "batch_size", + "device", + "conditioning", + "carry", + "trajectory_log_probs", + "trajectory_estimator_outputs", + "current_estimator_output", + "extras", + ) + + def __init__( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> None: + self.batch_size = batch_size + self.device = device + self.conditioning = conditioning + self.carry = None + self.trajectory_log_probs: List[torch.Tensor] = [] + self.trajectory_estimator_outputs: List[torch.Tensor] = [] + self.current_estimator_output: Optional[torch.Tensor] = None + self.extras: Dict[str, Any] = {} + + +class DefaultEstimatorAdapter: + """Adapter for non-recurrent estimators (current default behavior). + + Overview + -------- + This adapter bridges the generic sampling loop and the "classic" non‑recurrent + estimators already used throughout the codebase. It exposes the minimal + interface required by the `EstimatorAdapter` Protocol while keeping the + sampler loop estimator-agnostic. + + Assumptions + ----------- + - The wrapped estimator is non‑recurrent (no carry between steps). + - If conditioning is provided, the estimator accepts `(states, conditioning)`; + otherwise it accepts `(states)`. + - The estimator provides `to_probability_distribution(states, est_out, **kw)` + returning a torch Distribution over actions for the masked states. + + Context Lifecycle (opaque to the Sampler) + ---------------------------------------- + The adapter owns an opaque context dict `ctx` which the Sampler never reads. + The context is created once per rollout and mutated in place: + + - init_context(batch_size, device, conditioning) -> ctx + Stores rollout invariants and optional conditioning. Also prepares per‑step + buffers for artifacts that may be recorded (log_probs, estimator_outputs). + + - compute(states_active, ctx, step_mask, **policy_kwargs) -> (dist, ctx) + 1) Selects the appropriate estimator call signature depending on whether + conditioning is present. If conditioning is present, the adapter slices + it with `step_mask` so shapes match `states_active`. + 2) Calls the estimator forward pass to obtain the raw `est_out`. + 3) Converts `est_out` into a torch Distribution with + `to_probability_distribution`. + 4) Saves `est_out` into `ctx["last_est_out"]` so it can optionally be + recorded by `record_step` or exposed to legacy callers of `sample_actions`. + + - record_step(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs) + Materializes optional per‑step artifacts into ctx-managed buffers with + mask‑aware padding back to the full rollout batch size `N`: + * Log‑probs: computes `dist.log_prob(sampled_actions)` for active rows only, + checks for `inf`, then writes into a 1D tensor of shape `(N,)` filled with + zeros and masked assignment for active positions. Appends this to a list + (one tensor per time step). + * Estimator outputs: if requested, pads the last estimator output + (`ctx["last_est_out"]`) to shape `(N, ...)` using `-inf` for inactive rows + and appends to a list (one tensor per time step). This matches existing + conventions elsewhere in the library. + + - finalize(ctx) -> {"log_probs": Tensor | None, "estimator_outputs": Tensor | None} + Stacks recorded per‑step lists along the time dimension into tensors of shape + `(T, N, ...)` suitable for `Trajectories`. Returns `None` for any artifact + that was never recorded. + + Masking & Shapes + ---------------- + - `states_active` always corresponds to `states[~dones]` inside the sampler. + - The adapter receives `step_mask` (shape `(N,)`) to slice any step‑dependent + inputs (e.g., conditioning) and to pad per‑step outputs to the full batch. + - Padded tensors use `0.0` for log‑probs and `-inf` for estimator outputs to + maintain compatibility with downstream code. + + Backward/Forward Direction + -------------------------- + - `is_backward` is forwarded from the underlying estimator so the sampler can + choose the appropriate environment transition (forward vs backward). + + Performance Notes + ----------------- + - `ctx` is allocated once per rollout and mutated in place to avoid per‑step + overhead. + - If you know trajectory length bounds, you can extend this adapter to + pre‑allocate fixed‑size storage in `init_context` rather than appending to + Python lists. + """ + + def __init__(self, estimator: Estimator) -> None: + """Initialize the adapter with a non-recurrent estimator. + + The estimator must expose `to_probability_distribution(states, est_out, **kw)` + and optionally accept conditioning via `estimator(states, conditioning)`. + """ + self._estimator = estimator + + @property + def is_backward(self) -> bool: + """Whether the wrapped estimator samples in the backward direction.""" + return getattr(self._estimator, "is_backward", False) + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> AdapterContext: + """Create a new per-rollout context. + + Stores rollout invariants (batch size, device, optional conditioning) and + initializes empty buffers for per-step artifacts. + """ + return AdapterContext( + batch_size=batch_size, device=device, conditioning=conditioning + ) + + def compute( + self, + states_active: States, + ctx: Any, + step_mask: torch.Tensor, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + """Run the estimator for active rows and build an action Distribution. + + - Uses `step_mask` to slice conditioning to the active subset. + - Saves the raw estimator output in `ctx.current_estimator_output` for + optional recording in `record_step`. + """ + conditioning = ctx.conditioning # type: ignore[attr-defined] + if conditioning is not None: + with has_conditioning_exception_handler("estimator", self._estimator): + est_out = self._estimator(states_active, conditioning[step_mask]) + else: + with no_conditioning_exception_handler("estimator", self._estimator): + est_out = self._estimator(states_active) + + dist = self._estimator.to_probability_distribution( + states_active, est_out, **policy_kwargs + ) + ctx.current_estimator_output = est_out # type: ignore[attr-defined] + + return dist, ctx + + def record_step( + self, + ctx: Any, + step_mask: torch.Tensor, + sampled_actions: torch.Tensor, + dist: Distribution, + save_logprobs: bool, + save_estimator_outputs: bool, + ) -> None: + """Record per-step artifacts into the context's trajectory-level lists. + + - If requested, computes log-probs for the active rows and writes them into + a padded vector of shape (N,) before appending to the list + `trajectory_log_probs`. + - If requested, pads `current_estimator_output` to shape (N, ...) and appends + to the list `trajectory_estimator_outputs`. + """ + N = ctx.batch_size # type: ignore[attr-defined] + device = ctx.device # type: ignore[attr-defined] + + if save_logprobs: + lp_masked = dist.log_prob(sampled_actions) + if torch.any(torch.isinf(lp_masked)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + step_lp = torch.full((N,), 0.0, device=device) + step_lp[step_mask] = lp_masked + ctx.trajectory_log_probs.append(step_lp) # type: ignore[attr-defined] + + if ( + save_estimator_outputs + and getattr(ctx, "current_estimator_output", None) is not None + ): + est_out = ctx.current_estimator_output # type: ignore[attr-defined] + padded = torch.full((N,) + est_out.shape[1:], -float("inf"), device=device) + padded[step_mask] = est_out + ctx.trajectory_estimator_outputs.append(padded) # type: ignore[attr-defined] + + def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: + """Stack all recorded per-step artifacts along time into trajectory-level tensors. + + Returns a dict with keys: + - 'log_probs': Tensor of shape (T, N) or None + - 'estimator_outputs': Tensor of shape (T, N, ...) or None + """ + log_probs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if getattr(ctx, "trajectory_log_probs", []) + else None + ) + estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if getattr(ctx, "trajectory_estimator_outputs", []) + else None + ) + + return {"log_probs": log_probs, "estimator_outputs": estimator_outputs} + + def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: + """Expose the most recent per-step estimator output saved during `compute`.""" + return getattr(ctx, "current_estimator_output", None) + + +class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): + """Adapter for recurrent estimators that require and update a carry.""" + + def __init__(self, estimator: Estimator) -> None: + super().__init__(estimator) + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> AdapterContext: + """Create context and initialize recurrent carry, (estimator hidden state). + + Differs from the default adapter by allocating `ctx.carry` via + `estimator.init_carry(batch_size, device)`. + """ + init_carry = getattr(self._estimator, "init_carry", None) + if not callable(init_carry): + raise TypeError( + "RecurrentEstimatorAdapter requires an estimator that implements " + "init_carry(batch_size: int, device: torch.device).\n" + "A) Recurrent estimators must expose an `init_carry` method.\n" + "B) RecurrentEstimatorAdapter is only compatible with estimators that " + "expose `init_carry`." + ) + ctx = super().init_context(batch_size, device, conditioning) + # Expect estimator to implement init_carry(batch_size, device) + init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) + ctx.carry = init_carry_fn(batch_size, device) + + return ctx + + def compute( + self, + states_active: States, + ctx: Any, + step_mask: torch.Tensor, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + """Run estimator with carry and update it. + + Differs from the default adapter by calling + `estimator(states_active, ctx.carry) -> (est_out, new_carry)`, storing the + updated carry and saving `current_estimator_output` before building the + Distribution. + """ + # Recurrent estimators are expected to accept (states, carry) -> (out, new_carry) + est_out, new_carry = self._estimator(states_active, ctx.carry) # type: ignore[attr-defined] + ctx.carry = new_carry # type: ignore[attr-defined] + dist = self._estimator.to_probability_distribution( + states_active, est_out, **policy_kwargs + ) + ctx.current_estimator_output = est_out # type: ignore[attr-defined] + + return dist, ctx + + class Sampler: """Wrapper for a PolicyEstimator that enables sampling from GFlowNet environments. @@ -28,7 +386,9 @@ class Sampler: probability distributions. """ - def __init__(self, estimator: Estimator) -> None: + def __init__( + self, estimator: Estimator, adapter: Optional[EstimatorAdapter] = None + ) -> None: """Initializes a Sampler with a PolicyEstimator. Args: @@ -36,6 +396,9 @@ def __init__(self, estimator: Estimator) -> None: probability distributions. """ self.estimator = estimator + self.adapter = ( + adapter if adapter is not None else DefaultEstimatorAdapter(estimator) + ) def sample_actions( self, @@ -44,6 +407,7 @@ def sample_actions( conditioning: torch.Tensor | None = None, save_estimator_outputs: bool = False, save_logprobs: bool = False, + ctx: Any | None = None, **policy_kwargs: Any, ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]: """Samples actions from the given states using the policy estimator. @@ -77,36 +441,47 @@ def sample_actions( - Optional tensor of log probabilities (if save_logprobs=True) - Optional tensor of estimator outputs (if save_estimator_outputs=True) """ - # TODO: Should estimators instead ignore None for the conditioning vector? - if conditioning is not None: - with has_conditioning_exception_handler("estimator", self.estimator): - estimator_output = self.estimator(states, conditioning) - else: - with no_conditioning_exception_handler("estimator", self.estimator): - estimator_output = self.estimator(states) + if ctx is None: + ctx = self.adapter.init_context( + batch_size=states.batch_shape[0], + device=states.device, + conditioning=conditioning, + ) - dist = self.estimator.to_probability_distribution( - states, estimator_output, **policy_kwargs + step_mask = torch.ones( + states.batch_shape[0], dtype=torch.bool, device=states.device ) + dist, ctx = self.adapter.compute(states, ctx, step_mask, **policy_kwargs) with torch.no_grad(): - actions = dist.sample() + actions_tensor = dist.sample() if save_logprobs: - log_probs = dist.log_prob(actions) + log_probs = dist.log_prob(actions_tensor) if torch.any(torch.isinf(log_probs)): raise RuntimeError("Log probabilities are inf. This should not happen.") else: log_probs = None - actions = env.actions_from_tensor(actions) + # Allow adapter to record per-step artifacts for callers that reuse ctx. + self.adapter.record_step( + ctx=ctx, + step_mask=step_mask, + sampled_actions=actions_tensor, + dist=dist, + save_logprobs=save_logprobs, + save_estimator_outputs=save_estimator_outputs, + ) + + actions = env.actions_from_tensor(actions_tensor) - if not save_estimator_outputs: - estimator_output = None + estimator_output = None + if save_estimator_outputs and hasattr( + self.adapter, "get_current_estimator_output" + ): + estimator_output = self.adapter.get_current_estimator_output(ctx) assert log_probs is None or log_probs.shape == actions.batch_shape - # assert estimator_output is None or estimator_output.shape == actions.batch_shape - # TODO: check expected shape return actions, log_probs, estimator_output @@ -153,7 +528,7 @@ def sample_trajectories( For backward trajectories, the reward is computed at the initial state (s0) rather than the terminal state (sf). """ - if self.estimator.is_backward: + if self.adapter.is_backward: # [ASSUMPTION] When backward sampling, all provided states are the # terminating states (can be passed to log_reward fn) assert ( @@ -178,9 +553,7 @@ def sample_trajectories( ensure_same_device(states.device, conditioning.device) dones = ( - states.is_initial_state - if self.estimator.is_backward - else states.is_sink_state + states.is_initial_state if self.adapter.is_backward else states.is_sink_state ) # Define dummy actions to avoid errors when stacking empty lists. @@ -188,58 +561,44 @@ def sample_trajectories( trajectories_actions: List[Actions] = [ env.actions_from_batch_shape((n_trajectories,)) ] - trajectories_logprobs: List[torch.Tensor] = [ - torch.full((n_trajectories,), fill_value=0, device=device) - ] + # Placeholder kept for backward-compatibility of shapes; logprobs are + # recorded and stacked by the adapter. trajectories_terminating_idx = torch.zeros( n_trajectories, dtype=torch.long, device=device ) step = 0 - all_estimator_outputs = [] + ctx = self.adapter.init_context(n_trajectories, device, conditioning) while not all(dones): actions = env.actions_from_batch_shape((n_trajectories,)) - log_probs = torch.full((n_trajectories,), fill_value=0.0, device=device) - # This optionally allows you to retrieve the estimator_outputs collected - # during sampling. This is useful if, for example, you want to evaluate off - # policy actions later without repeating calculations to obtain the env - # distribution parameters. - if conditioning is not None: - masked_conditioning = conditioning[~dones] - else: - masked_conditioning = None + step_mask = ~dones - valid_actions, actions_log_probs, estimator_outputs = self.sample_actions( - env, - states[~dones], - masked_conditioning, - save_estimator_outputs=True if save_estimator_outputs else False, + # Compute distribution on active rows + dist, ctx = self.adapter.compute( + states[step_mask], ctx, step_mask, **policy_kwargs + ) + + # Sample actions for active rows + with torch.no_grad(): + valid_actions_tensor = dist.sample() + valid_actions = env.actions_from_tensor(valid_actions_tensor) + + # Let adapter record artifacts + self.adapter.record_step( + ctx=ctx, + step_mask=step_mask, + sampled_actions=valid_actions_tensor, + dist=dist, save_logprobs=save_logprobs, - **policy_kwargs, + save_estimator_outputs=save_estimator_outputs, ) - if estimator_outputs is not None: - # Place estimator outputs into a stackable tensor. Note that this - # will be replaced with torch.nested.nested_tensor in the future. - estimator_outputs_padded = torch.full( - (n_trajectories,) + estimator_outputs.shape[1:], - fill_value=-float("inf"), - device=device, - ) - estimator_outputs_padded[~dones] = estimator_outputs - all_estimator_outputs.append(estimator_outputs_padded) - actions[~dones] = valid_actions - if save_logprobs: - assert ( - actions_log_probs is not None - ), "actions_log_probs should not be None when save_logprobs is True" - log_probs[~dones] = actions_log_probs + actions[step_mask] = valid_actions trajectories_actions.append(actions) - trajectories_logprobs.append(log_probs) - if self.estimator.is_backward: + if self.adapter.is_backward: new_states = env._backward_step(states, actions) else: new_states = env._step(states, actions) @@ -265,7 +624,7 @@ def sample_trajectories( # to filter out the already done ones. new_dones = ( new_states.is_initial_state - if self.estimator.is_backward + if self.adapter.is_backward else new_states.is_sink_state ) & ~dones trajectories_terminating_idx[new_dones] = step @@ -279,21 +638,11 @@ def sample_trajectories( stacked_actions = env.Actions.stack(trajectories_actions)[ 1: ] # Drop dummy action - stacked_logprobs = ( - torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob - if save_logprobs - else None - ) - - # TODO: use torch.nested.nested_tensor(dtype, device, requires_grad). - stacked_estimator_outputs = ( - torch.stack(all_estimator_outputs, dim=0) if save_estimator_outputs else None - ) + # Ask adapter for stacked trajectory_artifacts (already shaped (T, N, ...)) + trajectory_artifacts = self.adapter.finalize(ctx) + stacked_logprobs = trajectory_artifacts.get("log_probs", None) + stacked_estimator_outputs = trajectory_artifacts.get("estimator_outputs", None) - # If there are no logprobs or estimator outputs, set them to None. - # TODO: This is a hack to avoid errors when no logprobs or estimator outputs are - # saved. This bug was introduced when I changed the dtypes library-wide -- why - # is this happening? if stacked_logprobs is not None and len(stacked_logprobs) == 0: stacked_logprobs = None if stacked_estimator_outputs is not None and len(stacked_estimator_outputs) == 0: @@ -323,7 +672,7 @@ def sample_trajectories( conditioning=conditioning, actions=stacked_actions, terminating_idx=trajectories_terminating_idx, - is_backward=self.estimator.is_backward, + is_backward=self.adapter.is_backward, log_rewards=None, # will be calculated later log_probs=stacked_logprobs, estimator_outputs=stacked_estimator_outputs, diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index fddebb69..5ca9cc14 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -1,12 +1,15 @@ """This file contains some examples of modules that can be used with GFN.""" import math +from abc import ABC, abstractmethod from typing import Literal, Optional import torch import torch.nn as nn +import torch.nn.functional as F from linear_attention_transformer import LinearAttentionTransformer from tensordict import TensorDict +from torch import Tensor from torch_geometric.nn import DirGNNConv, GCNConv, GINConv from gfn.actions import GraphActions, GraphActionType @@ -962,3 +965,539 @@ def bias(self) -> torch.Tensor | None: return self.bias_mu else: return None + + +class AutoregressiveDiscreteSequenceModel(ABC, nn.Module): + + @abstractmethod + def init_carry( + self, + batch_size: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + """Initialize the carry for the sequence model. + + Args: + batch_size (int): Batch size. + device (torch.device): Device to allocate carry tensors on. + + Returns: + dict[str, torch.Tensor]: Initialized carry. + """ + + @abstractmethod + def forward( + self, + x: torch.Tensor, + carry: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute the logits for the next tokens in the sequence. + + Args: + x (torch.Tensor): (B, T) tensor of input token indices where ``T`` is the + number of newly supplied timesteps (``T`` may be 1 for incremental + decoding). + carry (dict[str, torch.Tensor]): Carry from previous steps for recurrent + processing (e.g., hidden states). + + Returns: + tuple[torch.Tensor, dict[str, torch.Tensor]]: Logits for the next token + at each supplied timestep with shape (B, T, vocab) and updated carry. + """ + + @property + @abstractmethod + def vocab_size(self) -> int: + """Size of the vocabulary (excluding BOS token).""" + + +class RecurrentDiscreteSequenceModel(AutoregressiveDiscreteSequenceModel): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + hidden_size: int, + num_layers: int = 1, + rnn_type: Literal["lstm", "gru"] = "lstm", + dropout: float = 0.0, + ) -> None: + super().__init__() + if num_layers <= 0: + raise ValueError("num_layers must be a positive integer.") + rnn_kind = rnn_type.lower() + if rnn_kind not in {"lstm", "gru"}: + raise ValueError("rnn_type must be 'lstm' or 'gru'.") + + if not 0.0 <= dropout <= 1.0: + raise ValueError("dropout must be in the range [0, 1].") + + self._vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.hidden_size = hidden_size + self.num_layers = num_layers + self.rnn_type = rnn_kind + + self.embedding = nn.Embedding(vocab_size + 1, embedding_dim) # +1 for BOS token + rnn_dropout = dropout if num_layers > 1 else 0.0 + self.lstm: nn.LSTM | None + self.gru: nn.GRU | None + if rnn_kind == "lstm": + self.lstm = nn.LSTM( + input_size=embedding_dim, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=rnn_dropout, + ) + self.gru = None + else: + self.gru = nn.GRU( + input_size=embedding_dim, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=rnn_dropout, + ) + self.lstm = None + self.output_projection = nn.Linear(hidden_size, vocab_size) + + def init_carry( + self, + batch_size: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + carry: dict[str, torch.Tensor] = { + "hidden": torch.zeros( + self.num_layers, batch_size, self.hidden_size, device=device + ), + } + if self.rnn_type == "lstm": + carry["cell"] = torch.zeros( + self.num_layers, batch_size, self.hidden_size, device=device + ) + return carry + + def forward( + self, + x: torch.Tensor, + carry: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + if x.dim() != 2: + raise ValueError("Expected input tensor with shape (batch, timesteps).") + + batch, timesteps = x.size() + device = x.device + + if "hidden" not in carry: + raise KeyError("Carry must provide a 'hidden' state tensor.") + + hidden = carry["hidden"] + if hidden.size(1) != batch: + raise ValueError( + "Hidden state batch dimension does not match the provided tokens." + ) + if hidden.device != device: + raise ValueError( + "Hidden state tensor must live on the same device as input tokens." + ) + + embedded = self.embedding(x) + + if self.rnn_type == "lstm": + lstm = self.lstm + if lstm is None: + raise RuntimeError("LSTM module was not initialized.") + if "cell" not in carry: + raise KeyError("LSTM carry must provide a 'cell' state tensor.") + cell = carry["cell"] + if cell.size(1) != batch: + raise ValueError( + "Cell state batch dimension does not match the provided tokens." + ) + if cell.device != device: + raise ValueError( + "Cell state tensor must live on the same device as input tokens." + ) + outputs, (hidden_next, cell_next) = lstm(embedded, (hidden, cell)) + updated_carry: dict[str, torch.Tensor] = { + "hidden": hidden_next, + "cell": cell_next, + } + else: + gru = self.gru + if gru is None: + raise RuntimeError("GRU module was not initialized.") + outputs, hidden_next = gru(embedded, hidden) + updated_carry = { + "hidden": hidden_next, + } + + logits = self.output_projection(outputs) + return logits, updated_carry + + @property + def vocab_size(self) -> int: + return self._vocab_size + + +class _AutoregressiveTransformerBlock(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + ff_hidden_dim: int, + dropout: float, + ) -> None: + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError("Embedding dimension must be divisible by number of heads.") + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + + self.linear1 = nn.Linear(embed_dim, ff_hidden_dim) + self.linear2 = nn.Linear(ff_hidden_dim, embed_dim) + + self.attn_dropout = nn.Dropout(dropout) + self.residual_dropout = nn.Dropout(dropout) + self.ff_dropout = nn.Dropout(dropout) + + def forward( + self, + hidden: torch.Tensor, + key_carry: torch.Tensor, + value_carry: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch, timesteps, _ = hidden.size() + + normed_hidden = self.norm1(hidden) + + q = self.q_proj(normed_hidden) + k = self.k_proj(normed_hidden) + v = self.v_proj(normed_hidden) + + q = q.view(batch, timesteps, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch, timesteps, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch, timesteps, self.num_heads, self.head_dim).transpose(1, 2) + + carry_length = key_carry.size(2) + updated_key_carry = torch.cat((key_carry, k), dim=2) + updated_value_carry = torch.cat((value_carry, v), dim=2) + + attn_scores = torch.matmul(q, updated_key_carry.transpose(-2, -1)) / math.sqrt( + float(self.head_dim) + ) + + if timesteps > 1 or carry_length > 0: + total_kv_length = carry_length + timesteps + kv_positions = torch.arange( + total_kv_length, device=hidden.device, dtype=torch.long + ) + query_positions = torch.arange( + timesteps, device=hidden.device, dtype=torch.long + ).unsqueeze(1) + causal_mask = kv_positions.unsqueeze(0) <= (query_positions + carry_length) + attn_scores = attn_scores.masked_fill( + ~causal_mask.unsqueeze(0).unsqueeze(0), float("-inf") + ) + + attn_weights = torch.softmax(attn_scores, dim=-1) + attn_weights = self.attn_dropout(attn_weights) + attn_output = torch.matmul(attn_weights, updated_value_carry) + attn_output = attn_output.transpose(1, 2).reshape( + batch, timesteps, self.embed_dim + ) + attn_output = self.out_proj(attn_output) + + residual = hidden + hidden = residual + self.residual_dropout(attn_output) + + ff_input = self.norm2(hidden) + ff_hidden = self.linear1(ff_input) + ff_hidden = self.ff_dropout(F.gelu(ff_hidden)) + ff_hidden = self.linear2(ff_hidden) + + hidden = hidden + self.residual_dropout(ff_hidden) + return hidden, updated_key_carry, updated_value_carry + + +class TransformerDiscreteSequenceModel(AutoregressiveDiscreteSequenceModel): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + num_heads: int, + ff_hidden_dim: int, + num_layers: int, + max_position_embeddings: int, + dropout: float = 0.0, + positional_embedding: Literal["learned", "sinusoidal"] = "learned", + ) -> None: + super().__init__() + if num_layers <= 0: + raise ValueError("num_layers must be positive.") + if max_position_embeddings <= 0: + raise ValueError("max_position_embeddings must be positive.") + if not 0.0 <= dropout <= 1.0: + raise ValueError("dropout must lie in [0, 1].") + if embedding_dim % num_heads != 0: + raise ValueError("embedding_dim must be divisible by num_heads.") + if positional_embedding not in {"learned", "sinusoidal"}: + raise ValueError("positional_embedding must be 'learned' or 'sinusoidal'.") + + self._vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.ff_hidden_dim = ff_hidden_dim + self.num_layers = num_layers + self.max_position_embeddings = max_position_embeddings + self.head_dim = embedding_dim // num_heads + self._positional_embedding_type = positional_embedding + + self.token_embedding = nn.Embedding( + vocab_size + 1, embedding_dim + ) # +1 for BOS token + if self._positional_embedding_type == "learned": + self.position_embedding = nn.Embedding( + max_position_embeddings, embedding_dim + ) + else: + self.position_embedding = SinusoidalPositionalEmbedding( + embedding_dim=embedding_dim, + max_length=max_position_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + blocks: list[_AutoregressiveTransformerBlock] = [] + for _ in range(num_layers): + blocks.append( + _AutoregressiveTransformerBlock( + embed_dim=embedding_dim, + num_heads=num_heads, + ff_hidden_dim=ff_hidden_dim, + dropout=dropout, + ) + ) + + self.layers = nn.ModuleList(blocks) + self.final_norm = nn.LayerNorm(embedding_dim) + self.output_projection = nn.Linear(embedding_dim, vocab_size) + self.key_names = [f"key_{idx}" for idx in range(num_layers)] + self.value_names = [f"value_{idx}" for idx in range(num_layers)] + + def init_carry( + self, + batch_size: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + weight = self.token_embedding.weight + carry: dict[str, torch.Tensor] = { + "position": torch.zeros(batch_size, dtype=torch.long, device=device), + } + empty_key = weight.new_empty(batch_size, self.num_heads, 0, self.head_dim).to( + device + ) + empty_value = weight.new_empty(batch_size, self.num_heads, 0, self.head_dim).to( + device + ) + for key_name, value_name in zip(self.key_names, self.value_names): + carry[key_name] = empty_key.clone() + carry[value_name] = empty_value.clone() + + return carry + + def forward( + self, + x: torch.Tensor, + carry: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + if x.dim() != 2: + raise ValueError("Expected input tensor with shape (batch, timesteps).") + + batch, timesteps = x.size() + device = x.device + if "position" not in carry: + raise KeyError("Carry must include a 'position' tensor.") + + positions = carry["position"] + if positions.size(0) != batch: + raise ValueError( + "Position carry batch dimension does not match the provided tokens." + ) + if positions.device != device: + raise ValueError( + "Position tensor must live on the same device as input tokens." + ) + if torch.any(positions >= self.max_position_embeddings): + raise ValueError( + "Position index exceeds configured positional embedding range." + ) + + position_offsets = torch.arange(timesteps, device=device, dtype=positions.dtype) + position_indices = positions.unsqueeze(1) + position_offsets + if torch.any(position_indices >= self.max_position_embeddings): + raise ValueError( + "Position index exceeds configured positional embedding range." + ) + + hidden = self.token_embedding(x) + self.position_embedding(position_indices) + hidden = self.embedding_dropout(hidden) + + updated_carry: dict[str, torch.Tensor] = {} + + for idx, layer in enumerate(self.layers): + key_name = self.key_names[idx] + value_name = self.value_names[idx] + if key_name not in carry or value_name not in carry: + raise KeyError( + "Transformer carry is missing key/value tensors for layer" f" {idx}." + ) + key_carry = carry[key_name] + value_carry = carry[value_name] + if key_carry.size(0) != batch or key_carry.size(1) != self.num_heads: + raise ValueError( + "Key carry shape is incompatible with the provided tokens." + ) + if value_carry.size(0) != batch or value_carry.size(1) != self.num_heads: + raise ValueError( + "Value carry shape is incompatible with the provided tokens." + ) + if ( + key_carry.size(-1) != self.head_dim + or value_carry.size(-1) != self.head_dim + ): + raise ValueError("Key/value carry head dimension mismatch detected.") + if key_carry.device != device or value_carry.device != device: + raise ValueError("Key/value carry tensors must share the input device.") + hidden, updated_key_carry, updated_value_carry = layer( + hidden, key_carry, value_carry + ) + updated_carry[key_name] = updated_key_carry + updated_carry[value_name] = updated_value_carry + + hidden = self.final_norm(hidden) + logits = self.output_projection(hidden) + + updated_carry["position"] = positions + timesteps + return logits, updated_carry + + @property + def vocab_size(self) -> int: + return self._vocab_size + + +def sinusoidal_position_encoding( + length: int, + embedding_dim: int, + base: float = 10000.0, +) -> Tensor: + """Create 1D sinusoidal positional embeddings. + + Args: + length: Number of positions to encode. Must be non-negative. + embedding_dim: Dimensionality of each embedding. Must be positive. + base: Exponential base used to compute the angular frequencies. + + Returns: + A ``(length, embedding_dim)`` tensor of sinusoidal encodings. + + Raises: + ValueError: If ``length`` is negative, ``embedding_dim`` is not positive, + or ``base`` is not positive. + """ + + assert length >= 0, "length must be non-negative." + assert embedding_dim > 0, "embedding_dim must be positive." + assert base > 0, "base must be positive." + + if length == 0: + return torch.empty(0, embedding_dim) + + positions = torch.arange(length).unsqueeze(1) + div_input = torch.arange(0, embedding_dim, 2) + div_term = torch.exp(div_input * (-math.log(base) / embedding_dim)) + embeddings = torch.zeros(length, embedding_dim) + angles = positions * div_term + embeddings[:, 0::2] = torch.sin(angles) + + if embedding_dim % 2 == 0: + embeddings[:, 1::2] = torch.cos(angles) + else: + embeddings[:, 1::2] = torch.cos(angles)[:, : embedding_dim // 2] + + return embeddings + + +class SinusoidalPositionalEmbedding(nn.Module): + """Sinusoidal positional embeddings for transformer-style models. + + The module caches a precomputed table of embeddings and extends it on demand. + Forward accepts either a sequence length or explicit position indices. + """ + + def __init__( + self, + embedding_dim: int, + max_length: int = 2048, + base: float = 10000.0, + ) -> None: + super().__init__() + assert max_length >= 0, "max_length must be non-negative." + assert embedding_dim > 0, "embedding_dim must be positive." + assert base > 0, "base must be positive." + + self.embedding_dim = int(embedding_dim) + self.base = float(base) + + pe = sinusoidal_position_encoding(max_length, self.embedding_dim, base=self.base) + self._pe: Tensor + self.register_buffer("_pe", pe) + + @property + def pe(self) -> Tensor: + """Return the cached positional embedding table.""" + return self._pe + + def forward( + self, + positions: Optional[Tensor] = None, + seq_len: Optional[int] = None, + ) -> Tensor: + """Look up positional embeddings. + + Args: + positions: Optional tensor of position indices. Can have any shape, + and the returned embeddings will append ``embedding_dim`` to that + shape. Defaults to ``None``. + seq_len: Optional sequence length. When provided, returns the first + ``seq_len`` embeddings from the table. + + Returns: + Tensor of positional embeddings on the same device/dtype as the + cached table. + + Raises: + ValueError: If both or neither of ``positions`` and ``seq_len`` are + provided, or if indices exceed the cached range. + """ + + if (positions is None) == (seq_len is None): + raise ValueError("Provide exactly one of positions or seq_len.") + + if positions is not None: + flat_positions = positions.reshape(-1) + gathered = self._pe.index_select(0, flat_positions) + return gathered.view( + positions.shape[0], positions.shape[1], self.embedding_dim + ) + else: + return self._pe[:seq_len] diff --git a/testing/test_modules.py b/testing/test_modules.py new file mode 100644 index 00000000..5689f7a6 --- /dev/null +++ b/testing/test_modules.py @@ -0,0 +1,174 @@ +from typing import Literal + +import pytest +import torch + +from gfn.utils.modules import ( + RecurrentDiscreteSequenceModel, + TransformerDiscreteSequenceModel, +) + + +@pytest.mark.parametrize("rnn_type", ["lstm", "gru"]) +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_recurrent_smoke(rnn_type: Literal["lstm", "gru"], device: torch.device) -> None: + batch_size = 2 + vocab_size = 11 + total_steps = 4 + model = RecurrentDiscreteSequenceModel( + vocab_size=vocab_size, + embedding_dim=5, + hidden_size=7, + num_layers=2, + rnn_type=rnn_type, + dropout=0.0, + ).to(device) + model.eval() + + tokens = torch.randint(0, vocab_size, (batch_size, total_steps), device=device) + + def collect_logits( + chunk_sizes: list[int], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + carry = model.init_carry(batch_size, device) + outputs: list[torch.Tensor] = [] + start = 0 + with torch.no_grad(): + for chunk in chunk_sizes: + end = start + chunk + logits, carry = model(tokens[:, start:end], carry) + outputs.append(logits) + start = end + if start != total_steps: + raise ValueError("Chunk sizes must cover the entire sequence length.") + return torch.cat(outputs, dim=1), carry + + logits_all, carry_all = collect_logits([total_steps]) + logits_single, carry_single = collect_logits([1] * total_steps) + logits_double, carry_double = collect_logits([2, 2]) + + scripted = torch.jit.script(model) + carry_script = model.init_carry(batch_size, device) + with torch.no_grad(): + logits_script, carry_script = scripted(tokens, carry_script) + + assert torch.allclose(logits_all, logits_single, atol=1e-6, rtol=1e-5) + assert torch.allclose(logits_all, logits_double, atol=1e-6, rtol=1e-5) + assert torch.allclose(logits_all, logits_script, atol=1e-6, rtol=1e-5) + + assert torch.allclose( + carry_all["hidden"], carry_single["hidden"], atol=1e-6, rtol=1e-5 + ) + assert torch.allclose( + carry_all["hidden"], carry_double["hidden"], atol=1e-6, rtol=1e-5 + ) + + if rnn_type == "lstm": + assert torch.allclose( + carry_all["cell"], carry_single["cell"], atol=1e-6, rtol=1e-5 + ) + assert torch.allclose( + carry_all["cell"], carry_double["cell"], atol=1e-6, rtol=1e-5 + ) + + +@pytest.mark.parametrize("positional_embedding", ["learned", "sinusoidal"]) +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_transformer_smoke( + positional_embedding: Literal["learned", "sinusoidal"], + device: torch.device, +) -> None: + batch_size = 3 + vocab_size = 13 + total_steps = 4 + model = TransformerDiscreteSequenceModel( + vocab_size=vocab_size, + embedding_dim=12, + num_heads=3, + ff_hidden_dim=24, + num_layers=2, + max_position_embeddings=32, + dropout=0.0, + positional_embedding=positional_embedding, + ).to(device) + model.eval() + + tokens = torch.randint(0, vocab_size, (batch_size, total_steps), device=device) + + def collect_logits( + chunk_sizes: list[int], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + carry = model.init_carry(batch_size, device) + outputs: list[torch.Tensor] = [] + start = 0 + with torch.no_grad(): + for chunk in chunk_sizes: + end = start + chunk + logits, carry = model(tokens[:, start:end], carry) + outputs.append(logits) + start = end + if start != total_steps: + raise ValueError("Chunk sizes must cover the entire sequence length.") + return torch.cat(outputs, dim=1), carry + + logits_all, carry_all = collect_logits([total_steps]) + logits_single, carry_single = collect_logits([1] * total_steps) + logits_double, carry_double = collect_logits([2, 2]) + + scripted = torch.jit.script(model) + carry_script = model.init_carry(batch_size, device) + + with torch.no_grad(): + logits_script, carry_script = scripted(tokens, carry_script) + + assert torch.allclose(logits_all, logits_single, atol=1e-6, rtol=1e-5) + assert torch.allclose(logits_all, logits_double, atol=1e-6, rtol=1e-5) + assert torch.allclose(logits_all, logits_script, atol=1e-6, rtol=1e-5) + assert torch.equal(carry_all["position"], carry_single["position"]) + assert torch.equal(carry_all["position"], carry_double["position"]) + + def carry_matches( + ref: dict[str, torch.Tensor], other: dict[str, torch.Tensor] + ) -> bool: + for idx in range(model.num_layers): + key_name = model.key_names[idx] + value_name = model.value_names[idx] + if not torch.allclose(ref[key_name], other[key_name], atol=1e-6, rtol=1e-5): + return False + if not torch.allclose( + ref[value_name], other[value_name], atol=1e-6, rtol=1e-5 + ): + return False + return True + + assert carry_matches(carry_all, carry_single) + assert carry_matches(carry_all, carry_double) + + for idx in range(model.num_layers): + assert ( + carry_all[f"key_{idx}"].size(2) + == carry_all[f"value_{idx}"].size(2) + == total_steps + ) From 93a654a1103a881346d904ff48d40c88b3393509 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 29 Sep 2025 02:17:28 -0400 Subject: [PATCH 02/27] fixed _SeqStates --- src/gfn/estimators.py | 4 +- testing/test_samplers_and_trajectories.py | 259 +++++++++++++++++++++- 2 files changed, 257 insertions(+), 6 deletions(-) diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index c9fde292..835afdf1 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -874,9 +874,7 @@ def forward( # TODO: Is this still true? NOTE: Can only be used for on-policy generation, # not off-policy evaluation of entire trajectory. # states.tensor.shape: (..., max_string_len) - current_input_len = states.tensor.shape[ - -1 - ].max() # TODO: Check if this is correct. + current_input_len = states.tensor.shape[-1] # TODO: Check if this is correct. states_tensor = states.tensor[..., :current_input_len] # (..., string_len) # Compute sequence of logits and update carry. diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index cc931765..dde8ba5a 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -1,7 +1,8 @@ -from typing import Literal, Tuple +from typing import Literal, Tuple, cast import pytest import torch +from torch.distributions import Categorical from gfn.containers import Trajectories, Transitions from gfn.containers.replay_buffer import ReplayBuffer @@ -19,8 +20,20 @@ KHotPreprocessor, OneHotPreprocessor, ) -from gfn.samplers import LocalSearchSampler, Sampler -from gfn.utils.modules import MLP, GraphActionGNN +from gfn.samplers import ( + AdapterContext, + DefaultEstimatorAdapter, + LocalSearchSampler, + RecurrentEstimatorAdapter, + Sampler, +) +from gfn.states import States +from gfn.utils.modules import ( + MLP, + GraphActionGNN, + RecurrentDiscreteSequenceModel, + TransformerDiscreteSequenceModel, +) from gfn.utils.prob_calculations import get_trajectory_pfs from gfn.utils.training import states_actions_tns_to_traj @@ -453,3 +466,243 @@ def test_states_actions_tns_to_traj(): # Test that we can add the trajectories to a replay buffer replay_buffer = ReplayBuffer(env, capacity=10) replay_buffer.add(trajs) + + +# ---------------------- Adapters: unit-level smoke tests ---------------------- + + +class _FakeStates: + def __init__(self, n: int, device: torch.device): + self.tensor = torch.zeros((n, 1), device=device) + + @property + def batch_shape(self): + return (self.tensor.shape[0],) + + +class _DummyEstimator: + is_backward = False + + def __call__(self, states: _FakeStates, conditioning: torch.Tensor | None = None): + n = states.batch_shape[0] + return torch.zeros((n, 3), device=states.tensor.device) + + def to_probability_distribution( + self, states: _FakeStates, est_out: torch.Tensor, **_: dict + ): + logits = torch.zeros((states.batch_shape[0], 3), device=states.tensor.device) + return Categorical(logits=logits) + + # no expected_output_dim required for adapter tests + + +class _DummyRecurrentEstimator: + is_backward = False + + def init_carry(self, batch_size: int, device: torch.device): + return {"hidden": torch.zeros((batch_size, 2), device=device)} + + def __call__(self, states: _FakeStates, carry: dict[str, torch.Tensor]): + n = states.batch_shape[0] + logits = torch.zeros((n, 3), device=states.tensor.device) + new_carry = {"hidden": carry["hidden"] + 1} + return logits, new_carry + + def to_probability_distribution( + self, states: _FakeStates, est_out: torch.Tensor, **_: dict + ): + logits = torch.zeros((states.batch_shape[0], 3), device=states.tensor.device) + return Categorical(logits=logits) + + # no expected_output_dim required for adapter tests + + +def test_adapter_context_basic(): + ctx = AdapterContext(batch_size=4, device=torch.device("cpu"), conditioning=None) + assert ctx.batch_size == 4 + assert ctx.device.type == "cpu" + # extras supports arbitrary entries + ctx.extras["foo"] = 123 + assert ctx.extras["foo"] == 123 + + +def test_default_adapter_compute_record_finalize(): + adapter = DefaultEstimatorAdapter(cast(Estimator, _DummyEstimator())) + device = torch.device("cpu") + n = 5 + states = _FakeStates(n, device) + ctx = adapter.init_context(n, device, conditioning=None) + + step_mask = torch.ones(n, dtype=torch.bool, device=device) + dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + actions = dist.sample() + adapter.record_step( + ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True + ) + out = adapter.finalize(ctx) + assert out["log_probs"] is not None and out["log_probs"].shape == (1, n) + assert out["estimator_outputs"] is not None and out["estimator_outputs"].shape[ + :2 + ] == (1, n) + + +def test_recurrent_adapter_requires_init_carry(): + class _BadEstimator: + is_backward = False + + adapter = RecurrentEstimatorAdapter(cast(Estimator, _BadEstimator())) + with pytest.raises(TypeError): + _ = adapter.init_context(2, torch.device("cpu"), None) + + +def test_recurrent_adapter_flow(): + adapter = RecurrentEstimatorAdapter(cast(Estimator, _DummyRecurrentEstimator())) + device = torch.device("cpu") + n = 3 + states = _FakeStates(n, device) + ctx = adapter.init_context(n, device, conditioning=None) + + step_mask = torch.ones(n, dtype=torch.bool, device=device) + dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + actions = dist.sample() + # carry should update when we record multiple steps + h0 = ctx.carry["hidden"].clone() + adapter.record_step( + ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True + ) + # second step + dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + actions = dist.sample() + adapter.record_step( + ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True + ) + h1 = ctx.carry["hidden"].clone() + assert torch.all(h1 == h0 + 1) + out = adapter.finalize(ctx) + assert out["log_probs"] is not None and out["log_probs"].shape == (2, n) + assert out["estimator_outputs"] is not None and out["estimator_outputs"].shape[ + :2 + ] == (2, n) + + +# ---------------------- Integration with real recurrent modules ---------------------- + + +class _SeqStates: + def __init__(self, tokens: torch.Tensor, n_actions: int): + self.tensor = tokens # (batch, seq_len) + b = tokens.shape[0] + device = tokens.device + self.forward_masks = torch.ones((b, n_actions), dtype=torch.bool, device=device) + self.backward_masks = torch.ones( + (b, max(n_actions - 1, 1)), dtype=torch.bool, device=device + ) + + @property + def batch_shape(self): + return (self.tensor.shape[0],) + + @property + def device(self): + return self.tensor.device + + +@pytest.mark.parametrize("rnn_type", ["lstm", "gru"]) +def test_integration_recurrent_sequence_model_with_adapter( + rnn_type: Literal["lstm", "gru"] +) -> None: + device = torch.device("cpu") + batch_size = 3 + vocab_size = 11 + seq_len = 4 + + model = RecurrentDiscreteSequenceModel( + vocab_size=vocab_size, + embedding_dim=8, + hidden_size=16, + num_layers=1, + rnn_type=rnn_type, + dropout=0.0, + ).to(device) + + from gfn.estimators import RecurrentDiscretePolicyEstimator + + estimator = RecurrentDiscretePolicyEstimator( + module=model, + n_actions=vocab_size, + is_backward=False, + ) + + adapter = RecurrentEstimatorAdapter(estimator) + ctx = adapter.init_context(batch_size, device, conditioning=None) + + tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + states = _SeqStates(tokens, vocab_size) + + # Run two steps and verify carry and artifact shapes + step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + for _ in range(2): + dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + actions = dist.sample() + adapter.record_step( + ctx, + step_mask, + actions, + dist, + save_logprobs=True, + save_estimator_outputs=True, + ) + + out = adapter.finalize(ctx) + assert out["log_probs"] is not None and out["log_probs"].shape[0] == 2 + assert ( + out["estimator_outputs"] is not None and out["estimator_outputs"].shape[0] == 2 + ) + + +@pytest.mark.parametrize("positional_embedding", ["learned", "sinusoidal"]) +def test_integration_transformer_sequence_model_with_adapter( + positional_embedding: Literal["learned", "sinusoidal"] +) -> None: + device = torch.device("cpu") + batch_size = 2 + vocab_size = 9 + seq_len = 5 + + model = TransformerDiscreteSequenceModel( + vocab_size=vocab_size, + embedding_dim=12, + num_heads=3, + ff_hidden_dim=24, + num_layers=1, + max_position_embeddings=32, + dropout=0.0, + positional_embedding=positional_embedding, + ).to(device) + + from gfn.estimators import RecurrentDiscretePolicyEstimator + + estimator = RecurrentDiscretePolicyEstimator( + module=model, + n_actions=vocab_size, + is_backward=False, + ) + + adapter = RecurrentEstimatorAdapter(estimator) + ctx = adapter.init_context(batch_size, device, conditioning=None) + + tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + states = _SeqStates(tokens, vocab_size) + + step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + actions = dist.sample() + adapter.record_step( + ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True + ) + + out = adapter.finalize(ctx) + assert out["log_probs"] is not None and out["log_probs"].shape[0] == 1 + assert ( + out["estimator_outputs"] is not None and out["estimator_outputs"].shape[0] == 1 + ) From 14b110cb063e10572538fe479c0012babf2823b5 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 29 Sep 2025 13:19:34 -0400 Subject: [PATCH 03/27] Update input_dim to use preprocessor output_dim --- docs/source/guides/example.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/source/guides/example.md b/docs/source/guides/example.md index a9609611..43d7de88 100644 --- a/docs/source/guides/example.md +++ b/docs/source/guides/example.md @@ -22,14 +22,13 @@ env = HyperGrid(ndim=4, height=8) # Grid of size 8x8x8x8 preprocessor = KHotPreprocessor(ndim=env.ndim, height=env.height) # 2 - We define the needed modules (neural networks). -input_dim = preprocessor.output_dim if preprocessor.output_dim is not None else env.state_shape[-1] module_PF = MLP( - input_dim=input_dim, + input_dim=preprocessor.output_dim, output_dim=env.n_actions ) # Neural network for the forward policy, with as many outputs as there are actions module_PB = MLP( - input_dim=input_dim, + input_dim=preprocessor.output_dim, output_dim=env.n_actions - 1, trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer ) @@ -81,14 +80,13 @@ preprocessor = KHotPreprocessor(ndim=env.ndim, height=env.height) # 2 - We define the needed modules (neural networks). # The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator -input_dim = preprocessor.output_dim if preprocessor.output_dim is not None else env.state_shape[-1] module_PF = MLP( - input_dim=input_dim, + input_dim=preprocessor.output_dim, output_dim=env.n_actions ) # Neural network for the forward policy, with as many outputs as there are actions module_PB = MLP( - input_dim=input_dim, + input_dim=preprocessor.output_dim, output_dim=env.n_actions - 1, trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer ) From c7a3d8c10c4db12de7184a8f83c1a511ff357db4 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 29 Sep 2025 13:20:05 -0400 Subject: [PATCH 04/27] Update input_dim to use preprocessor's output_dim --- docs/source/guides/example.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/guides/example.md b/docs/source/guides/example.md index 43d7de88..9b82502f 100644 --- a/docs/source/guides/example.md +++ b/docs/source/guides/example.md @@ -91,7 +91,7 @@ module_PB = MLP( trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer ) module_logF = MLP( - input_dim=input_dim, + input_dim=preprocessor.output_dim, output_dim=1, # Important for ScalarEstimators! ) From a4fc53a5b86387757b2620d6833901272ed4b7b1 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 8 Oct 2025 13:49:40 -0400 Subject: [PATCH 05/27] added draft of chunking logic -- need to test on some discrete environments --- pyproject.toml | 1 + src/gfn/chunking/__init__.py | 0 src/gfn/chunking/adapters.py | 130 ++++++++++ src/gfn/chunking/chunkers.py | 150 ++++++++++++ src/gfn/chunking/policy.py | 184 ++++++++++++++ src/gfn/env.py | 450 ++++++++++++++++++++++++++++++++++- src/gfn/states.py | 65 ++++- testing/test_chunking.py | 375 +++++++++++++++++++++++++++++ 8 files changed, 1351 insertions(+), 4 deletions(-) create mode 100644 src/gfn/chunking/__init__.py create mode 100644 src/gfn/chunking/adapters.py create mode 100644 src/gfn/chunking/chunkers.py create mode 100644 src/gfn/chunking/policy.py create mode 100644 testing/test_chunking.py diff --git a/pyproject.toml b/pyproject.toml index f069d335..3cf0d7c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ tensordict = ">=0.6.1" torch = ">=2.6.0" torch_geometric = ">=2.6.1" dill = ">=0.3.8" +tokenizers = ">=0.15" # dev dependencies. black = { version = "24.3", optional = true } diff --git a/src/gfn/chunking/__init__.py b/src/gfn/chunking/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/gfn/chunking/adapters.py b/src/gfn/chunking/adapters.py new file mode 100644 index 00000000..b03e9f23 --- /dev/null +++ b/src/gfn/chunking/adapters.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import Any, Optional + +import torch +from torch.distributions import Categorical, Distribution + +from gfn.chunking.policy import ChunkedPolicy +from gfn.env import DiscreteEnv +from gfn.samplers import AdapterContext, EstimatorAdapter +from gfn.states import DiscreteStates + + +class ChunkedAdapter(EstimatorAdapter): + """EstimatorAdapter that produces macro-level distributions using ChunkedPolicy. + + Forward-only in this PR. TODO(backward): support backward chunking by switching + stepping and termination criteria to the backward direction. + """ + + def __init__(self, env: DiscreteEnv, policy: ChunkedPolicy, library: Any) -> None: + self.env = env + self.policy = policy + self.library = library + self._is_backward = False # TODO(backward): allow backward chunking + + @property + def is_backward(self) -> bool: + return self._is_backward + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> AdapterContext: + ctx = AdapterContext( + batch_size=batch_size, device=device, conditioning=conditioning + ) + ctx.extras["macro_log_probs"] = [] # List[(N,)] + return ctx + + def _strict_macro_mask(self, states_active: DiscreteStates) -> torch.Tensor: + """Strict mask by simulating each macro sequentially on each active state. + + Invalidates a macro if any sub-action is invalid or if sink is reached before + the sequence completes. Guarantees EXIT macro is valid if no macro is valid. + """ + B = states_active.batch_shape[0] + N = self.library.n_actions + device = states_active.device + mask = torch.zeros(B, N, dtype=torch.bool, device=device) + + for b in range(B): + s_curr = states_active[b : b + 1] + for j, seq in enumerate(self.library.id_to_sequence): + ok = True + s = s_curr + for k, a in enumerate(seq): + a_tensor = self.env.actions_from_tensor( + torch.tensor([[a]], device=device) + ) + if not self.env.is_action_valid(s, a_tensor): + ok = False + break + s_next = self.env._step(s, a_tensor) + if s_next.is_sink_state.item() and k != len(seq) - 1: + ok = False + break + s = s_next + mask[b, j] = ok + + # Ensure EXIT macro is available when none is valid + try: + exit_id = self.library.id_to_sequence.index([self.env.exit_action.item()]) + except ValueError: + exit_id = N - 1 + no_valid = ~mask.any(dim=1) + if no_valid.any(): + mask[no_valid] = False + mask[no_valid, exit_id] = True + return mask + + def compute( + self, + states_active: DiscreteStates, + ctx: Any, + step_mask: torch.Tensor, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + logits = self.policy.forward_logits(states_active) # (B_active, N) + macro_mask = self._strict_macro_mask(states_active) + masked_logits = torch.where( + macro_mask, logits, torch.full_like(logits, -float("inf")) + ) + dist = Categorical(logits=masked_logits) + ctx.current_estimator_output = None + return dist, ctx + + def record_step( + self, + ctx: Any, + step_mask: torch.Tensor, + sampled_actions: torch.Tensor, + dist: Distribution, + save_logprobs: bool, + save_estimator_outputs: bool, + ) -> None: + if save_logprobs: + lp_masked = dist.log_prob(sampled_actions) + step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device) + step_lp[step_mask] = lp_masked + ctx.extras["macro_log_probs"].append(step_lp) + # No estimator outputs for macros by default + return + + def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: + out: dict[str, Optional[torch.Tensor]] = { + "log_probs": None, + "estimator_outputs": None, + } + macro_log_probs = ctx.extras.get("macro_log_probs", []) + if macro_log_probs: + out["macro_log_probs"] = torch.stack(macro_log_probs, dim=0) + else: + out["macro_log_probs"] = None + return out + + def get_current_estimator_output(self, ctx: Any): + return None diff --git a/src/gfn/chunking/chunkers.py b/src/gfn/chunking/chunkers.py new file mode 100644 index 00000000..c8a0e362 --- /dev/null +++ b/src/gfn/chunking/chunkers.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from collections import Counter +from typing import TYPE_CHECKING, Any, Hashable, Sequence + +from tokenizers import Tokenizer +from tokenizers.models import BPE, WordPiece +from tokenizers.trainers import BpeTrainer, WordPieceTrainer + +if TYPE_CHECKING: # Avoid runtime import to break circular deps with env/containers + from gfn.containers.trajectories import Trajectories + + +class Chunker(ABC): + """Abstract base class for chunkers that propose new vocab tokens. + + Chunkers operate on trajectories and environment context and return + a sequence of token keys (any Hashable) to be added to the env vocab. + """ + + @abstractmethod + def propose_tokens( + self, + env: "Any", + trajectories: Trajectories, + n_tokens_to_add: int, + remove_old: bool, + ) -> Sequence[Hashable]: + raise NotImplementedError + + +class UniformChunker(Chunker): + """Proposes random bigrams of current non-exit tokens as tuples of ints.""" + + def propose_tokens( + self, + env: "Any", + trajectories: Trajectories, + n_tokens_to_add: int, + remove_old: bool, + ) -> Sequence[Hashable]: + # Build non-exit pool from current vocab ids. + non_exit_ids = [i for i in range(env.n_actions) if i != env.exit_token_id] + seen = set(env.vocab) + out: set[Hashable] = set() + while len(out) < n_tokens_to_add and len(out) < 10_000: + a, b = random.choice(non_exit_ids), random.choice(non_exit_ids) + candidate = (a, b) + if candidate not in seen: + out.add(candidate) + return list(out) + + +class _StringMapping: + """Utility to map env keys to strings suitable for tokenizers.""" + + def __init__(self, delimiter: str = "") -> None: + self.delimiter = delimiter + + def key_to_str(self, key: Hashable) -> str: + if isinstance(key, tuple): + return self.delimiter.join(str(x) for x in key) + return str(key) + + +class BPEChunker(Chunker): + def __init__(self, unk_token: str = "[UNK]", delimiter: str = "") -> None: + self.unk_token = unk_token + self.mapper = _StringMapping(delimiter=delimiter) + + def propose_tokens( + self, + env: "Any", + trajectories: Trajectories, + n_tokens_to_add: int, + remove_old: bool, + min_frequency: int = 5, + ) -> Sequence[Hashable]: + # Build corpus strings from trajectories via env tokenizer + corpus = env.trajectories_to_token_strings(trajectories) + + # Build initial vocab from current env keys mapped to strings + vocab_dict = {self.mapper.key_to_str(k): i for i, k in enumerate(env.vocab)} + tokenizer = Tokenizer(BPE(vocab_dict, [], unk_token=self.unk_token)) + + target_vocab_size = len(env.vocab) - 1 + n_tokens_to_add + trainer = BpeTrainer( + vocab_size=target_vocab_size, # type: ignore + special_tokens=[self.unk_token], # type: ignore + min_frequency=min_frequency, # type: ignore + ) + tokenizer.train_from_iterator(corpus, trainer=trainer) + + # Take the most common new tokens. + base_vocab = set(vocab_dict.keys()) + encodings = tokenizer.encode_batch(corpus) + counts = Counter() + for enc in encodings: + for tok in enc.tokens: + if tok not in base_vocab and tok != self.unk_token and len(tok) > 0: + counts[tok] += 1 + + top_new = [tok for tok, _ in counts.most_common(n_tokens_to_add)] + return top_new + + +class WordPieceChunker(Chunker): + def __init__(self, unk_token: str = "[UNK]", delimiter: str = "") -> None: + self.unk_token = unk_token + self.mapper = _StringMapping(delimiter=delimiter) + + def propose_tokens( + self, + env: "Any", + trajectories: Trajectories, + n_tokens_to_add: int, + remove_old: bool, + min_frequency: int = 5, + ) -> Sequence[Hashable]: + corpus = env.trajectories_to_token_strings(trajectories) + vocab_dict = {self.mapper.key_to_str(k): i for i, k in enumerate(env.vocab)} + tokenizer = Tokenizer( + WordPiece( + vocab=vocab_dict, + unk_token=self.unk_token, + max_input_chars_per_word=100, + ) + ) + target_vocab_size = len(env.vocab) - 1 + n_tokens_to_add + trainer = WordPieceTrainer( + vocab_size=target_vocab_size, + continuing_subword_prefix="##", # Defined prefix (removed later). + special_tokens=[self.unk_token], + min_frequency=min_frequency, + ) + tokenizer.train_from_iterator(corpus, trainer=trainer) + + # Take the most common new tokens. + base_vocab = set(vocab_dict.keys()) + encodings = tokenizer.encode_batch(corpus) + counts = Counter() + for enc in encodings: + for tok in enc.tokens: + if tok not in base_vocab and tok != self.unk_token and len(tok) > 0: + counts[tok.lstrip("##")] += 1 # Remove prefix if present. + + top_new = [tok for tok, _ in counts.most_common(n_tokens_to_add)] + return top_new diff --git a/src/gfn/chunking/policy.py b/src/gfn/chunking/policy.py new file mode 100644 index 00000000..8689d15b --- /dev/null +++ b/src/gfn/chunking/policy.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List + +import torch +from torch import nn + +from gfn.states import DiscreteStates + +if TYPE_CHECKING: + from gfn.env import ChunkedDiscreteEnvironment + + +class ChunkedPolicy(nn.Module): + """Compute logits over a macro library via state and macro embeddings. + + The `state_module` maps preprocessed states to a fixed-size embedding. The + `action_encoder` maps action sequences (macros) to the same embedding space. + Logits are the scaled dot products between state embeddings and macro embeddings. + """ + + def __init__( + self, + state_module: nn.Module, + action_encoder: ActionEncoder, + env: "ChunkedDiscreteEnvironment", + action_embedding_dim: int, + primitive_id_mapper: Callable[[int], int] | None = None, + ) -> None: + super().__init__() + self.state_module = state_module + self.action_encoder = action_encoder + self.env = env + self.action_embedding_dim = int(action_embedding_dim) + self.primitive_id_mapper: Callable[[int], int] = ( + primitive_id_mapper if primitive_id_mapper is not None else (lambda x: x) + ) + self.register_buffer("_library_embeddings", torch.empty(0)) + + @torch.no_grad() + def refresh_library_embeddings(self, device: torch.device) -> None: + # Build sequences from env vocab decoded to primitive ids + vocab = self.env.vocab + seqs: List[List[int]] = [] + max_len = 0 + + # TODO: This should rely on the env tokenizer instead of primitive_id_mapper. + for key in vocab: + decoded = list(self.env.decode_key_to_actions(key)) + mapped = [self.primitive_id_mapper(i) for i in decoded] + max_len = max(max_len, len(mapped)) + + # Ensure at least length 1 for encoder stability. + seqs.append(mapped if len(mapped) > 0 else [0]) # 0 is a placeholder FIXME. + + if len(seqs) == 0: + self._library_embeddings = torch.empty( + 0, self.action_embedding_dim, device=device + ) + return + + x = torch.full( + (len(seqs), max_len), + fill_value=0, + dtype=torch.long, + device=device, + ) + + # TODO: Vectorize this. + for i, s in enumerate(seqs): + if len(s) > 0: + x[i, : len(s)] = torch.tensor(s, dtype=torch.long, device=device) + + self._library_embeddings = self.action_encoder(x) # (N, D) + + def forward_logits(self, states: DiscreteStates) -> torch.Tensor: + state_emb = self.state_module(states) # (*B, D) + if ( + self._library_embeddings.numel() == 0 + or self._library_embeddings.shape[0] != self.env.n_actions + ): + self.refresh_library_embeddings(device=state_emb.device) + + logits = torch.einsum("bd,nd->bn", state_emb, self._library_embeddings) + logits = logits / (state_emb.shape[-1] ** 0.5) + + return logits + + +class ActionModel(nn.Module): + def __init__( + self, + n_primitive_actions: int, + hidden_dim: int = 256, + action_embedding_dimension: int = 128, + ) -> None: + super().__init__() + self.primitive_embedding = nn.Embedding(n_primitive_actions, hidden_dim) + self.rnn_encoder = nn.GRU(hidden_dim, hidden_dim, num_layers=2, batch_first=True) + self.out_layer = nn.Sequential( + nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, action_embedding_dimension) + ) + + def forward(self, x): # x: (B, L) + emb = self.primitive_embedding(x) + s, _ = self.rnn_encoder(emb) + out = s[:, -1] + out = self.out_layer(out) + return out + + +class PositionalEncoding(nn.Module): + """Minimal sinusoidal positional encoding for completeness. + + If a richer implementation exists elsewhere, prefer importing it. + """ + + pe: torch.Tensor # registered buffer + + def __init__(self, dim: int, dropout: float = 0.0, max_len: int = 512) -> None: + super().__init__() + self.dropout = nn.Dropout(dropout) + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len, dtype=torch.get_default_dtype()).unsqueeze(1) + div_term = torch.exp( + torch.arange( + 0, + dim, + 2, + dtype=torch.get_default_dtype(), + ) + * (-torch.log(torch.tensor(10000.0)) / dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + if dim % 2 == 0: + pe[:, 1::2] = torch.cos(position * div_term) + else: + pe[:, 1::2] = torch.cos(position * div_term)[:, : dim // 2] + + self.register_buffer("pe", pe) + + def forward(self, x): + # x: (B, L, D) + L = x.size(1) + x = x + self.pe[:L] + return self.dropout(x) + + +class ActionEncoder(nn.Module): + def __init__( + self, + n_primitive_actions: int, + action_embedding_dimension: int, + hidden_dim: int, + num_layers: int, + num_head: int, + max_len: int = 60, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.pos = PositionalEncoding(hidden_dim, dropout=dropout, max_len=max_len + 1) + self.embedding = nn.Embedding(n_primitive_actions, hidden_dim) + encoder_layers = nn.TransformerEncoderLayer( + hidden_dim, num_head, hidden_dim, dropout=dropout, batch_first=True + ) + self.encoder = nn.TransformerEncoder(encoder_layers, num_layers) + + # TODO: For the action encoder to work properly with macros of variable length, + # do we need the embedding layer to be recurrent. Or can we just use a simple + # embedding layer? + self.action_embedding_layer = nn.Linear(hidden_dim, action_embedding_dimension) + + def forward(self, x_ids): # (B, L) with 0 = PAD + pad = x_ids == 0 # (B, L) bool + x = self.embedding(x_ids) # (B, L, D) + x = self.pos(x) # (B, L, D) + x = self.encoder(x, src_key_padding_mask=pad) # mask pads in attention + + mask = (~pad).unsqueeze(-1) # (B, L, 1) + denom = mask.sum(dim=1).clamp_min(1) # (B, 1) + pooled = (x * mask).sum(dim=1) / denom # (B, D) + + return self.action_embedding_layer(pooled) diff --git a/src/gfn/env.py b/src/gfn/env.py index e77e36bf..30881e57 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple, cast +from contextlib import contextmanager +from typing import Any, Callable, Hashable, Optional, Sequence, Tuple, cast import torch from torch_geometric.data import Data as GeometricData from gfn.actions import Actions, GraphActions -from gfn.states import DiscreteStates, GraphStates, States +from gfn.chunking.chunkers import Chunker +from gfn.states import ChunkedStates, DiscreteStates, GraphStates, States from gfn.utils.common import ensure_same_device, set_seed # Errors @@ -325,6 +327,9 @@ def _step(self, states: States, actions: Actions) -> States: # For the indices where the new states are not sink states (i.e., where the # state is not already a sink and the action is not exit), update those # positions with the result of the environment's step function. + # TODO: Ensure that the step function returns a States instance with the same + # type as the States class. Right now, we initialize with -inf, which assumes + # a float type. For now, I will handle casting outside of this method. new_states = self.States.make_sink_states( states.batch_shape, device=states.device ) @@ -656,6 +661,21 @@ def _step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """ new_states = super()._step(states, actions) new_states = cast(DiscreteStates, new_states) + + # Ensure dtypes of all tensor fields match the input states. + # TODO: We should probably fix this at the base Env class level, but + # I want to do it in a follow up PR. + if new_states.tensor.dtype != states.tensor.dtype: + new_states.tensor = new_states.tensor.to(dtype=states.tensor.dtype) + if new_states.forward_masks.dtype != states.forward_masks.dtype: + new_states.forward_masks = new_states.forward_masks.to( + dtype=states.forward_masks.dtype + ) + if new_states.backward_masks.dtype != states.backward_masks.dtype: + new_states.backward_masks = new_states.backward_masks.to( + dtype=states.backward_masks.dtype + ) + self.update_masks(new_states) return new_states @@ -673,6 +693,21 @@ def _backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSt """ new_states = super()._backward_step(states, actions) new_states = cast(DiscreteStates, new_states) + + # Ensure dtypes of all tensor fields match the input states. + # TODO: We should probably fix this at the base Env class level, but + # I want to do it in a follow up PR. + if new_states.tensor.dtype != states.tensor.dtype: + new_states.tensor = new_states.tensor.to(dtype=states.tensor.dtype) + if new_states.forward_masks.dtype != states.forward_masks.dtype: + new_states.forward_masks = new_states.forward_masks.to( + dtype=states.forward_masks.dtype + ) + if new_states.backward_masks.dtype != states.backward_masks.dtype: + new_states.backward_masks = new_states.backward_masks.to( + dtype=states.backward_masks.dtype + ) + self.update_masks(new_states) return new_states @@ -757,6 +792,417 @@ def terminating_states(self) -> DiscreteStates: ) +class ChunkedDiscreteEnvironment(DiscreteEnv): + """Discrete environment with chunking-aware action vocab management. + + Intended behavior: + - Exit invariance: the exit action is represented by a fixed, non-executable + sentinel key at its index and never moves when the vocab grows. + - Vocab growth: new primitive/macro tokens are appended (no reindexing of + existing actions). Token-to-id and id-to-token mappings are maintained by + the environment. + - Self-healing states: the custom ``States`` class returned by + :meth:`make_states_class` automatically resizes forward/backward masks to + ``env.n_actions`` and overlays environment-wide constraints (soft-disabled + actions and strict macro feasibility) whenever masks are created or the + batch shape changes. + - Chunker integration: arbitrary chunkers can propose new tokens; helpers + are provided to build string corpora from trajectories for text-based + chunkers (e.g., BPE/WordPiece). + """ + + def __init__( + self, + n_actions: int, + s0: torch.Tensor, + state_shape: Tuple | int, + *, + action_shape: Tuple | int = (1,), + dummy_action: Optional[torch.Tensor] = None, + exit_action: Optional[torch.Tensor] = None, + sf: Optional[torch.Tensor] = None, + check_action_validity: bool = True, + tokenizer: Optional[Callable[[Sequence[int]], str]] = None, + detokenizer: Optional[Callable[[str], Sequence[int]]] = None, + ) -> None: + # Delegate exit action handling to DiscreteEnv (base behavior). + super().__init__( + n_actions=n_actions, + s0=s0, + state_shape=state_shape, + action_shape=action_shape, + dummy_action=dummy_action, + exit_action=exit_action, + sf=sf, + check_action_validity=check_action_validity, + ) + + # Fixed exit action id derived from parent init. + self.exit_token_id: int = int(self.exit_action.item()) + + # Re-entrancy guard for macro overlay recursion + self._macro_overlay_depth: int = 0 + + # Hashable-keyed vocab: start with primitive ids, then exit sentinel at its index. + self._exit_key: str = "" + self.id_to_token_key: list[Hashable] = list(range(self.n_actions - 1)) + [ + self._exit_key + ] + self.token_key_to_id: dict[Hashable, int] = { + k: i for i, k in enumerate(self.id_to_token_key) + } + + # Initially, all non-exit tokens are considered atomic. + self._atomic_token_ids: set[int] = { + i for i in range(self.n_actions) if i != self.exit_token_id + } + self._disabled_token_ids: set[int] = set() + + self._tokenizer: Callable[[Sequence[int]], str] + self._tokenizer = tokenizer if tokenizer is not None else self._default_tokenizer + + # Optional detokenizer to convert string keys to primitive action sequences. + # TODO: the tokenizer should have methods for this (instead of a unique detokenizer). + self._detokenizer: Optional[Callable[[str], Sequence[int]]] = detokenizer + + @staticmethod + def _default_tokenizer(ids: Sequence[int]) -> str: + # Example: [1, 30, 7] -> "1,30,7," + return ",".join(str(i) for i in ids) + "," + + @property + def vocab(self) -> list[Hashable]: + return list(self.id_to_token_key) + + def add_tokens(self, new_keys: Sequence[Hashable]) -> list[int]: + """Append new token keys to the vocab, assigning fresh ids. + + Args: + new_keys: Iterable of integer token keys to add. + + Returns: + The list of ids assigned to the newly added keys (in insertion order). + """ + new_ids: list[int] = [] + for key in new_keys: + if key in self.token_key_to_id: + continue + + new_id = len(self.id_to_token_key) + self.id_to_token_key.append(key) + self.token_key_to_id[key] = new_id + new_ids.append(new_id) + + if new_ids: + self.n_actions = len(self.id_to_token_key) + + return new_ids + + def disable_tokens(self, tokens_or_ids: Sequence[int]) -> None: + for tid in tokens_or_ids: + tid_i = int(tid) + if tid_i != self.exit_token_id: + self._disabled_token_ids.add(tid_i) + + def enable_tokens(self, tokens_or_ids: Sequence[int]) -> None: + for tid in tokens_or_ids: + tid_i = int(tid) + self._disabled_token_ids.discard(tid_i) + + def chunk_and_update_vocab( + self, + trajectories: Any, + chunker: Chunker, + n_tokens_to_add: int, + remove_old: bool = False, + ) -> list[int]: + """Calls a user-provided chunker and updates the vocab with proposed keys. + + The chunker returns token keys (typically integers) to be appended. If + remove_old is True, previously learned non-atomic tokens that were not + just added will be soft-disabled via masks. + """ + # Expect a Chunker-like instance exposing propose_tokens(env, trajectories, n, remove_old) + proposed_keys = list( + chunker.propose_tokens(self, trajectories, n_tokens_to_add, remove_old) + ) + new_ids = self.add_tokens(proposed_keys) + if remove_old: + keep = set(self._atomic_token_ids) | {self.exit_token_id} | set(new_ids) + to_disable = [i for i in range(self.n_actions) if i not in keep] + self.disable_tokens(to_disable) + return new_ids + + def trajectories_to_action_sequences(self, trajs: Any) -> list[list[int]]: + # actions: (T, B) after squeeze, terminating_idx: (B,) + actions = trajs.actions.tensor.squeeze(-1) + term = trajs.terminating_idx + out: list[list[int]] = [] + T, B = actions.shape + for i in range(B): + L = int(term[i].item()) + idxs = [ + int(a) for a in actions[:L, i].tolist() if int(a) != self.exit_token_id + ] + out.append(idxs) + return out + + def trajectories_to_token_strings(self, trajs: Any) -> list[str]: + seqs = self.trajectories_to_action_sequences(trajs) + return [self._tokenizer(seq) for seq in seqs] + + def apply_soft_disabled_to_forward_masks(self, states: DiscreteStates) -> None: + if self._disabled_token_ids: + ids = torch.tensor(sorted(self._disabled_token_ids), device=states.device) + states.forward_masks[..., ids] = False + + @contextmanager + def macro_mask_guard(self): + """Temporarily disable macro overlay application to avoid recursion. + + Increments an environment-local depth counter; while non-zero, calls to + apply_macro_forward_mask will no-op. Always decremented in a finally block. + """ + self._macro_overlay_depth += 1 + try: + yield + finally: + self._macro_overlay_depth -= 1 + assert self._macro_overlay_depth >= 0 + + def make_states_class(self) -> type[DiscreteStates]: + """Returns the DiscreteStates class for this environment. + + Returns: + A type of a subclass of DiscreteStates with environment-specific + functionalities. + """ + env = self + + class ChunkedEnvStates(ChunkedStates): + """States for chunked env that auto-resize and overlay masks. + + Responsibilities: + - Keep ``n_actions`` synchronized with the parent environment. + - Lazily resize masks to match ``env.n_actions`` after construction, + extension, or padding. + - Apply environment overlays (soft-disables and strict macro feasibility) + after any (re)allocation of masks. + """ + + state_shape = env.state_shape + s0 = env.s0 + sf = env.sf + make_random_states = env.make_random_states + n_actions = env.n_actions + device = env.device + + # wire hooks into the shared ChunkedStates base + @staticmethod + def get_n_actions() -> int: + return env.n_actions + + @staticmethod + def overlay_masks(s: "ChunkedStates") -> None: + env.apply_soft_disabled_to_forward_masks(s) + env.apply_macro_forward_mask(s) + + return ChunkedEnvStates + + @abstractmethod + def update_masks(self, states: DiscreteStates) -> None: + """Subclasses must compute env-specific masks and then call overlay helper. + + Example pattern: + states.set_nonexit_action_masks(cond=..., allow_exit=True) + self.apply_soft_disabled_to_forward_masks(states) + """ + ... + + def decode_key_to_actions(self, key: Hashable) -> Sequence[int]: + """Decodes a vocab key (potential macro) into a sequence of primitive actions. + + - int -> [int] + - tuple[int,...] -> list(tuple) + - str -> detokenizer(str) if provided; otherwise empty sequence (non-executable) + """ + if isinstance(key, int): + return [int(key)] + if isinstance(key, tuple): + # assume a tuple of ints + return [int(x) for x in key] + if isinstance(key, str): + if key == self._exit_key: + return [] + if self._detokenizer is not None: + return list(self._detokenizer(key)) + + raise ValueError(f"Invalid key: {key}") + + def _decode_action_id_to_sequence(self, action_id: int) -> Sequence[int]: + key = self.id_to_token_key[action_id] + return self.decode_key_to_actions(key) + + def is_macro_id(self, action_id: int) -> bool: + seq = self._decode_action_id_to_sequence(action_id) + return len(seq) > 1 + + def _compute_macro_mask_flat(self, states: DiscreteStates) -> torch.Tensor: + """Compute macro feasibility for a 1D batch of ChunkedStates. + + Returns: (B, n_actions) boolean mask for macros only (primitives/exit left True). + """ + if not isinstance(states, ChunkedStates): + raise TypeError("compute macro mask requires ChunkedStates") + + assert len(states.batch_shape) == 1 + B = states.batch_shape[0] + macro_mask = torch.ones( + B, self.n_actions, dtype=torch.bool, device=states.device + ) + + # Collect macro sequences + macro_sequences: dict[int, Sequence[int]] = {} + for action_id in range(self.n_actions): + seq = self._decode_action_id_to_sequence(action_id) + if len(seq) > 1: + macro_sequences[action_id] = seq + + if not macro_sequences: + return macro_mask # no macros to validate. + + for aid, seq in macro_sequences.items(): + # Local working copy of states; do not mutate caller's states + s_curr = states.clone() + valid_vec = torch.ones(B, dtype=torch.bool, device=states.device) + + for primitive_id in seq: + # Per-state validity for this primitive at the current step. + step_valid = s_curr.forward_masks[:, primitive_id] + to_step = valid_vec & step_valid + valid_vec &= step_valid # Update cumulative validity + + if bool(to_step.any().item()): + idx = torch.where(to_step)[0] + n = idx.numel() + a_tensor = torch.full( + (n, *self.action_shape), + primitive_id, + device=states.device, + dtype=torch.long, + ) + a = self.actions_from_tensor(a_tensor) + next_sub = super()._step(s_curr[idx], a) + s_curr[idx] = next_sub + + # Refresh masks after stepping (guard prevents recursion) + self.update_masks(s_curr) + + macro_mask[:, aid] = valid_vec + + return macro_mask + + def compute_strict_macro_forward_mask(self, states: DiscreteStates) -> torch.Tensor: + """Returns a mask of shape (batch_shape*, n_actions) validating macros. + + Supports (B) and (T,B). Requires ChunkedStates. + """ + # Enforce ChunkedStates + if not isinstance(states, ChunkedStates): + raise TypeError("compute_strict_macro_forward_mask requires ChunkedStates") + + with self.macro_mask_guard(): + if len(states.batch_shape) == 1: + return self._compute_macro_mask_flat(states) + elif len(states.batch_shape) == 2: + T, B = states.batch_shape + # Horizon pre-check: macros longer than remaining steps are invalid + macro_lengths = torch.tensor( + [ + len(self._decode_action_id_to_sequence(i)) + for i in range(self.n_actions) + ], + device=states.device, + dtype=torch.long, + ) + t_idx = torch.arange(T, device=states.device) + remaining = (T - t_idx).view(T, 1) + horizon_ok = ( + macro_lengths.view(1, self.n_actions) <= remaining.view(T, 1) + ).to(torch.bool) + + # Compute feasibility ignoring horizon, then AND with horizon_ok + flat = states.flatten() + flat_mask = self._compute_macro_mask_flat(flat) # (T*B, n_actions) + tb_mask = flat_mask.view(T, B, self.n_actions) + # Broadcast horizon_ok over B + tb_mask = tb_mask & horizon_ok.view(T, 1, self.n_actions) + return tb_mask + else: + raise ValueError( + f"Expected batch_shape (B) or (T,B), got {states.batch_shape}" + ) + + def apply_macro_forward_mask(self, states: DiscreteStates) -> None: + # Skip macro overlay while inside macro feasibility computation + if getattr(self, "_macro_overlay_depth", 0) > 0: + return + macro_mask = self.compute_strict_macro_forward_mask(states) + states.forward_masks = states.forward_masks & macro_mask + + def _step(self, states: ChunkedStates, actions: Actions) -> ChunkedStates: + """Overrides base to unroll macro actions sequentially. + + Non-macro actions are delegated to the base implementation. + """ + assert states.batch_shape == actions.batch_shape + B = states.batch_shape[0] + + # Identify macro vs non-macro per batch element + action_ids = actions.tensor.view(B) + macro_flags = torch.zeros(B, dtype=torch.bool, device=states.device) + for i in range(B): + macro_flags[i] = self.is_macro_id(int(action_ids[i].item())) + + # Fast path: no macros found. + if not bool(macro_flags.any().item()): + return cast(ChunkedStates, super()._step(states, actions)) + + # Split states/actions + out_states = self.States.make_sink_states( + states.batch_shape, device=states.device + ) + + # Handle non-macro subset via base + non_macro_idx = ~macro_flags + if bool(non_macro_idx.any().item()): + nm_states = states[non_macro_idx] + nm_actions = actions[non_macro_idx] + nm_next = super()._step(nm_states, nm_actions) + out_states[non_macro_idx] = nm_next + + # Handle macros by sequential unroll + if bool(macro_flags.any().item()): + m_states = states[macro_flags] + m_action_ids = action_ids[macro_flags] + # iterate each macro in the smaller batch + curr = m_states + for j in range(curr.batch_shape[0]): + aid = int(m_action_ids[j].item()) + seq = self._decode_action_id_to_sequence(aid) + s = curr[j : j + 1] + for primitive_id in seq: + a_tensor = self.actions_from_tensor( + torch.tensor([[primitive_id]], device=states.device) + ) + s = super()._step(s, a_tensor) + out_states[macro_flags][j : j + 1] = s + + # Update masks for the resulting batch + self.update_masks(cast(DiscreteStates, out_states)) + return cast(ChunkedStates, out_states) + + class GraphEnv(Env): """Base class for graph-based environments. diff --git a/src/gfn/states.py b/src/gfn/states.py index a8bcc178..8abb4a3a 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -233,7 +233,14 @@ def __setitem__( index: Indices to set. states: States object containing the new states. """ - self.tensor[index] = states.tensor + # Align dtype/device of the source to the destination slice to avoid + # runtime errors from mismatched tensor properties during indexed writes. + # Note: we intentionally do not mutate `states.tensor` in-place. + dest = self.tensor + src = states.tensor + if src.dtype != dest.dtype or src.device != dest.device: + src = src.to(device=dest.device, dtype=dest.dtype) + self.tensor[index] = src def clone(self) -> States: """Returns a clone of the current instance. @@ -522,7 +529,9 @@ def __getitem__( return out def __setitem__( - self, index: int | Sequence[int] | Sequence[bool], states: DiscreteStates + self, + index: int | Sequence[int] | Sequence[bool] | torch.Tensor, + states: DiscreteStates, ) -> None: """Sets particular discrete states and their masks. @@ -672,6 +681,58 @@ def init_forward_masks(self, set_ones: bool = True) -> None: self.forward_masks = torch.zeros(shape).to(self.device).bool() +class ChunkedStates(DiscreteStates): + """Reusable ChunkedStates base used by chunking-aware environments. + + Env factories should return a subclass that binds env-specific class variables + (state_shape, s0, sf, n_actions, device) and two hooks: + - get_n_actions: Callable[[], int] returning current env.n_actions + - overlay_masks: Callable[[ChunkedStates], None] applying env overlays + """ + + # Hooks to be provided by the environment-specific subclass + get_n_actions: ClassVar[Callable[[], int]] = lambda: cast( + int, ChunkedStates.n_actions + ) + overlay_masks: ClassVar[Optional[Callable[["ChunkedStates"], None]]] = None + + def __init__(self, tensor, forward_masks=None, backward_masks=None): + super().__init__(tensor, forward_masks, backward_masks) + self._ensure_current() + + def _ensure_current(self) -> None: + # Keep class-level n_actions in sync with env via hook if available + try: + self.__class__.n_actions = int(self.__class__.get_n_actions()) + except Exception: + pass + + if (self.forward_masks.shape[-1] != self.__class__.n_actions) or ( + self.backward_masks.shape[-1] != self.__class__.n_actions - 1 + ): + self.forward_masks = torch.ones( + (*self.batch_shape, self.__class__.n_actions), + dtype=torch.bool, + device=self.device, + ) + self.backward_masks = torch.ones( + (*self.batch_shape, self.__class__.n_actions - 1), + dtype=torch.bool, + device=self.device, + ) + + if self.__class__.overlay_masks is not None: + self.__class__.overlay_masks(self) + + def pad_dim0_with_sf(self, required_first_dim: int) -> None: + super().pad_dim0_with_sf(required_first_dim) + self._ensure_current() + + def extend(self, other: "ChunkedStates") -> None: + super().extend(other) + self._ensure_current() + + class GraphStates(States): """Base class for graph-based state representations. diff --git a/testing/test_chunking.py b/testing/test_chunking.py new file mode 100644 index 00000000..baedf74b --- /dev/null +++ b/testing/test_chunking.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +import random +from typing import List, Sequence, cast + +import torch +import torch.nn as nn + +from gfn.chunking.policy import ActionEncoder, ChunkedPolicy +from gfn.containers import Trajectories +from gfn.env import ChunkedDiscreteEnvironment +from gfn.states import ChunkedStates, DiscreteStates + +# from gfn.chunking.adapters import ChunkedAdapter + + +class SyntheticTokenEnv(ChunkedDiscreteEnvironment): + def __init__(self, device: torch.device = torch.device("cpu")): + # n_actions = 27: A=1,B=2,C=3,D=4,...,Z=26,EXIT=27 + n_actions = 27 + s0 = torch.tensor([0], device=device) + # Tokenizer maps primitive ints to letters for string-based chunkers. + + def _letters_tokenizer(seq: Sequence[int]) -> str: + alpha = {i: chr(ord("A") + i - 1) for i in range(1, 27)} # 1->A ... 26->Z + alpha[27] = "EXIT" + + return "".join(alpha.get(i, "") for i in seq) + + super().__init__( + n_actions=n_actions, + s0=s0, + state_shape=(1,), + action_shape=(1,), + check_action_validity=True, + tokenizer=_letters_tokenizer, + exit_action=torch.tensor([27], device=device), + ) + + def update_masks(self, states: ChunkedStates) -> None: + # Forward mask: disallow [B,A] and [C,D] + # Backward mask: inverse constraints: if curr==A, disallow parent B; if curr==D, disallow parent C + device = states.device + batch_shape = states.batch_shape + + # Initialize all true for both masks with correct batch shape + fwd = torch.ones((*batch_shape, self.n_actions), dtype=torch.bool, device=device) + bwd = torch.ones( + (*batch_shape, self.n_actions - 1), dtype=torch.bool, device=device + ) + + # Current token value per batch element (handles (B,1) and (T,B,1)) + last = states.tensor.squeeze(-1) # shape == (*batch_shape,) + + # Forward constraints + disallow_A = last == 2 # if last == B --> disallow A + if disallow_A.any(): + # Mask primitive A (index 0) + fwd[..., 1].masked_fill_(disallow_A, False) + disallow_D = last == 3 # if last == C --> disallow D + if disallow_D.any(): + # Mask primitive D (index 3) + fwd[..., 4].masked_fill_(disallow_D, False) + + # Backward constraints (no EXIT column) + disallow_parent_B = last == 1 # current == A -> disallow parent B + if disallow_parent_B.any(): + bwd[..., 2].masked_fill_(disallow_parent_B, False) + + disallow_parent_C = last == 4 # current == D -> disallow parent C + if disallow_parent_C.any(): + bwd[..., 3].masked_fill_(disallow_parent_C, False) + + states.forward_masks = fwd + states.backward_masks = bwd + + # Overlay global disables and macro feasibility + self.apply_soft_disabled_to_forward_masks(states) + self.apply_macro_forward_mask(states) + + def step(self, states: DiscreteStates, actions) -> DiscreteStates: + # Set state to the action token, unless EXIT; EXIT leads to sink (-1) + device = states.device + a = actions.tensor # preserve shape (matches states.tensor) + new = states.tensor.clone() + # For any EXIT, set sink + is_exit = a == (self.n_actions - 1) + new[is_exit] = torch.tensor([-1], device=device, dtype=new.dtype) + # For non-exit, set to the chosen primitive token id + non_exit = ~is_exit + if a.dtype != new.dtype: + a = a.to(dtype=new.dtype) + new[non_exit] = a[non_exit] + out = self.states_from_tensor(new) + return out + + def backward_step(self, states: DiscreteStates, actions) -> DiscreteStates: + # For synthetic tests, just revert to s0 for simplicity when not exit + device = states.device + a = actions.tensor.view(-1) + new = states.tensor.clone() + # Non-exit moves to a dummy parent; for tests we can use 0 + new[a != (self.n_actions - 1)] = torch.tensor( + [0], device=device, dtype=new.dtype + ) + out = self.states_from_tensor(new) + return out + + +def generate_synthetic_corpus( + n_traj: int = 10000, + length: int = 20, + device: torch.device = torch.device("cpu"), +) -> Trajectories: + # Build N trajectories respecting forward constraints and injecting chunks + # Tokens 1..4 only; no EXIT in the corpus + env = SyntheticTokenEnv(device) + + actions_2d = torch.zeros(length, n_traj, dtype=torch.long, device=device) + term = torch.full((n_traj,), length, dtype=torch.long, device=device) + + subseq = [2, 3, 1, 4, 2] # "BCADB" + for i in range(n_traj): + seq: List[int] = [] + + # Choose an insertion index for BCADB that fits in the sequence + ins_start = random.randint(0, length - len(subseq)) + + t = 0 + while t < length: + + # Add the subsequence here. + if t == ins_start: + seq.extend(subseq) + t += len(subseq) + continue + + last = seq[-1] if seq else 0 + candidates = list(range(1, 27)) + + # Forward constraints from SyntheticTokenEnv.update_masks + if last == 2: # after B, disallow A + candidates.remove(1) + if last == 3: # after C, disallow D + candidates.remove(4) + seq.append(random.choice(candidates)) + t += 1 + + actions_2d[:, i] = torch.tensor(seq[:length], dtype=torch.long, device=device) + + # Wrap into env-specific containers + actions = env.actions_from_tensor(actions_2d.unsqueeze(-1)) + + # Derive states by unrolling tokens as last-observation states + states_tensor = torch.zeros(length + 1, n_traj, 1, dtype=torch.long, device=device) + states_tensor[0, :, 0] = 0 # s0 + states_tensor[1:, :, 0] = actions_2d + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + + return Trajectories( + env=env, + states=states, + actions=actions, + terminating_idx=term, + is_backward=False, + ) + + +class TinyStateModule(nn.Module): + def __init__(self, embed_dim: int = 16): + super().__init__() + self.embed = nn.Embedding( + 6, embed_dim + ) # allow -1,0..4 remapped in preprocessing + self.proj = nn.Linear(embed_dim, 32) + + def forward(self, states: DiscreteStates) -> torch.Tensor: + x = states.tensor.view(-1) + # Remap -1->0, keep 0..4 as is plus 1 offset to avoid negative indices + x = torch.clamp(x + 1, min=0) + e = self.embed(x) + out = self.proj(e) + return out + + +class _ConstState(nn.Module): + """Return a constant state embedding for any input batch. + + This isolates the test to the interaction between `ChunkedPolicy` and + `ActionEncoder`, avoiding any dependency on a learned state network. + """ + + def __init__(self, embed_dim: int) -> None: + super().__init__() + self.embedding = nn.Parameter(torch.zeros(embed_dim), requires_grad=False) + + def forward(self, states: DiscreteStates) -> torch.Tensor: + batch = states.batch_shape[0] + return self.embedding.expand(batch, -1) + + +def test_policy_encoder_growing_action_space_with_synthetic_env(): + # Ensure deterministic encoder behavior + torch.manual_seed(0) + + device = torch.device("cpu") + D = 32 + + # Reuse the synthetic environment with primitives A..Z and EXIT + env = SyntheticTokenEnv(device) + + # Build a real batch of discrete states (values don't matter for constant state net) + states = env.states_from_tensor(torch.tensor([[0], [1], [2]], device=device)) + + # Encoder maps sequences of primitive ids to action embeddings in R^D + encoder = ActionEncoder( + n_primitive_actions=env.n_actions, # primitives + EXIT + action_embedding_dimension=D, + hidden_dim=32, + num_layers=1, + num_head=4, + max_len=8, + dropout=0.0, + ) + + # Policy produces logits via scaled dot product between state and action embeddings + policy = ChunkedPolicy(_ConstState(D), encoder, env, action_embedding_dim=D) + + # First pass: primitives only + logits1 = policy.forward_logits(states) + assert logits1.shape == (states.batch_shape[0], env.n_actions) + emb1 = policy._library_embeddings.detach().clone() + assert torch.isfinite(emb1).all() + + # Grow action space by adding a length-5 macro (BCADB) + env.add_tokens([(2, 3, 1, 4, 2)]) + logits2 = policy.forward_logits(states) + assert logits2.shape == (states.batch_shape[0], env.n_actions) + emb2 = policy._library_embeddings.detach().clone() + + # Existing embeddings should be unchanged after refresh + assert torch.allclose(emb2[: emb1.shape[0]], emb1, atol=1e-1) + + # Grow again with a different-length macro + env.add_tokens([(3, 3, 3)]) + logits3 = policy.forward_logits(states) + assert logits3.shape == (states.batch_shape[0], env.n_actions) + emb3 = policy._library_embeddings.detach() + assert emb3.shape[0] == env.n_actions + assert torch.isfinite(emb3).all() + + # Check scaled dot-product formula: l_t = f_θ(A) q_t / sqrt(d) + state_emb = policy.state_module(states) + expected = torch.einsum("bd,nd->bn", state_emb, emb3) / (D**0.5) + assert torch.allclose(logits3, expected, atol=1e-6) + + +def test_mining_finds_chunks(): + from gfn.chunking.chunkers import BPEChunker, WordPieceChunker + + device = torch.device("cpu") + trajs = generate_synthetic_corpus(2000, 20, device) + env = trajs.env # SyntheticTokenEnv with letters tokenizer. + + # Propose with BPE and WordPiece; expect 'BCADB' among new tokens (present in all sequences) + bpe = BPEChunker(unk_token="[UNK]", delimiter="") + wp = WordPieceChunker(unk_token="[UNK]", delimiter="") + + proposed_bpe = set( + bpe.propose_tokens(env, trajs, n_tokens_to_add=50, remove_old=False) + ) + proposed_wp = set( + wp.propose_tokens(env, trajs, n_tokens_to_add=50, remove_old=False) + ) + + assert "BCADB" in proposed_bpe + assert "BCADB" in proposed_wp + + +def test_macro_masking(): + device = torch.device("cpu") + trajs = generate_synthetic_corpus(2000, 20, device) + env = trajs.env # SyntheticTokenEnv with letters tokenizer. + + # Add macro keys to env vocab (tuple form for executable macros) + new_ids = env.add_tokens([(2, 3, 1, 4, 2)]) + assert len(new_ids) == 1 + assert env.id_to_token_key[-1] == (2, 3, 1, 4, 2) # BCADB, the new macro-action. + assert env.id_to_token_key[-2] == "" + assert env.id_to_token_key[:26] == list(range(26)) # The alphabet. + + # Macro should be feasible from a generic state (start at s0) + state_0 = env.states_from_tensor(torch.tensor([[0]], device=device)) + env.update_masks(state_0) + macro_id = new_ids[0] # BCADB, allowed. + assert state_0.forward_masks[0, macro_id].item() + + # Macro should be infeasable from a generic state. + new_ids = env.add_tokens([(2, 3, 4, 4, 2)]) # BCDDB, disallowed (no C->D allowed). + state_0 = env.states_from_tensor(torch.tensor([[0]], device=device)) + env.update_masks(state_0) + macro_id = new_ids[0] # BCDDB, disallowed. + assert not state_0.forward_masks[0, macro_id].item() + + +def test_macro_mask_guard_no_recursion_batch_only(): + device = torch.device("cpu") + env = SyntheticTokenEnv(device) + # Simple (B,) batch + B = 2 + states_tensor = torch.zeros((B, 1), dtype=torch.long, device=device) + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + + # Should not recurse + env.update_masks(states) + mask = env.compute_strict_macro_forward_mask(states) + assert mask.shape == (states.batch_shape[0], env.n_actions) + assert mask.dtype == torch.bool + assert getattr(env, "_macro_overlay_depth", 0) == 0 + + +def test_macro_mask_guard_no_recursion_trajectories(): + device = torch.device("cpu") + env = SyntheticTokenEnv(device) + # (T,B) batch + T, B = 3, 2 + states_tensor = torch.zeros((T, B, 1), dtype=torch.long, device=device) + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + + # Should not recurse + env.update_masks(states) + mask = env.compute_strict_macro_forward_mask(states) + assert mask.shape == (T, B, env.n_actions) + assert mask.dtype == torch.bool + assert getattr(env, "_macro_overlay_depth", 0) == 0 + + +def test_horizon_mask_blocks_oversized_macro(): + device = torch.device("cpu") + env = SyntheticTokenEnv(device) + + # Register a length-3 macro + macro_ids = env.add_tokens([(2, 2, 2)]) + macro_id = macro_ids[0] + + # Build (T=3, B=2) states: remaining steps = 3,2,1 at t=0,1,2 + T, B = 3, 2 + states_tensor = torch.zeros((T, B, 1), dtype=torch.long, device=device) + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + env.update_masks(states) + mask = env.compute_strict_macro_forward_mask(states) + + # At t=0: remaining=3 -> macro allowed by horizon check + assert mask[0, :, macro_id].all().item() + + # At t>=1: remaining < 3 -> macro disallowed + assert (~mask[1:, :, macro_id]).all().item() + + +def test_apply_macro_forward_mask_noop_under_guard(): + device = torch.device("cpu") + env = SyntheticTokenEnv(device) + B = 2 + states_tensor = torch.zeros((B, 1), dtype=torch.long, device=device) + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + env.update_masks(states) + fwd_before = states.forward_masks.clone() + + # Manually enter guard + setattr(env, "_macro_overlay_depth", getattr(env, "_macro_overlay_depth", 0) + 1) + try: + env.apply_macro_forward_mask(states) + finally: + setattr(env, "_macro_overlay_depth", getattr(env, "_macro_overlay_depth", 1) - 1) + assert torch.equal(states.forward_masks, fwd_before) From 02856ee220faf307f5c1e70a9600d93dbd9b16ee Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 8 Oct 2025 23:34:40 -0400 Subject: [PATCH 06/27] added dtype casting to preprocessors --- src/gfn/preprocessors.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index 20d63359..b0fbadd2 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -23,15 +23,21 @@ class Preprocessor(ABC): dimension will not be checked. """ - def __init__(self, output_dim: int | None) -> None: + def __init__( + self, output_dim: int | None, target_dtype: torch.dtype | None = None + ) -> None: """Initializes a Preprocessor with the specified output dimension. Args: output_dim: The dimensionality of the preprocessed output tensor, which is compatible with the neural network that will be used. If None, the output dimension will not be checked. + target_dtype: Optional dtype to cast tensor outputs to. When set, any + tensor returned by `preprocess` will be cast to this dtype in + `__call__` before returning. """ self.output_dim = output_dim + self.target_dtype = target_dtype @abstractmethod def preprocess(self, states: States) -> torch.Tensor: @@ -55,8 +61,11 @@ def __call__(self, states: States | GraphStates) -> torch.Tensor | GeometricBatc The preprocessed states as a tensor or GeometricBatch. """ out = self.preprocess(states) - if isinstance(out, torch.Tensor) and self.output_dim is not None: - assert out.shape[-1] == self.output_dim + if isinstance(out, torch.Tensor): + if self.output_dim is not None: + assert out.shape[-1] == self.output_dim + if self.target_dtype is not None and out.dtype != self.target_dtype: + out = out.to(self.target_dtype) return out From 1224bc0d07a0f3adcae5203e2f753c2044233687 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 9 Oct 2025 10:37:09 -0400 Subject: [PATCH 07/27] added vectorized and non-vectorized adapter-based probability calculation paths. simplified the API of adapters. --- docs/source/index.rst | 1 + src/gfn/samplers.py | 318 +++++++++---- src/gfn/utils/handlers.py | 25 +- src/gfn/utils/prob_calculations.py | 539 ++++++++++++++++------ testing/test_probability_calculations.py | 441 ++++++++++++++++++ testing/test_samplers_and_trajectories.py | 38 +- 6 files changed, 1117 insertions(+), 245 deletions(-) create mode 100644 testing/test_probability_calculations.py diff --git a/docs/source/index.rst b/docs/source/index.rst index be9ef88c..e41a3104 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,6 +10,7 @@ guides/example guides/states_actions_containers guides/modules_estimators_samplers + guides/estimator_adapters guides/losses guides/creating_environments guides/advanced diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 95a0e431..0eb6251b 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -10,10 +10,7 @@ from gfn.states import GraphStates, States from gfn.utils.common import ensure_same_device from gfn.utils.graphs import graph_states_share_storage -from gfn.utils.handlers import ( - has_conditioning_exception_handler, - no_conditioning_exception_handler, -) +from gfn.utils.handlers import check_cond_forward from gfn.utils.prob_calculations import get_trajectory_pbs, get_trajectory_pfs @@ -35,9 +32,9 @@ class that implements these members is accepted; no inheritance is required. adapter is responsible for: - initializing ``ctx`` once per rollout in ``init_context`` - updating any internal state (e.g., recurrent ``carry``) during ``compute`` - - recording per‑step artifacts in ``record_step`` (e.g., log_probs, + - recording per‑step artifacts in ``record`` (e.g., log_probs, estimator outputs), typically with mask-aware padding - - producing stacked outputs for ``Trajectories`` in ``finalize`` + Finalization is handled by the rollout context itself. Guidance -------- @@ -51,6 +48,10 @@ class that implements these members is accepted; no inheritance is required. def is_backward(self) -> bool: ... # fmt: skip + @property + def is_vectorized(self) -> bool: + ... # fmt: skip + def init_context( self, batch_size: int, @@ -68,7 +69,7 @@ def compute( ) -> tuple[Distribution, Any]: ... # fmt: skip - def record_step( + def record( self, ctx: Any, step_mask: torch.Tensor, @@ -79,7 +80,14 @@ def record_step( ) -> None: ... # fmt: skip - def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: + def log_prob_of_actions( + self, + states_active: States, + actions_active: torch.Tensor, + ctx: Any, + step_mask: torch.Tensor, + **policy_kwargs: Any, + ) -> tuple[torch.Tensor, Any]: ... # fmt: skip # Optional helper for `sample_actions` BC @@ -87,7 +95,7 @@ def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: ... # fmt: skip -class AdapterContext: +class RolloutContext: """Structured, mutable context owned by adapters. Uses fixed attributes for core fields and an `extras` dict for adapter- @@ -121,6 +129,47 @@ def __init__( self.current_estimator_output: Optional[torch.Tensor] = None self.extras: Dict[str, Any] = {} + def append_step( + self, + step_mask: torch.Tensor, + sampled_actions: torch.Tensor, + dist: Distribution, + save_logprobs: bool, + save_estimator_outputs: bool, + ) -> None: + """Record per-step artifacts into trajectory-level buffers owned by the context.""" + N = self.batch_size + device = self.device + + if save_logprobs: + lp_masked = dist.log_prob(sampled_actions) + if torch.any(torch.isinf(lp_masked)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + step_lp = torch.full((N,), 0.0, device=device) + step_lp[step_mask] = lp_masked + self.trajectory_log_probs.append(step_lp) + + if save_estimator_outputs and self.current_estimator_output is not None: + est_out = self.current_estimator_output + padded = torch.full((N,) + est_out.shape[1:], -float("inf"), device=device) + padded[step_mask] = est_out + self.trajectory_estimator_outputs.append(padded) + + def finalize(self) -> dict[str, Optional[torch.Tensor]]: + """Stack recorded per-step artifacts along time into trajectory-level tensors.""" + log_probs = ( + torch.stack(self.trajectory_log_probs, dim=0) + if self.trajectory_log_probs + else None + ) + estimator_outputs = ( + torch.stack(self.trajectory_estimator_outputs, dim=0) + if self.trajectory_estimator_outputs + else None + ) + + return {"log_probs": log_probs, "estimator_outputs": estimator_outputs} + class DefaultEstimatorAdapter: """Adapter for non-recurrent estimators (current default behavior). @@ -142,8 +191,9 @@ class DefaultEstimatorAdapter: Context Lifecycle (opaque to the Sampler) ---------------------------------------- - The adapter owns an opaque context dict `ctx` which the Sampler never reads. - The context is created once per rollout and mutated in place: + The adapter owns an opaque rollout context `ctx` (see `RolloutContext`) which + the Sampler never reads. The context is created once per rollout and mutated + in place: - init_context(batch_size, device, conditioning) -> ctx Stores rollout invariants and optional conditioning. Also prepares per‑step @@ -156,22 +206,21 @@ class DefaultEstimatorAdapter: 2) Calls the estimator forward pass to obtain the raw `est_out`. 3) Converts `est_out` into a torch Distribution with `to_probability_distribution`. - 4) Saves `est_out` into `ctx["last_est_out"]` so it can optionally be - recorded by `record_step` or exposed to legacy callers of `sample_actions`. + 4) Saves `est_out` into `ctx.current_estimator_output` so it can optionally be + recorded by `record` or exposed to callers that need it. - - record_step(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs) - Materializes optional per‑step artifacts into ctx-managed buffers with + - record(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs) + Materializes optional per‑step artifacts into context‑managed buffers with mask‑aware padding back to the full rollout batch size `N`: * Log‑probs: computes `dist.log_prob(sampled_actions)` for active rows only, - checks for `inf`, then writes into a 1D tensor of shape `(N,)` filled with - zeros and masked assignment for active positions. Appends this to a list - (one tensor per time step). + then writes into a 1D tensor of shape `(N,)` filled with zeros and masked + assignment for active positions. Appends this to a list (one tensor per time step). * Estimator outputs: if requested, pads the last estimator output - (`ctx["last_est_out"]`) to shape `(N, ...)` using `-inf` for inactive rows - and appends to a list (one tensor per time step). This matches existing - conventions elsewhere in the library. + (`ctx.current_estimator_output`) to shape `(N, ...)` using `-inf` for inactive rows + and appends to a list (one tensor per time step). - - finalize(ctx) -> {"log_probs": Tensor | None, "estimator_outputs": Tensor | None} + Finalization is performed by the context itself: + - ctx.finalize() -> {"log_probs": Tensor | None, "estimator_outputs": Tensor | None} Stacks recorded per‑step lists along the time dimension into tensors of shape `(T, N, ...)` suitable for `Trajectories`. Returns `None` for any artifact that was never recorded. @@ -189,6 +238,13 @@ class DefaultEstimatorAdapter: - `is_backward` is forwarded from the underlying estimator so the sampler can choose the appropriate environment transition (forward vs backward). + Vectorized Probability Path + -------------------------- + - `is_vectorized` is used by the Sampler to choose the appropriate probability path. + - Vectorized adapters always use faster paths in probability calculators. + Non-vectorized adapters (e.g., recurrent) use per-step paths with masking and + alignment identical to the legacy reference. + Performance Notes ----------------- - `ctx` is allocated once per rollout and mutated in place to avoid per‑step @@ -211,18 +267,22 @@ def is_backward(self) -> bool: """Whether the wrapped estimator samples in the backward direction.""" return getattr(self._estimator, "is_backward", False) + @property + def is_vectorized(self) -> bool: + return True + def init_context( self, batch_size: int, device: torch.device, conditioning: Optional[torch.Tensor] = None, - ) -> AdapterContext: + ) -> RolloutContext: """Create a new per-rollout context. Stores rollout invariants (batch size, device, optional conditioning) and initializes empty buffers for per-step artifacts. """ - return AdapterContext( + return RolloutContext( batch_size=batch_size, device=device, conditioning=conditioning ) @@ -240,12 +300,10 @@ def compute( optional recording in `record_step`. """ conditioning = ctx.conditioning # type: ignore[attr-defined] - if conditioning is not None: - with has_conditioning_exception_handler("estimator", self._estimator): - est_out = self._estimator(states_active, conditioning[step_mask]) - else: - with no_conditioning_exception_handler("estimator", self._estimator): - est_out = self._estimator(states_active) + cond_active = conditioning[step_mask] if conditioning is not None else None + est_out = check_cond_forward( + self._estimator, "estimator", states_active, cond_active + ) dist = self._estimator.to_probability_distribution( states_active, est_out, **policy_kwargs @@ -254,7 +312,7 @@ def compute( return dist, ctx - def record_step( + def record( self, ctx: Any, step_mask: torch.Tensor, @@ -263,53 +321,66 @@ def record_step( save_logprobs: bool, save_estimator_outputs: bool, ) -> None: - """Record per-step artifacts into the context's trajectory-level lists. + """Record per-step artifacts into the context's trajectory-level lists.""" + # Delegate recording to the rollout context + ctx.append_step( # type: ignore[attr-defined] + step_mask=step_mask, + sampled_actions=sampled_actions, + dist=dist, + save_logprobs=save_logprobs, + save_estimator_outputs=save_estimator_outputs, + ) + + def log_prob_of_actions( + self, + states_active: States, + actions_active: torch.Tensor, + ctx: Any, + step_mask: torch.Tensor, + **policy_kwargs: Any, + ) -> tuple[torch.Tensor, Any]: + # Optional fast path: use caller-provided estimator outputs + precomputed = policy_kwargs.pop("precomputed_estimator_output", None) + if precomputed is not None: + est_out = precomputed + else: + conditioning = ctx.conditioning # type: ignore[attr-defined] + cond_active = conditioning[step_mask] if conditioning is not None else None + est_out = check_cond_forward( + self._estimator, "estimator", states_active, cond_active + ) + dist = self._estimator.to_probability_distribution( + states_active, est_out, **policy_kwargs + ) + ctx.current_estimator_output = est_out # type: ignore[attr-defined] - - If requested, computes log-probs for the active rows and writes them into - a padded vector of shape (N,) before appending to the list - `trajectory_log_probs`. - - If requested, pads `current_estimator_output` to shape (N, ...) and appends - to the list `trajectory_estimator_outputs`. - """ N = ctx.batch_size # type: ignore[attr-defined] device = ctx.device # type: ignore[attr-defined] + lp_masked = dist.log_prob(actions_active) + if torch.any(torch.isinf(lp_masked)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + step_lp = torch.full((N,), 0.0, device=device) + step_lp[step_mask] = lp_masked + return step_lp, ctx + + # Vectorized helper to mirror legacy probability calculations without per-step context + def log_prob_vectorized( + self, + states: States, + actions_tensor: torch.Tensor, + conditioning: Optional[torch.Tensor] = None, + **policy_kwargs: Any, + ) -> torch.Tensor: + """Compute log_prob for a batch of (state, action) pairs in a vectorized way. - if save_logprobs: - lp_masked = dist.log_prob(sampled_actions) - if torch.any(torch.isinf(lp_masked)): - raise RuntimeError("Log probabilities are inf. This should not happen.") - step_lp = torch.full((N,), 0.0, device=device) - step_lp[step_mask] = lp_masked - ctx.trajectory_log_probs.append(step_lp) # type: ignore[attr-defined] - - if ( - save_estimator_outputs - and getattr(ctx, "current_estimator_output", None) is not None - ): - est_out = ctx.current_estimator_output # type: ignore[attr-defined] - padded = torch.full((N,) + est_out.shape[1:], -float("inf"), device=device) - padded[step_mask] = est_out - ctx.trajectory_estimator_outputs.append(padded) # type: ignore[attr-defined] - - def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: - """Stack all recorded per-step artifacts along time into trajectory-level tensors. - - Returns a dict with keys: - - 'log_probs': Tensor of shape (T, N) or None - - 'estimator_outputs': Tensor of shape (T, N, ...) or None + This mirrors the legacy vectorized path used in probability utils and uses + the adapter's estimator and distribution construction, including policy kwargs. """ - log_probs = ( - torch.stack(ctx.trajectory_log_probs, dim=0) - if getattr(ctx, "trajectory_log_probs", []) - else None - ) - estimator_outputs = ( - torch.stack(ctx.trajectory_estimator_outputs, dim=0) - if getattr(ctx, "trajectory_estimator_outputs", []) - else None + est_out = check_cond_forward(self._estimator, "estimator", states, conditioning) + dist = self._estimator.to_probability_distribution( + states, est_out, **policy_kwargs ) - - return {"log_probs": log_probs, "estimator_outputs": estimator_outputs} + return dist.log_prob(actions_tensor) def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: """Expose the most recent per-step estimator output saved during `compute`.""" @@ -317,17 +388,79 @@ def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): - """Adapter for recurrent estimators that require and update a carry.""" + """Adapter for recurrent estimators that require and update a carry. + + Overview + -------- + This adapter extends the default (non‑recurrent) behavior to handle models that + maintain a recurrent state ("carry"). It exposes the same surface as + `DefaultEstimatorAdapter`, with the following differences: + + - `is_vectorized = False`: Probability calculators will use a non‑vectorized + (per‑step) path that mirrors the legacy reference exactly (including masks and + state/action alignment), since a recurrent carry must be updated sequentially. + - The rollout `ctx` stores a `carry` that is initialized once via + `estimator.init_carry(batch_size, device)` and updated at every call to + `compute`/`log_prob_of_actions`. + + Context Lifecycle (opaque to the Sampler) + ---------------------------------------- + The adapter owns an opaque rollout context `ctx` (see `RolloutContext`) which the + Sampler never reads. The context is created once per rollout and mutated in place: + + - init_context(batch_size, device, conditioning) -> ctx + Stores rollout invariants and optional conditioning. Initializes recurrent + `carry` via `estimator.init_carry`. Also prepares per‑step buffers for optional + artifacts (log_probs, estimator_outputs). + + - compute(states_active, ctx, step_mask, **policy_kwargs) -> (dist, ctx) + 1) Calls the recurrent estimator as `(states_active, ctx.carry) -> (est_out, new_carry)` + and stores `new_carry` back into `ctx.carry`. + 2) Converts `est_out` into a torch Distribution with + `to_probability_distribution(states_active, est_out, **policy_kwargs)`. + 3) Saves `est_out` into `ctx.current_estimator_output` for optional recording. + + - record(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs) + Materializes optional per‑step artifacts into context‑managed buffers with + mask‑aware padding back to the full rollout batch size `N`: + * Log‑probs: computes `dist.log_prob(sampled_actions)` for active rows only, + then writes into a 1D tensor of shape `(N,)` filled with zeros and masked + assignment for active positions. Appends this to a list (one tensor per time step). + * Estimator outputs: if requested, pads `ctx.current_estimator_output` to shape + `(N, ...)` using `-inf` for inactive rows and appends to a list. + + Finalization is performed by the context itself: + - ctx.finalize() -> {"log_probs": Tensor | None, "estimator_outputs": Tensor | None} + Stacks recorded per‑step lists along the time dimension into tensors of shape + `(T, N, ...)` suitable for `Trajectories`. Returns `None` for any artifact + that was never recorded. + + Probability Calculators + ----------------------- + Since `is_vectorized = False`, the PF/PB probability calculators use the + non‑vectorized, per‑step path that matches the legacy reference: + - Trajectory PF: `step_mask = ~states.is_sink_state[t] & ~actions.is_dummy[t]`. + - Trajectory PB: actions at time `t` are aligned with states at time `t+1`, and + `step_mask = ~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] + & ~actions.is_dummy[t] & ~actions.is_exit[t]` with `t==0` skipped. + - Transitions: the same legacy masks are used, and a single adapter call is made + per batch. + No mask indexing with action ids is used; distributions handle illegal actions. + """ def __init__(self, estimator: Estimator) -> None: super().__init__(estimator) + @property + def is_vectorized(self) -> bool: + return False + def init_context( self, batch_size: int, device: torch.device, conditioning: Optional[torch.Tensor] = None, - ) -> AdapterContext: + ) -> RolloutContext: """Create context and initialize recurrent carry, (estimator hidden state). Differs from the default adapter by allocating `ctx.carry` via @@ -373,6 +506,31 @@ def compute( return dist, ctx + def log_prob_of_actions( + self, + states_active: States, + actions_active: torch.Tensor, + ctx: Any, + step_mask: torch.Tensor, + **policy_kwargs: Any, + ) -> tuple[torch.Tensor, Any]: + # Recurrent estimators are expected to accept (states, carry) -> (out, new_carry) + est_out, new_carry = self._estimator(states_active, ctx.carry) # type: ignore[attr-defined] + ctx.carry = new_carry # type: ignore[attr-defined] + dist = self._estimator.to_probability_distribution( + states_active, est_out, **policy_kwargs + ) + ctx.current_estimator_output = est_out # type: ignore[attr-defined] + + N = ctx.batch_size # type: ignore[attr-defined] + device = ctx.device # type: ignore[attr-defined] + lp_masked = dist.log_prob(actions_active) + if torch.any(torch.isinf(lp_masked)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + step_lp = torch.full((N,), 0.0, device=device) + step_lp[step_mask] = lp_masked + return step_lp, ctx + class Sampler: """Wrapper for a PolicyEstimator that enables sampling from GFlowNet environments. @@ -464,7 +622,7 @@ def sample_actions( log_probs = None # Allow adapter to record per-step artifacts for callers that reuse ctx. - self.adapter.record_step( + self.adapter.record( ctx=ctx, step_mask=step_mask, sampled_actions=actions_tensor, @@ -585,7 +743,7 @@ def sample_trajectories( valid_actions = env.actions_from_tensor(valid_actions_tensor) # Let adapter record artifacts - self.adapter.record_step( + self.adapter.record( ctx=ctx, step_mask=step_mask, sampled_actions=valid_actions_tensor, @@ -599,9 +757,9 @@ def sample_trajectories( trajectories_actions.append(actions) if self.adapter.is_backward: - new_states = env._backward_step(states, actions) + new_states = env._backward_step(states, actions) # type: ignore[attr-defined] else: - new_states = env._step(states, actions) + new_states = env._step(states, actions) # type: ignore[attr-defined] # Ensure that the new state is a distinct object from the old state. assert new_states is not states @@ -638,8 +796,8 @@ def sample_trajectories( stacked_actions = env.Actions.stack(trajectories_actions)[ 1: ] # Drop dummy action - # Ask adapter for stacked trajectory_artifacts (already shaped (T, N, ...)) - trajectory_artifacts = self.adapter.finalize(ctx) + # Finalize stacked trajectory artifacts from the context (already shaped (T, N, ...)) + trajectory_artifacts = ctx.finalize() # type: ignore[attr-defined] stacked_logprobs = trajectory_artifacts.get("log_probs", None) stacked_estimator_outputs = trajectory_artifacts.get("estimator_outputs", None) diff --git a/src/gfn/utils/handlers.py b/src/gfn/utils/handlers.py index bd494a2a..56877c71 100644 --- a/src/gfn/utils/handlers.py +++ b/src/gfn/utils/handlers.py @@ -1,8 +1,31 @@ import warnings from contextlib import contextmanager -from typing import Any +from typing import Any, Optional + +import torch from gfn.containers import Container +from gfn.estimators import Estimator +from gfn.states import States + + +def check_cond_forward( + module: Estimator, + module_name: str, + states: States, + condition: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Call estimator forward with or without conditioning using error handlers. + + Uses the same exception handling policy as the legacy utility to keep + behavior consistent across adapters and probability code paths. + """ + if condition is not None: + with has_conditioning_exception_handler(module_name, module): + return module(states, condition) + else: + with no_conditioning_exception_handler(module_name, module): + return module(states) @contextmanager diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 7f0f547f..ceb65fb8 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -1,43 +1,11 @@ -from typing import Tuple +import warnings +from typing import Any, Tuple import torch from gfn.containers import Trajectories, Transitions from gfn.estimators import Estimator -from gfn.states import States -from gfn.utils.handlers import ( - has_conditioning_exception_handler, - no_conditioning_exception_handler, -) - - -def check_cond_forward( - module: Estimator, - module_name: str, - states: States, - condition: torch.Tensor | None = None, -) -> torch.Tensor: - """Checks if conditioning is passed and calls the module's forward method accordingly. - - Args: - module: The GFN module to call. - module_name: The name of the module (for error messages). - states: The states to pass to the module. - condition: Optional conditioning tensor to pass to the module. - - Returns: - The output of the module's forward method. - - Raises: - TypeError: If conditioning is passed but the module does not accept it, or vice-versa. - """ - if condition is not None: - with has_conditioning_exception_handler(module_name, module): - return module(states, condition) - else: - with no_conditioning_exception_handler(module_name, module): - return module(states) - +from gfn.utils.handlers import check_cond_forward # ------------ # Trajectories @@ -50,33 +18,66 @@ def get_trajectory_pfs_and_pbs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = True, + pf_adapter: Any | None = None, + pb_adapter: Any | None = None, + **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculates the log probabilities of forward and backward trajectories. + """Calculate PF and PB log-probabilities for trajectories. + + This function delegates to :func:`get_trajectory_pfs` and + :func:`get_trajectory_pbs`, forwarding optional adapter(s) and policy kwargs. + + Vectorized vs non-vectorized + - If the adapter is None or ``adapter.is_vectorized is True``, the legacy + vectorized path is used (fast path, strict parity with legacy code). + - If ``adapter.is_vectorized is False`` (e.g., recurrent), a non‑vectorized + per‑step path is used with legacy-accurate masks and alignment. Args: - pf: The forward policy estimator. - pb: The backward policy estimator, or None if the gflownet DAG is a tree, and - pb is therefore always 1. - trajectories: The trajectories to calculate probabilities for. - fill_value: The value to fill for invalid states (e.g., sink states). - recalculate_all_logprobs: Whether to recalculate log probabilities even if they - already exist in the trajectories object. + pf: Forward policy estimator. + pb: Backward policy estimator, or ``None`` if the DAG is a tree (PB=1). + trajectories: Trajectories container to evaluate. + fill_value: Fill used for invalid states (e.g., sink state positions). + recalculate_all_logprobs: If ``True``, recompute PF even if cached. + pf_adapter: Adapter for PF (vectorized vs non‑vectorized decision). + pb_adapter: OAdapter for PB (vectorized vs non‑vectorized decision). + **policy_kwargs: Extra kwargs passed to estimator's + ``to_probability_distribution`` (e.g., temperature, epsilon, sf_bias). Returns: - A tuple containing two tensors: log_pf_trajectories and log_pb_trajectories. + Tuple[Tensor, Tensor]: + - PF log-probs with shape ``(T, N)`` + - PB log-probs with shape ``(T, N)`` """ # fill value is the value used for invalid states (sink state usually) # uncomment next line for debugging # assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy) + if pb_adapter is not None and not isinstance(pb_adapter, type(pf_adapter)): + warnings.warn( + ( + "type(pb_adapter)={} and type(pf_adapter)={}, this is probably not what you want " + "unless you explicitly want to use different sampling logic for the two policies " + "(with different estimator architectures). This is very uncommon." + ).format(type(pb_adapter), type(pf_adapter)) + ) + log_pf_trajectories = get_trajectory_pfs( pf, trajectories, fill_value=fill_value, recalculate_all_logprobs=recalculate_all_logprobs, + adapter=pf_adapter, + **policy_kwargs, + ) + log_pb_trajectories = get_trajectory_pbs( + pb, + trajectories, + fill_value=fill_value, + adapter=pb_adapter, + **policy_kwargs, ) - log_pb_trajectories = get_trajectory_pbs(pb, trajectories, fill_value=fill_value) return log_pf_trajectories, log_pb_trajectories @@ -86,18 +87,29 @@ def get_trajectory_pfs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = True, + adapter: Any | None = None, + **policy_kwargs: Any, ) -> torch.Tensor: - """Calculates the log probabilities of forward trajectories. + """Calculate PF log-probabilities for trajectories. + + Vectorized vs non-vectorized + - Vectorized when ``adapter is None`` or ``adapter.is_vectorized is True``: + uses the legacy vectorized implementation (strict parity with reference). + - Non‑vectorized when ``adapter.is_vectorized is False``: evaluates per‑step + using legacy masks (PF: ``~states.is_sink_state[t] & ~actions.is_dummy[t]``), + passing the active subset to the adapter without any action‑id mask indexing. Args: - pf: The forward policy estimator. - trajectories: The trajectories to calculate probabilities for. - fill_value: The value to fill for invalid states (e.g., sink states). - recalculate_all_logprobs: Whether to recalculate log probabilities even if they - already exist in the trajectories object. + pf: Forward policy estimator. + trajectories: Trajectories container to evaluate. + fill_value: Fill used for invalid states (e.g., sink state positions). + recalculate_all_logprobs: If ``True``, recompute PF even if cached. + adapter: Optional adapter controlling vectorized vs non‑vectorized path. + **policy_kwargs: Extra kwargs passed to + ``to_probability_distribution`` (e.g., temperature, epsilon). Returns: - A tensor containing the log probabilities of the forward trajectories. + Tensor of shape ``(T, N)`` containing PF log-probabilities. Raises: ValueError: If backward trajectories are provided. @@ -118,38 +130,97 @@ def get_trajectory_pfs( log_pf_trajectories = trajectories.log_probs assert log_pf_trajectories is not None else: - log_pf_trajectories = torch.full_like( - trajectories.actions.tensor[..., 0], - fill_value=fill_value, - dtype=torch.get_default_dtype(), # Floating point dtype. - ) - - if len(valid_states) == 0: - return log_pf_trajectories - - if trajectories.estimator_outputs is not None and not recalculate_all_logprobs: - estimator_outputs = trajectories.estimator_outputs[action_mask] + # Decide vectorized (legacy) vs non-vectorized (adapter per-step) + vectorized = adapter is None or getattr(adapter, "is_vectorized", True) + + if not vectorized: + # Adapter-driven path + N = trajectories.n_trajectories + device = trajectories.states.device + cond = trajectories.conditioning + if cond is not None and len(cond.shape) >= 2: + cond = cond[0] + ctx = adapter.init_context(int(N), device, cond) # type: ignore[arg-type] + + T = trajectories.max_length + log_pf_trajectories = torch.full( + (T, N), + fill_value=fill_value, + dtype=torch.get_default_dtype(), + device=device, + ) + + for t in range(T): + state_ok = ~trajectories.states.is_sink_state[t] + action_ok = ~trajectories.actions.is_dummy[t] + step_mask = state_ok & action_ok + if not torch.any(step_mask): + continue + step_states = trajectories.states[t][step_mask] + step_actions = trajectories.actions.tensor[t][step_mask] + # Optimization: forward cached estimator outputs when available + if ( + trajectories.estimator_outputs is not None + and not recalculate_all_logprobs + ): + precomputed = trajectories.estimator_outputs[t][step_mask] + step_lp, ctx = adapter.log_prob_of_actions( # type: ignore[union-attr] + step_states, + step_actions, + ctx, + step_mask, + precomputed_estimator_output=precomputed, + **policy_kwargs, + ) + else: + step_lp, ctx = adapter.log_prob_of_actions( # type: ignore[union-attr] + step_states, step_actions, ctx, step_mask, **policy_kwargs + ) + if fill_value != 0.0: + padded = torch.full( + (N,), fill_value, device=device, dtype=step_lp.dtype + ) + padded[step_mask] = step_lp[step_mask] + step_lp = padded + log_pf_trajectories[t] = step_lp else: - masked_cond = None - if trajectories.conditioning is not None: - cond_dim = (-1,) * len(trajectories.conditioning.shape) - traj_len = trajectories.states.tensor.shape[0] - masked_cond = trajectories.conditioning.unsqueeze(0).expand( - (traj_len,) + cond_dim - )[state_mask] - - estimator_outputs = check_cond_forward(pf, "pf", valid_states, masked_cond) - - # Calculates the log PF of the actions sampled off policy. - valid_log_pf_actions = pf.to_probability_distribution( - valid_states, estimator_outputs - ).log_prob( - valid_actions.tensor - ) # Using the actions sampled off-policy. - - log_pf_trajectories[action_mask] = valid_log_pf_actions.to( - log_pf_trajectories.dtype, copy=False - ) + # Legacy vectorized path (strict reference behavior) + log_pf_trajectories = torch.full_like( + trajectories.actions.tensor[..., 0], + fill_value=fill_value, + dtype=torch.get_default_dtype(), + ) + + if len(valid_states) == 0: + return log_pf_trajectories + + if ( + trajectories.estimator_outputs is not None + and not recalculate_all_logprobs + ): + # Reuse cached outputs to build the distribution + est_out = trajectories.estimator_outputs[action_mask] + dist = pf.to_probability_distribution( + valid_states, est_out, **policy_kwargs + ) + valid_log_pf_actions = dist.log_prob(valid_actions.tensor) + else: + # Build conditioning per-step shape to align with valid_states + masked_cond = None + if trajectories.conditioning is not None: + cond_dim = (-1,) * len(trajectories.conditioning.shape) + traj_len = trajectories.states.tensor.shape[0] + masked_cond = trajectories.conditioning.unsqueeze(0).expand( + (traj_len,) + cond_dim + )[state_mask] + est_out = check_cond_forward(pf, "pf", valid_states, masked_cond) + valid_log_pf_actions = pf.to_probability_distribution( + valid_states, est_out, **policy_kwargs + ).log_prob(valid_actions.tensor) + + log_pf_trajectories[action_mask] = valid_log_pf_actions.to( + log_pf_trajectories.dtype, copy=False + ) assert log_pf_trajectories.shape == ( trajectories.max_length, @@ -163,17 +234,30 @@ def get_trajectory_pbs( pb: Estimator | None, trajectories: Trajectories, fill_value: float = 0.0, + adapter: Any | None = None, + **policy_kwargs: Any, ) -> torch.Tensor: - """Calculates the log probabilities of backward trajectories. + """Calculate PB log-probabilities for trajectories. + + Vectorized vs non-vectorized + - Vectorized when ``adapter is None`` or ``adapter.is_vectorized is True``: + uses the legacy vectorized implementation (strict parity with reference). + - Non‑vectorized when ``adapter.is_vectorized is False``: evaluates per‑step + using legacy masks/alignment: + PB aligns actions at time ``t`` with states at time ``t+1`` and uses + ``~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] + & ~actions.is_dummy[t] & ~actions.is_exit[t]``, skipping ``t==0``. Args: - pb: The backward policy estimator. - trajectories: The trajectories to calculate probabilities for. - fill_value: The value to fill for invalid states (e.g., sink states). - dtype: The dtype of the log probabilities. + pb: Backward policy estimator, or ``None`` for tree DAGs (PB=1). + trajectories: Trajectories container to evaluate. + fill_value: Fill used for invalid states (e.g., sink state positions). + adapter: Optional adapter controlling vectorized vs non‑vectorized path. + **policy_kwargs: Extra kwargs passed to + ``to_probability_distribution``. Returns: - A tensor containing the log probabilities of the backward trajectories. + Tensor of shape ``(T, N)`` containing PB log-probabilities. Raises: ValueError: If backward trajectories are provided. @@ -213,12 +297,72 @@ def get_trajectory_pbs( # We need to index it with the state_mask to get the valid states masked_cond = trajectories.conditioning[state_mask] - if pb is not None: - estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond) - valid_log_pb_actions = pb.to_probability_distribution( - valid_states, estimator_outputs - ).log_prob(valid_actions.tensor) + # Recurrent adapters are only valid for trajectories and never require pb + from gfn.samplers import RecurrentEstimatorAdapter # type: ignore + + is_recurrent = ( + isinstance(adapter, RecurrentEstimatorAdapter) if adapter is not None else False + ) + + if is_recurrent: + # With recurrent adapter, pb must be None (tree DAG); return zeros + assert pb is None, "When using a RecurrentEstimatorAdapter, pb must be None." + valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) + + elif pb is not None: + # Choose vectorized (legacy) vs non-vectorized (adapter per-step) + # Vectorized path is used by default via DefaultEstimatorAdapter. + vectorized = adapter is None or getattr(adapter, "is_vectorized", True) + if adapter is None: + from gfn.samplers import DefaultEstimatorAdapter # Avoids circular import. + + adapter = DefaultEstimatorAdapter(pb) + + if not vectorized: + # Adapter-driven pb evaluation (non-recurrent) + N = trajectories.n_trajectories + device = trajectories.states.device + cond = trajectories.conditioning + if cond is not None and len(cond.shape) >= 2: + cond = cond[0] + ctx = adapter.init_context(int(N), device, cond) # type: ignore[arg-type] + + T = trajectories.max_length + # Iterate per-step with legacy-complete masking (state at t+1, action at t) + for t in range(T): + state_ok = (~trajectories.states.is_sink_state[t + 1]) & ( + ~trajectories.states.is_initial_state[t + 1] + ) + if t == 0: + # Legacy explicitly disables PB at t=0 + state_ok = torch.zeros_like(state_ok, dtype=torch.bool) + action_ok = (~trajectories.actions.is_dummy[t]) & ( + ~trajectories.actions.is_exit[t] + ) + step_mask = state_ok & action_ok + + if not torch.any(step_mask): + continue + step_states = trajectories.states[t + 1][step_mask] + step_actions = trajectories.actions.tensor[t][step_mask] + step_lp, ctx = adapter.log_prob_of_actions( + step_states, step_actions, ctx, step_mask, **policy_kwargs + ) + padded = torch.full((N,), fill_value, device=device, dtype=step_lp.dtype) + padded[step_mask] = step_lp[step_mask] + log_pb_trajectories[t] = padded + + return log_pb_trajectories + else: + # Legacy vectorized path + estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond) + valid_log_pb_actions = pb.to_probability_distribution( + valid_states, estimator_outputs + ).log_prob(valid_actions.tensor) + log_pb_trajectories[action_mask] = valid_log_pb_actions.to( + log_pb_trajectories.dtype, copy=False + ) else: # If pb is None, we assume that the gflownet DAG is a tree, and therefore # the backward policy probability is always 1 (log probs are 0). @@ -246,19 +390,29 @@ def get_transition_pfs_and_pbs( pb: Estimator | None, transitions: Transitions, recalculate_all_logprobs: bool = True, + pf_adapter: Any | None = None, + pb_adapter: Any | None = None, + **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculates the log probabilities of forward and backward transitions. + """Calculate PF and PB log-probabilities for transitions. + + Vectorized vs non-vectorized mirrors trajectories: + - Vectorized (adapter is None or ``is_vectorized=True``): legacy vectorized path. + - Non‑vectorized (``is_vectorized=False``): per‑batch adapter call with legacy + masks; no action‑id mask indexing. Args: - pf: The forward policy estimator. - pb: The backward policy estimator, or None if the gflownet DAG is a tree, and - pb is therefore always 1. - transitions: The transitions to calculate probabilities for. - recalculate_all_logprobs: Whether to recalculate log probabilities even if they - already exist in the transitions object. + pf: Forward policy estimator. + pb: Backward policy estimator, or ``None`` for tree DAGs (PB=1). + transitions: Transitions container to evaluate. + recalculate_all_logprobs: If ``True``, recompute PF even if cached. + pf_adapter: Optional adapter for PF. + pb_adapter: Optional adapter for PB. + **policy_kwargs: Extra kwargs passed to + ``to_probability_distribution``. Returns: - A tuple containing two tensors: log_pf_transitions and log_pb_transitions. + Tuple[Tensor, Tensor]: PF and PB log-probabilities of shape ``(M,)``. Raises: ValueError: If backward transitions are provided. @@ -266,8 +420,21 @@ def get_transition_pfs_and_pbs( if transitions.is_backward: raise ValueError("Backward transitions are not supported") - log_pf_transitions = get_transition_pfs(pf, transitions, recalculate_all_logprobs) - log_pb_transitions = get_transition_pbs(pb, transitions) + if pb_adapter is not None and not isinstance(pb_adapter, type(pf_adapter)): + warnings.warn( + ( + "type(pb_adapter)={} and type(pf_adapter)={}, this is probably not what you want " + "unless you explicitly want to use different sampling logic for the two policies " + "(with different estimator architectures). This is very uncommon." + ).format(type(pb_adapter), type(pf_adapter)) + ) + + log_pf_transitions = get_transition_pfs( + pf, transitions, recalculate_all_logprobs, adapter=pf_adapter, **policy_kwargs + ) + log_pb_transitions = get_transition_pbs( + pb, transitions, adapter=pb_adapter, **policy_kwargs + ) assert log_pf_transitions.shape == (transitions.n_transitions,) assert log_pb_transitions.shape == (transitions.n_transitions,) @@ -276,18 +443,30 @@ def get_transition_pfs_and_pbs( def get_transition_pfs( - pf: Estimator, transitions: Transitions, recalculate_all_logprobs: bool = True + pf: Estimator, + transitions: Transitions, + recalculate_all_logprobs: bool = True, + adapter: Any | None = None, + **policy_kwargs: Any, ) -> torch.Tensor: - """Calculates the log probabilities of forward transitions. + """Calculate PF log-probabilities for transitions. + + Vectorized vs non-vectorized + - Vectorized when ``adapter is None`` or ``adapter.is_vectorized is True``: + legacy vectorized path. + - Non‑vectorized when ``adapter.is_vectorized is False``: single adapter call + with legacy masks and no action‑id indexing. Args: - pf: The forward policy estimator. - transitions: The transitions to calculate probabilities for. - recalculate_all_logprobs: Whether to recalculate log probabilities even if they - already exist in the transitions object. + pf: Forward policy estimator. + transitions: Transitions container to evaluate. + recalculate_all_logprobs: If ``True``, recompute PF even if cached. + adapter: Optional adapter controlling vectorized vs non‑vectorized path. + **policy_kwargs: Extra kwargs passed to + ``to_probability_distribution``. Returns: - A tensor containing the log probabilities of the forward transitions. + Tensor of shape ``(M,)`` containing PF log-probabilities. """ states = transitions.states actions = transitions.actions @@ -296,40 +475,76 @@ def get_transition_pfs( log_pf_actions = transitions.log_probs assert log_pf_actions is not None else: - # Evaluate the log PF of the actions, with optional conditioning. - # TODO: Inefficient duplication in case of tempered policy - # The Transitions container should then have some - # estimator_outputs attribute as well, to avoid duplication here ? - # See (#156). - estimator_outputs = check_cond_forward( - pf, "pf", states, transitions.conditioning - ) - - log_pf_actions = pf.to_probability_distribution( - states, estimator_outputs - ).log_prob(actions.tensor) + if adapter is not None or True: + from gfn.samplers import RecurrentEstimatorAdapter # type: ignore + + if adapter is None: + from gfn.samplers import DefaultEstimatorAdapter # type: ignore + + adapter = DefaultEstimatorAdapter(pf) + elif isinstance(adapter, RecurrentEstimatorAdapter): + raise TypeError( + "RecurrentEstimatorAdapter is only supported for Trajectories" + ) + assert adapter is not None + + N = transitions.n_transitions + device = transitions.states.device + cond = transitions.conditioning + ctx = adapter.init_context(int(N), device, cond) + mask = torch.ones(N, dtype=torch.bool, device=device) + + # Evaluate the log PF of the actions, with optional conditioning. + # TODO: Inefficient duplication in case of tempered policy + # The Transitions container should then have some + # estimator_outputs attribute as well, to avoid duplication here ? + # See (#156). + step_lp, _ = adapter.log_prob_of_actions( + states[mask], + actions.tensor[mask], + ctx, + mask, + **policy_kwargs, + ) + log_pf_actions = step_lp return log_pf_actions -def get_transition_pbs(pb: Estimator | None, transitions: Transitions) -> torch.Tensor: - """Calculates the log probabilities of backward transitions. +def get_transition_pbs( + pb: Estimator | None, + transitions: Transitions, + adapter: Any | None = None, + **policy_kwargs: Any, +) -> torch.Tensor: + """Calculate PB log-probabilities for transitions. + + Vectorized vs non-vectorized + - Vectorized when ``adapter is None`` or ``adapter.is_vectorized is True``: + legacy vectorized path. + - Non‑vectorized when ``adapter.is_vectorized is False``: single adapter call + with legacy masks and no action‑id indexing. Args: - pb: The backward policy Estimator, or None if the gflownet DAG is a tree, and - pb is therefore always 1. - transitions: The transitions to calculate probabilities for. + pb: Backward policy estimator, or ``None`` for tree DAGs (PB=1). + transitions: Transitions container to evaluate. + adapter: Optional adapter controlling vectorized vs non‑vectorized path. + **policy_kwargs: Extra kwargs passed to + ``to_probability_distribution``. + + Returns: + Tensor of shape ``(M,)`` containing PB log-probabilities. """ - # automatically removes invalid transitions (i.e. s_f -> s_f) - valid_next_states = transitions.next_states[~transitions.is_terminating] - non_exit_actions = transitions.actions[~transitions.actions.is_exit] - - # Evaluate the log PB of the actions, with optional conditioning. - masked_cond = ( - transitions.conditioning[~transitions.is_terminating] - if transitions.conditioning is not None - else None - ) + # # automatically removes invalid transitions (i.e. s_f -> s_f) + # valid_next_states = transitions.next_states[~transitions.is_terminating] + # non_exit_actions = transitions.actions[~transitions.actions.is_exit] + + # # Evaluate the log PB of the actions, with optional conditioning. + # masked_cond = ( + # transitions.conditioning[~transitions.is_terminating] + # if transitions.conditioning is not None + # else None + # ) # TODO: We support a fill_value for trajectories, but not for transitions. # Should we add it here, or remove it for trajectories? @@ -337,17 +552,45 @@ def get_transition_pbs(pb: Estimator | None, transitions: Transitions) -> torch. (transitions.n_transitions,), device=transitions.states.device ) - # If pb is None, we assume that the gflownet DAG is a tree, and therefore - # the backward policy probability is always 1 (log probs are 0). - if pb is not None: - estimator_outputs = check_cond_forward(pb, "pb", valid_next_states, masked_cond) + if adapter is not None or True: + from gfn.samplers import RecurrentEstimatorAdapter # type: ignore + + if adapter is None and pb is not None: + from gfn.samplers import DefaultEstimatorAdapter # type: ignore - # Evaluate the log PB of the actions. - valid_log_pb_actions = pb.to_probability_distribution( - valid_next_states, estimator_outputs - ).log_prob(non_exit_actions.tensor) + adapter = DefaultEstimatorAdapter(pb) + elif isinstance(adapter, RecurrentEstimatorAdapter): + raise TypeError( + "RecurrentEstimatorAdapter is only supported for Trajectories" + ) + assert adapter is not None - if len(valid_next_states) != 0: - log_pb_actions[~transitions.is_terminating] = valid_log_pb_actions + # If pb is None, we assume that the gflownet DAG is a tree, and therefore + # the backward policy probability is always 1 (log probs are 0). + if pb is None: + return log_pb_actions + + N = transitions.n_transitions + device = transitions.states.device + cond = transitions.conditioning + ctx = adapter.init_context(int(N), device, cond) + # Legacy-complete masking for PB on transitions: + # require non-terminating next_states and non-exit actions simultaneously + # automatically removes invalid transitions (i.e. s_f -> s_f) + state_ok = ~transitions.is_terminating + action_ok = ~transitions.actions.is_exit + mask = state_ok & action_ok + + if not torch.any(mask): + return log_pb_actions + + step_lp, _ = adapter.log_prob_of_actions( + transitions.next_states[mask], + transitions.actions.tensor[mask], + ctx, + mask, + **policy_kwargs, + ) + log_pb_actions[mask] = step_lp[mask] return log_pb_actions diff --git a/testing/test_probability_calculations.py b/testing/test_probability_calculations.py new file mode 100644 index 00000000..e71a7825 --- /dev/null +++ b/testing/test_probability_calculations.py @@ -0,0 +1,441 @@ +import pytest +import torch + +from gfn.estimators import DiscretePolicyEstimator +from gfn.gym import HyperGrid +from gfn.preprocessors import IdentityPreprocessor +from gfn.samplers import DefaultEstimatorAdapter, Sampler +from gfn.utils.handlers import check_cond_forward +from gfn.utils.prob_calculations import ( + get_trajectory_pbs, + get_trajectory_pfs, + get_transition_pbs, + get_transition_pfs, +) + + +class NonVectorizedDefaultAdapter(DefaultEstimatorAdapter): + @property + def is_vectorized(self) -> bool: # type: ignore[override] + return False + + +def _legacy_get_trajectory_pfs( + pf: DiscretePolicyEstimator, + trajectories, + *, + fill_value: float = 0.0, + recalculate_all_logprobs: bool = True, +): + if trajectories.is_backward: + raise ValueError("Backward trajectories are not supported") + + state_mask = ~trajectories.states.is_sink_state + action_mask = ~trajectories.actions.is_dummy + + valid_states = trajectories.states[state_mask] + valid_actions = trajectories.actions[action_mask] + + if valid_states.batch_shape != valid_actions.batch_shape: + raise AssertionError("Something wrong happening with log_pf evaluations") + + log_pf_trajectories = torch.full_like( + trajectories.actions.tensor[..., 0], + fill_value=fill_value, + dtype=torch.get_default_dtype(), + ) + + if len(valid_states) == 0: + return log_pf_trajectories + + if trajectories.estimator_outputs is not None and not recalculate_all_logprobs: + estimator_outputs = trajectories.estimator_outputs[action_mask] + else: + masked_cond = None + if trajectories.conditioning is not None: + cond_dim = (-1,) * len(trajectories.conditioning.shape) + traj_len = trajectories.states.tensor.shape[0] + masked_cond = trajectories.conditioning.unsqueeze(0).expand( + (traj_len,) + cond_dim + )[state_mask] + + estimator_outputs = check_cond_forward(pf, "pf", valid_states, masked_cond) + + valid_log_pf_actions = pf.to_probability_distribution( + valid_states, estimator_outputs + ).log_prob(valid_actions.tensor) + + log_pf_trajectories[action_mask] = valid_log_pf_actions.to( + log_pf_trajectories.dtype, copy=False + ) + + assert log_pf_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) + return log_pf_trajectories + + +def _build_env_and_pf(n: int = 4): + env = HyperGrid(ndim=2, height=4) + preprocessor = IdentityPreprocessor( + output_dim=env.state_shape[-1], target_dtype=torch.get_default_dtype() + ) + pf_module = torch.nn.Sequential( + torch.nn.Linear(preprocessor.output_dim, 16), # type: ignore + torch.nn.ReLU(), + torch.nn.Linear(16, env.n_actions), + ) + pf_estimator = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=preprocessor, + ) + sampler = Sampler(estimator=pf_estimator) + + return env, pf_estimator, sampler + + +@pytest.mark.parametrize("use_cached_outputs", [True, False]) +def test_get_trajectory_pfs_matches_legacy_with_default_adapter( + use_cached_outputs: bool, +): + env, pf_estimator, sampler = _build_env_and_pf() + + trajectories = sampler.sample_trajectories( + env, + n=5, + save_estimator_outputs=use_cached_outputs, + save_logprobs=False, + ) + + # Legacy calculation + legacy = _legacy_get_trajectory_pfs( + pf_estimator, + trajectories, + fill_value=0.0, + recalculate_all_logprobs=not use_cached_outputs, + ) + + # Adapter-backed calculation + adapter = DefaultEstimatorAdapter(pf_estimator) + modern = get_trajectory_pfs( + pf_estimator, + trajectories, + fill_value=0.0, + recalculate_all_logprobs=not use_cached_outputs, + adapter=adapter, + ) + + torch.testing.assert_close(modern, legacy) + + +def _legacy_get_trajectory_pbs( + pb: DiscretePolicyEstimator | None, + trajectories, + *, + fill_value: float = 0.0, +): + if trajectories.is_backward: + raise ValueError("Backward trajectories are not supported") + + log_pb_trajectories = torch.full_like( + trajectories.actions.tensor[..., 0], + fill_value=fill_value, + dtype=torch.get_default_dtype(), + ) + + state_mask = ( + ~trajectories.states.is_sink_state & ~trajectories.states.is_initial_state + ) + state_mask[0, :] = False + action_mask = ~trajectories.actions.is_dummy & ~trajectories.actions.is_exit + + valid_states = trajectories.states[state_mask] + valid_actions = trajectories.actions[action_mask] + + if valid_states.batch_shape != valid_actions.batch_shape: + raise AssertionError("Something wrong happening with log_pf evaluations") + + if len(valid_states) == 0: + return log_pb_trajectories + + masked_cond = None + if trajectories.conditioning is not None: + masked_cond = trajectories.conditioning[state_mask] + + if pb is not None: + estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond) + valid_log_pb_actions = pb.to_probability_distribution( + valid_states, estimator_outputs + ).log_prob(valid_actions.tensor) + else: + valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) + + log_pb_trajectories[action_mask] = valid_log_pb_actions.to( + log_pb_trajectories.dtype, copy=False + ) + + assert log_pb_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) + return log_pb_trajectories + + +def _build_env_pf_pb(): + env = HyperGrid(ndim=2, height=4) + preprocessor = IdentityPreprocessor( + output_dim=env.state_shape[-1], target_dtype=torch.get_default_dtype() + ) + pf_module = torch.nn.Sequential( + torch.nn.Linear(preprocessor.output_dim, 16), # type: ignore + torch.nn.ReLU(), + torch.nn.Linear(16, env.n_actions), + ) + pb_module = torch.nn.Sequential( + torch.nn.Linear(preprocessor.output_dim, 16), # type: ignore + torch.nn.ReLU(), + torch.nn.Linear(16, env.n_actions - 1), + ) + pf_estimator = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=preprocessor, + ) + pb_estimator = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=preprocessor, + ) + pf_sampler = Sampler(estimator=pf_estimator) + return env, pf_estimator, pb_estimator, pf_sampler + + +def test_get_trajectory_pbs_matches_legacy_with_default_adapter(): + env, _, pb_estimator, pf_sampler = _build_env_pf_pb() + + trajectories = pf_sampler.sample_trajectories( + env, + n=6, + save_estimator_outputs=False, + save_logprobs=False, + ) + + legacy = _legacy_get_trajectory_pbs(pb_estimator, trajectories, fill_value=0.0) + + adapter = DefaultEstimatorAdapter(pb_estimator) + modern = get_trajectory_pbs( + pb_estimator, + trajectories, + fill_value=0.0, + adapter=adapter, + ) + + torch.testing.assert_close(modern, legacy) + + +@pytest.mark.parametrize("use_cached_outputs", [True, False]) +def test_trajectory_pf_vectorized_vs_nonvectorized_parity(use_cached_outputs: bool): + env, pf_estimator, sampler = _build_env_and_pf() + + trajectories = sampler.sample_trajectories( + env, + n=5, + save_estimator_outputs=use_cached_outputs, + save_logprobs=False, + ) + + # Vectorized (legacy) path: adapter None triggers vectorized + vec = get_trajectory_pfs( + pf_estimator, + trajectories, + recalculate_all_logprobs=not use_cached_outputs, + adapter=None, + ) + + # Non-vectorized path: force via NonVectorizedDefaultAdapter + nvec = get_trajectory_pfs( + pf_estimator, + trajectories, + recalculate_all_logprobs=not use_cached_outputs, + adapter=NonVectorizedDefaultAdapter(pf_estimator), + ) + + torch.testing.assert_close(vec, nvec) + + +def test_trajectory_pb_vectorized_vs_nonvectorized_parity(): + env, _, pb_estimator, pf_sampler = _build_env_pf_pb() + + trajectories = pf_sampler.sample_trajectories( + env, + n=6, + save_estimator_outputs=False, + save_logprobs=False, + ) + + # Vectorized + vec = get_trajectory_pbs(pb_estimator, trajectories, adapter=None) + # Non-vectorized forced + nvec = get_trajectory_pbs( + pb_estimator, trajectories, adapter=NonVectorizedDefaultAdapter(pb_estimator) + ) + + torch.testing.assert_close(vec, nvec) + + +def test_transition_pf_vectorized_vs_nonvectorized_parity(): + env, pf_estimator, _, pf_sampler = _build_env_pf_pb() + trajectories = pf_sampler.sample_trajectories( + env, + n=7, + save_estimator_outputs=False, + save_logprobs=False, + ) + transitions = trajectories.to_transitions() + + vec = get_transition_pfs( + pf_estimator, transitions, recalculate_all_logprobs=True, adapter=None + ) + nvec = get_transition_pfs( + pf_estimator, + transitions, + recalculate_all_logprobs=True, + adapter=NonVectorizedDefaultAdapter(pf_estimator), + ) + torch.testing.assert_close(vec, nvec) + + +def test_transition_pb_vectorized_vs_nonvectorized_parity(): + env, _, pb_estimator, pf_sampler = _build_env_pf_pb() + trajectories = pf_sampler.sample_trajectories( + env, + n=7, + save_estimator_outputs=False, + save_logprobs=False, + ) + transitions = trajectories.to_transitions() + + vec = get_transition_pbs(pb_estimator, transitions, adapter=None) + nvec = get_transition_pbs( + pb_estimator, transitions, adapter=NonVectorizedDefaultAdapter(pb_estimator) + ) + torch.testing.assert_close(vec, nvec) + + +def test_adapter_log_prob_of_actions_precomputed_matches_forward(): + env, pf_estimator, _ = _build_env_and_pf() + states = env.reset(batch_shape=(5,)) + + # Compute estimator outputs once (precomputed path) + est_out = check_cond_forward(pf_estimator, "pf", states, None) + dist = pf_estimator.to_probability_distribution(states, est_out) + with torch.no_grad(): + actions_tensor = dist.sample() + + adapter = DefaultEstimatorAdapter(pf_estimator) + ctx1 = adapter.init_context(batch_size=5, device=states.device, conditioning=None) + ctx2 = adapter.init_context(batch_size=5, device=states.device, conditioning=None) + step_mask = torch.ones(5, dtype=torch.bool, device=states.device) + + # Baseline: adapter recomputes estimator outputs internally + lp1, _ = adapter.log_prob_of_actions(states, actions_tensor, ctx1, step_mask) + + # Precomputed: adapter uses provided estimator outputs (fast path) + lp2, _ = adapter.log_prob_of_actions( + states, actions_tensor, ctx2, step_mask, precomputed_estimator_output=est_out + ) + + torch.testing.assert_close(lp1, lp2) + + +def _legacy_get_transition_pfs( + pf: DiscretePolicyEstimator, + transitions, + *, + recalculate_all_logprobs: bool = False, +): + states = transitions.states + actions = transitions.actions + + if transitions.has_log_probs and recalculate_all_logprobs is False: + log_pf_actions = transitions.log_probs + assert log_pf_actions is not None + return log_pf_actions + + estimator_outputs = check_cond_forward(pf, "pf", states, transitions.conditioning) + log_pf_actions = pf.to_probability_distribution(states, estimator_outputs).log_prob( + actions.tensor + ) + return log_pf_actions + + +def _legacy_get_transition_pbs(pb: DiscretePolicyEstimator | None, transitions): + valid_next_states = transitions.next_states[~transitions.is_terminating] + non_exit_actions = transitions.actions[~transitions.actions.is_exit] + masked_cond = ( + transitions.conditioning[~transitions.is_terminating] + if transitions.conditioning is not None + else None + ) + + log_pb_actions = torch.zeros( + (transitions.n_transitions,), device=transitions.states.device + ) + + if pb is not None: + estimator_outputs = check_cond_forward(pb, "pb", valid_next_states, masked_cond) + valid_log_pb_actions = pb.to_probability_distribution( + valid_next_states, estimator_outputs + ).log_prob(non_exit_actions.tensor) + if len(valid_next_states) != 0: + log_pb_actions[~transitions.is_terminating] = valid_log_pb_actions + + return log_pb_actions + + +def test_get_transition_pfs_matches_legacy_with_default_adapter(): + env, pf_estimator, _, pf_sampler = _build_env_pf_pb() + trajectories = pf_sampler.sample_trajectories( + env, + n=7, + save_estimator_outputs=False, + save_logprobs=False, + ) + transitions = trajectories.to_transitions() + + legacy = _legacy_get_transition_pfs(pf_estimator, transitions) + modern = get_transition_pfs( + pf_estimator, + transitions, + recalculate_all_logprobs=True, + adapter=DefaultEstimatorAdapter(pf_estimator), + ) + torch.testing.assert_close(modern, legacy) + + +def test_get_transition_pbs_matches_legacy_with_default_adapter(): + env, _, pb_estimator, pf_sampler = _build_env_pf_pb() + trajectories = pf_sampler.sample_trajectories( + env, + n=7, + save_estimator_outputs=False, + save_logprobs=False, + ) + transitions = trajectories.to_transitions() + + legacy = _legacy_get_transition_pbs(pb_estimator, transitions) + modern = get_transition_pbs( + pb_estimator, + transitions, + adapter=DefaultEstimatorAdapter(pb_estimator), + ) + torch.testing.assert_close(modern, legacy) + + +if __name__ == "__main__": + test_get_trajectory_pbs_matches_legacy_with_default_adapter() diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index a5cb0b91..71862dae 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -21,10 +21,10 @@ OneHotPreprocessor, ) from gfn.samplers import ( - AdapterContext, DefaultEstimatorAdapter, LocalSearchSampler, RecurrentEstimatorAdapter, + RolloutContext, Sampler, ) from gfn.states import States @@ -514,8 +514,8 @@ def to_probability_distribution( # no expected_output_dim required for adapter tests -def test_adapter_context_basic(): - ctx = AdapterContext(batch_size=4, device=torch.device("cpu"), conditioning=None) +def test_rollout_context_basic(): + ctx = RolloutContext(batch_size=4, device=torch.device("cpu"), conditioning=None) assert ctx.batch_size == 4 assert ctx.device.type == "cpu" # extras supports arbitrary entries @@ -533,10 +533,10 @@ def test_default_adapter_compute_record_finalize(): step_mask = torch.ones(n, dtype=torch.bool, device=device) dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) actions = dist.sample() - adapter.record_step( + adapter.record( ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True ) - out = adapter.finalize(ctx) + out = ctx.finalize() assert out["log_probs"] is not None and out["log_probs"].shape == (1, n) assert out["estimator_outputs"] is not None and out["estimator_outputs"].shape[ :2 @@ -564,18 +564,18 @@ def test_recurrent_adapter_flow(): actions = dist.sample() # carry should update when we record multiple steps h0 = ctx.carry["hidden"].clone() - adapter.record_step( + adapter.record( ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True ) # second step dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) actions = dist.sample() - adapter.record_step( + adapter.record( ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True ) h1 = ctx.carry["hidden"].clone() assert torch.all(h1 == h0 + 1) - out = adapter.finalize(ctx) + out = ctx.finalize() assert out["log_probs"] is not None and out["log_probs"].shape == (2, n) assert out["estimator_outputs"] is not None and out["estimator_outputs"].shape[ :2 @@ -641,7 +641,7 @@ def test_integration_recurrent_sequence_model_with_adapter( for _ in range(2): dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) actions = dist.sample() - adapter.record_step( + adapter.record( ctx, step_mask, actions, @@ -650,11 +650,13 @@ def test_integration_recurrent_sequence_model_with_adapter( save_estimator_outputs=True, ) - out = adapter.finalize(ctx) - assert out["log_probs"] is not None and out["log_probs"].shape[0] == 2 - assert ( - out["estimator_outputs"] is not None and out["estimator_outputs"].shape[0] == 2 - ) + out = ctx.finalize() + log_probs = out["log_probs"] + estimator_outputs = out["estimator_outputs"] + assert log_probs is not None + assert log_probs.shape[0] == 2 + assert estimator_outputs is not None + assert estimator_outputs.shape[0] == 2 @pytest.mark.parametrize("positional_embedding", ["learned", "sinusoidal"]) @@ -694,12 +696,16 @@ def test_integration_transformer_sequence_model_with_adapter( step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) actions = dist.sample() - adapter.record_step( + adapter.record( ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True ) - out = adapter.finalize(ctx) + out = ctx.finalize() assert out["log_probs"] is not None and out["log_probs"].shape[0] == 1 assert ( out["estimator_outputs"] is not None and out["estimator_outputs"].shape[0] == 1 ) + + +if __name__ == "__main__": + test_to_transition("DiscreteEBM") From 34202a85012baa819e15695fcfc2b4242920b739 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 9 Oct 2025 10:52:09 -0400 Subject: [PATCH 08/27] added documentation --- docs/source/guides/estimator_adapters.md | 121 +++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 docs/source/guides/estimator_adapters.md diff --git a/docs/source/guides/estimator_adapters.md b/docs/source/guides/estimator_adapters.md new file mode 100644 index 00000000..d5971886 --- /dev/null +++ b/docs/source/guides/estimator_adapters.md @@ -0,0 +1,121 @@ +# Estimator Adapters + +Adapters decouple the generic sampling and probability computation logic from estimator-specific details (conditioning shape, recurrent state/carry, distribution construction, artifact recording). They enable a single sampler and probability utilities to work across different estimator families. + +This guide explains: +- The Adapter protocol and lifecycle +- Vectorized vs non-vectorized probability paths +- How adapters integrate with the Sampler and probability calculators +- How to implement a new Adapter + +## Concepts and Goals + +An Adapter mediates between three places where estimator logic is needed: +1) The online sampling loop (Sampler) for trajectory rollouts +2) Probability calculators for trajectories (PF/PB) and transitions (PF/PB) +3) Optional artifact capture (per-step log-probs, estimator outputs) + +The Sampler remains estimator-agnostic. Adapters own any estimator-specific state (e.g., recurrent carry) and control how to run the estimator and build the policy distribution. + +## Adapter Protocol (call signature) + +Adapters conform to a structural Protocol with the following surface (see `gfn/samplers.py`): + +- Properties + - `is_backward: bool` — whether the wrapped estimator is a backward policy. + - `is_vectorized: bool` — whether the adapter supports vectorized probability calculations (no carry). Vectorized adapters always use the faster legacy vectorized paths in probability calculators. Non-vectorized adapters (e.g., recurrent) use per-step paths with masking and alignment identical to the legacy reference. + +- Methods + - `init_context(batch_size: int, device: torch.device, conditioning: Tensor|None) -> Any` + - Allocates a rollout context once per batch (Sampler). Stores invariants (batch size, device, optional conditioning) and initializes any adapter state (e.g., recurrent carry) along with per-step artifact buffers. + + - `compute(states_active: States, ctx: Any, step_mask: Tensor, **policy_kwargs) -> (Distribution, Any)` + - Runs the estimator forward on the active rows and returns a torch Distribution over actions. + - Must handle conditioning slicing with `step_mask` when applicable. + + - `record(ctx: Any, step_mask: Tensor, sampled_actions: Tensor, dist: Distribution, save_logprobs: bool, save_estimator_outputs: bool) -> None` + - Optionally record per-step artifacts into buffers owned by the context (e.g., log-probs, estimator outputs). Padding back to batch size happens here, using zeros for log-probs and `-inf` for estimator outputs to match existing conventions. + + - `log_prob_of_actions(states_active: States, actions_active: Tensor, ctx: Any, step_mask: Tensor, **policy_kwargs) -> (Tensor, Any)` + - Computes log-probs for a batch of (state, action) pairs corresponding to `True` entries of `step_mask` and returns a padded `(N,)` vector. + + - `get_current_estimator_output(ctx: Any) -> Tensor|None` + - Convenience to expose the last estimator output after `compute`. + +- Context + - The rollout context (created by `init_context`) owns: + - `batch_size`, `device`, optional `conditioning` + - Optional `carry` (recurrent hidden state) + - Per-step buffers: `trajectory_log_probs`, `trajectory_estimator_outputs` + - A `finalize()` method that stacks per-step artifacts into `(T, N, ...)` tensors for `Trajectories`. + +## Built-in Adapters + +- `DefaultEstimatorAdapter` + - `is_vectorized = True` + - No carry. Works with both the Sampler and vectorized probability calculators. + - In the Sampler, it slices conditioning by `step_mask`, runs the estimator, builds the Distribution, and optionally records artifacts. + +- `RecurrentEstimatorAdapter` + - `is_vectorized = False` + - Maintains a `carry` in the context (initialized via `estimator.init_carry(batch_size, device)`). + - In the Sampler, it calls the estimator as `(states_active, ctx.carry) -> (est_out, new_carry)`, stores `new_carry`, builds the Distribution, and optionally records artifacts. + +## Vectorized vs Non-Vectorized Probability Paths + +Probability calculators (PF/PB for trajectories and transitions) branch on `adapter.is_vectorized`: + +- Vectorized (fast path) + - Used when `adapter is None` or `adapter.is_vectorized is True`. + - Implements the legacy vectorized logic exactly (see the reference implementation below). + - No adapter calls are needed; the estimator is called on vectorized masks, and distributions compute `log_prob` over the masked actions. This path is the most efficient and is used during training when possible. + +- Non-Vectorized (per-step path) + - Used when `adapter.is_vectorized is False` (e.g., recurrent adapters). + - The calculators iterate per-step with legacy-accurate masks and alignment: + - PF (trajectories): `step_mask = ~states.is_sink_state[t] & ~actions.is_dummy[t]` + - PB (trajectories): align actions at time `t` with states at time `t+1`, and use `step_mask = ~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] & ~actions.is_dummy[t] & ~actions.is_exit[t]` and skip `t==0`. + - Transitions: use the same masks as the legacy vectorized functions and make a single adapter call per batch. + - No mask indexing with action ids is used; masking is solely via the legacy boolean masks and the Distribution handles illegal actions internally. + +In both branches, behavior matches the legacy reference exactly, so tests compare outputs between vectorized and non-vectorized paths for parity. + +## Integration with the Sampler + +The Sampler uses the adapter lifecycle: +- `ctx = adapter.init_context(batch_size, device, conditioning)` +- While some trajectories are active: + - `(dist, ctx) = adapter.compute(states[step_mask], ctx, step_mask, **policy_kwargs)` + - Sample actions from `dist`; build actions for the full batch + - `adapter.record(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs)` + - Step the environment forward/backward based on `adapter.is_backward` +- After rollout: `artifacts = ctx.finalize()` and populate `Trajectories`. + +## How to Implement a New Adapter + +1) Decide if your estimator needs a recurrent carry - some persistent state or cache that is utilized throughout the trajectory. + - If yes, set `is_vectorized = False` and implement `init_context` to initialize `carry`. Implement `compute` to update `carry` each step. + - If no, set `is_vectorized = True` and follow the default adapter pattern. + +2) Implement `compute` + - Handle conditioning slicing with `step_mask` when conditioning is provided. + - Call your estimator and construct a torch Distribution via `to_probability_distribution(states_active, est_out, **policy_kwargs)`. + +3) Implement `record` + - If you want to capture per-step log-probs and/or estimator outputs, compute them for active rows and pad back to `(N,)` (log-probs) or `(N, ...)` (estimator outputs) before appending to the context buffers. + +4) Implement `log_prob_of_actions` + - Given `(states_active, actions_active)` for the active rows, compute the Distribution (reusing the same forward logic) and return a padded `(N,)` vector of `log_prob`. + - Do not modify masks here; calculators pass in `step_mask` already built from existing masks. + +5) Mark `is_backward` if your estimator is a backward policy; the sampler will step the environment backward accordingly. + +6) Performance Guidance + - For vectorized adapters, prefer the vectorized probability path (legacy implementation). It’s much faster and avoids per-step overhead. + - For non-vectorized adapters, keep per-step code minimal and avoid Python-side loops that can be vectorized. + +## Reference: Legacy Implementations + +The legacy, vectorized implementations are the gold standard. Adapters are designed to use those paths whenever possible (vectorized) and to exactly match their behavior when per-step evaluation is required (non-vectorized). See the reference for details: + +- `utils/prob_calculations.py` (master): [link](https://raw.githubusercontent.com/GFNOrg/torchgfn/refs/heads/master/src/gfn/utils/prob_calculations.py) \ No newline at end of file From d1db3bdfe1b6fe23f1fb83ef6a4d7625ef189822 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 9 Oct 2025 11:18:59 -0400 Subject: [PATCH 09/27] removed strange change to documentation --- docs/source/guides/example.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/guides/example.md b/docs/source/guides/example.md index 9b82502f..3adf0107 100644 --- a/docs/source/guides/example.md +++ b/docs/source/guides/example.md @@ -22,13 +22,14 @@ env = HyperGrid(ndim=4, height=8) # Grid of size 8x8x8x8 preprocessor = KHotPreprocessor(ndim=env.ndim, height=env.height) # 2 - We define the needed modules (neural networks). +input_dim = preprocessor.output_dim if preprocessor.output_dim is not None else env.state_shape[-1] module_PF = MLP( - input_dim=preprocessor.output_dim, + input_dim=input_dim, output_dim=env.n_actions ) # Neural network for the forward policy, with as many outputs as there are actions module_PB = MLP( - input_dim=preprocessor.output_dim, + input_dim=input_dim, output_dim=env.n_actions - 1, trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer ) @@ -80,13 +81,14 @@ preprocessor = KHotPreprocessor(ndim=env.ndim, height=env.height) # 2 - We define the needed modules (neural networks). # The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator +input_dim = preprocessor.output_dim if preprocessor.output_dim is not None else env.state_shape[-1] module_PF = MLP( - input_dim=preprocessor.output_dim, + input_dim=input_dim, output_dim=env.n_actions ) # Neural network for the forward policy, with as many outputs as there are actions module_PB = MLP( - input_dim=preprocessor.output_dim, + input_dim=input_dim, output_dim=env.n_actions - 1, trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer ) From 08bf6ebf3fc02765e81cea6ecb5cf34219525415 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 9 Oct 2025 11:20:56 -0400 Subject: [PATCH 10/27] removed strange change to documentation --- docs/source/guides/example.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/guides/example.md b/docs/source/guides/example.md index 3adf0107..a9609611 100644 --- a/docs/source/guides/example.md +++ b/docs/source/guides/example.md @@ -93,7 +93,7 @@ module_PB = MLP( trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer ) module_logF = MLP( - input_dim=preprocessor.output_dim, + input_dim=input_dim, output_dim=1, # Important for ScalarEstimators! ) From baa50e4c7fca576ab3c8e444f7c81ecaf0990420 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 9 Oct 2025 18:09:27 -0400 Subject: [PATCH 11/27] added basic recurrent bitsequence algorithm --- src/gfn/estimators.py | 38 ++++- src/gfn/gflownet/base.py | 50 +++++- src/gfn/gflownet/detailed_balance.py | 50 +++++- src/gfn/gflownet/flow_matching.py | 12 ++ src/gfn/gflownet/sub_trajectory_balance.py | 64 ++++---- src/gfn/gflownet/trajectory_balance.py | 15 +- .../examples/train_bitsequence_recurrent.py | 151 ++++++++++++++++++ 7 files changed, 327 insertions(+), 53 deletions(-) create mode 100644 tutorials/examples/train_bitsequence_recurrent.py diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 835afdf1..df05f8e4 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -871,16 +871,38 @@ def forward( Returns: The output of the module, as a tensor of shape (*batch_shape, output_dim). """ - # TODO: Is this still true? NOTE: Can only be used for on-policy generation, - # not off-policy evaluation of entire trajectory. - # states.tensor.shape: (..., max_string_len) - current_input_len = states.tensor.shape[-1] # TODO: Check if this is correct. - states_tensor = states.tensor[..., :current_input_len] # (..., string_len) + # Prepare integer token sequences without -1 padding and use a BOS index. + # We infer the active sequence length per row from (token != -1). + tokens = states.tensor + if not torch.is_floating_point(tokens): + tokens = tokens.long() + else: + tokens = tokens.to(dtype=torch.long) + + # Replace padding (-1) with BOS index expected by the sequence model. + # RecurrentDiscreteSequenceModel reserves index == vocab_size for BOS. + bos_index = getattr(self.module, "vocab_size", self.n_actions - 1) + tokens = torch.where( + tokens < 0, torch.as_tensor(bos_index, device=tokens.device), tokens + ) + + # Determine a common prefix length across the (active) batch. + # Active rows in a rollout step share the same length; use max for safety. + # We still derive length from original states.tensor where -1 marks padding. + original = states.tensor + valid_mask = original >= 0 + if valid_mask.ndim == 1: + max_len = int(valid_mask.sum().item()) + else: + max_len = int(valid_mask.sum(dim=-1).max().item()) + if max_len <= 0: + max_len = 1 # Ensure at least BOS is processed - # Compute sequence of logits and update carry. - logits, carry = self.module(states_tensor, carry) + # Trim to the common active prefix length and run the sequence model. + seq_input = tokens[..., :max_len] + logits, carry = self.module(seq_input, carry) - # Get the logits for the last token in the sequence. + # Use the logits corresponding to the last processed token. logits = logits[:, -1, :] # (b, n_actions) if self.expected_output_dim is not None: diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 0c5efe64..40e46c98 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -171,7 +171,13 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC): """ def __init__( - self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False + self, + pf: Estimator, + pb: Estimator | None, + constant_pb: bool = False, + *, + pf_adapter: Any | None = None, + pb_adapter: Any | None = None, ) -> None: """Initializes a PFBasedGFlowNet instance. @@ -183,6 +189,11 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + pf_adapter: Optional adapter for PF probability calculation/sampling (e.g., + recurrent). When provided, used both in the sampler and in + probability recomputation paths. + pb_adapter: Optional adapter for PB probability calculation. Used in + probability recomputation paths when `pb` is provided. """ super().__init__() # Technical note: pb may be constant for a variety of edge cases, for example, @@ -211,6 +222,10 @@ def __init__( self.pf = pf self.pb = pb self.constant_pb = constant_pb + # Optional adapters controlling estimator interactions via + # vectorized / non-vectorized probability paths. + self.pf_adapter = pf_adapter + self.pb_adapter = pb_adapter def sample_trajectories( self, @@ -234,7 +249,7 @@ def sample_trajectories( Returns: A Trajectories object containing the sampled trajectories. """ - sampler = Sampler(estimator=self.pf) + sampler = Sampler(estimator=self.pf, adapter=self.pf_adapter) trajectories = sampler.sample_trajectories( env, n=n, @@ -275,7 +290,13 @@ class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]): """ def __init__( - self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False + self, + pf: Estimator, + pb: Estimator | None, + constant_pb: bool = False, + *, + pf_adapter: Any | None = None, + pb_adapter: Any | None = None, ) -> None: """Initializes a TrajectoryBasedGFlowNet instance. @@ -288,7 +309,13 @@ def __init__( explicitly by user to ensure that pb is an Estimator except under this special case. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__( + pf, + pb, + constant_pb=constant_pb, + pf_adapter=pf_adapter, + pb_adapter=pb_adapter, + ) def get_pfs_and_pbs( self, @@ -301,8 +328,8 @@ def get_pfs_and_pbs( More specifically, it evaluates $\log P_F(s' \mid s)$ and $\log P_B(s \mid s')$ for each transition in each trajectory in the batch. - If recalculate_all_logprobs=True, we re-evaluate the logprobs of the trajectories - using the current self.pf. Otherwise, the following applies: + If recalculate_all_logprobs=True, we re-evaluate the logprobs of the + trajectories using the current self.pf. Otherwise, the following applies: - If trajectories have logprobs attribute, use them - this is usually for on-policy learning. - Elif trajectories have estimator_outputs attribute, transform them into @@ -311,6 +338,9 @@ def get_pfs_and_pbs( the current self.pf - this is usually for off-policy learning with replay buffer. + Uses the PF and PB adapters to evaluate the logprobs, with their optional + adapters if provided. + Args: trajectories: The Trajectories object to evaluate. fill_value: Value to use for invalid states (e.g., $s_f$ added to shorter @@ -322,7 +352,13 @@ def get_pfs_and_pbs( the log_pf and log_pb for each action in each trajectory. """ return get_trajectory_pfs_and_pbs( - self.pf, self.pb, trajectories, fill_value, recalculate_all_logprobs + self.pf, + self.pb, + trajectories, + fill_value, + recalculate_all_logprobs, + pf_adapter=self.pf_adapter, + pb_adapter=self.pb_adapter, ) def get_scores( diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index efb53d1b..7a7133b0 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -1,5 +1,5 @@ import math -from typing import Tuple +from typing import Any, Tuple import torch @@ -76,6 +76,9 @@ def __init__( log_reward_clip_min: float = -float("inf"), safe_log_prob_min: bool = True, constant_pb: bool = False, + *, + pf_adapter: Any | None = None, + pb_adapter: Any | None = None, ) -> None: """Initializes a DBGFlowNet instance. @@ -93,8 +96,18 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + pf_adapter: Optional estimator adapter controlling PF probability + computation/sampling. + pb_adapter: Optional estimator adapter controlling PB probability + computation. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__( + pf, + pb, + constant_pb=constant_pb, + pf_adapter=pf_adapter, + pb_adapter=pb_adapter, + ) assert any( isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator] @@ -148,7 +161,12 @@ def get_pfs_and_pbs( log_pb for each transition. """ return get_transition_pfs_and_pbs( - self.pf, self.pb, transitions, recalculate_all_logprobs + self.pf, + self.pb, + transitions, + recalculate_all_logprobs, + pf_adapter=self.pf_adapter, + pb_adapter=self.pb_adapter, ) def get_scores( @@ -301,10 +319,30 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]): """ def __init__( - self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False + self, + pf: Estimator, + pb: Estimator | None, + constant_pb: bool = False, + *, + pf_adapter: Any | None = None, + pb_adapter: Any | None = None, ) -> None: - """Initializes a ModifiedDBGFlowNet instance.""" - super().__init__(pf, pb, constant_pb=constant_pb) + """Initializes a ModifiedDBGFlowNet instance. + + Args: + pf: Forward policy estimator. + pb: Backward policy estimator or None. + constant_pb: See base class. + pf_adapter: Optional adapter for PF. + pb_adapter: Optional adapter for PB. + """ + super().__init__( + pf, + pb, + constant_pb=constant_pb, + pf_adapter=pf_adapter, + pb_adapter=pb_adapter, + ) def get_scores( self, transitions: Transitions, recalculate_all_logprobs: bool = True diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 79cc9a52..74cc6125 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -32,6 +32,14 @@ class FMGFlowNet(GFlowNet[StatesContainer[DiscreteStates]]): logF: A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states). alpha: A scalar weight for the reward matching loss. + + Adapter note + ------------ + Flow Matching does not rely on PF/PB probability recomputation. Any trajectory + sampling provided by this class is for diagnostics/visualization and uses the + default (non-recurrent) adapter internally. Sampler adapters (e.g., + `RecurrentEstimatorAdapter`) are not exposed as configuration options for this + class. """ def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): @@ -72,6 +80,10 @@ def sample_trajectories( Returns: A Trajectories object containing the sampled trajectories. + + Notes: + This helper uses the default sampler adapter; custom sampler adapters are + not supported for Flow Matching. """ if not env.is_discrete: raise NotImplementedError( diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 34175654..8987451d 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -1,6 +1,6 @@ import math import warnings -from typing import List, Literal, Tuple, TypeAlias +from typing import Any, List, Literal, Tuple, TypeAlias import torch @@ -83,6 +83,9 @@ def __init__( log_reward_clip_min: float = -float("inf"), forward_looking: bool = False, constant_pb: bool = False, + *, + pf_adapter: Any | None = None, + pb_adapter: Any | None = None, ): """Initializes a SubTBGFlowNet instance. @@ -100,8 +103,18 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + pf_adapter: Optional estimator adapter controlling PF probability + computation/sampling. + pb_adapter: Optional estimator adapter controlling PB probability + computation. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__( + pf, + pb, + constant_pb=constant_pb, + pf_adapter=pf_adapter, + pb_adapter=pb_adapter, + ) assert any( isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator] @@ -157,14 +170,14 @@ def cumulative_logprobs( def calculate_preds( self, - log_pf_trajectories_cum: CumulativeLogProbsTensor, + log_pf_traj_cum: CumulativeLogProbsTensor, log_state_flows: LogStateFlowsTensor, i: int, ) -> PredictionsTensor: """Calculates the predictions tensor for the current sub-trajectory length. Args: - log_pf_trajectories_cum: Tensor of shape (max_length + 1, n_trajectories) + log_pf_traj_cum: Tensor of shape (max_length + 1, n_trajectories) containing the cumulative sum of logprobs of the forward actions for each trajectory. log_state_flows: Tensor of shape (max_length, n_trajectories) containing @@ -178,11 +191,7 @@ def calculate_preds( log_state_flows if i == 1 else log_state_flows[: -(i - 1)] ) - preds = ( - log_pf_trajectories_cum[i:] - - log_pf_trajectories_cum[:-i] - + current_log_state_flows - ) + preds = log_pf_traj_cum[i:] - log_pf_traj_cum[:-i] + current_log_state_flows return preds @@ -190,7 +199,7 @@ def calculate_targets( self, trajectories: Trajectories, preds: PredictionsTensor, - log_pb_trajectories_cum: CumulativeLogProbsTensor, + log_pb_traj_cum: CumulativeLogProbsTensor, log_state_flows: LogStateFlowsTensor, is_terminal_mask: MaskTensor, sink_states_mask: MaskTensor, @@ -202,7 +211,7 @@ def calculate_targets( trajectories: The batch of trajectories. preds: Tensor of shape (max_length + 1 - i, n_trajectories) containing the predictions for the current sub-trajectory length. - log_pb_trajectories_cum: Tensor of shape (max_length + 1, n_trajectories) + log_pb_traj_cum: Tensor of shape (max_length + 1, n_trajectories) containing the cumulative sum of logprobs of the backward actions for each trajectory. log_state_flows: Tensor of shape (max_length, n_trajectories) containing @@ -229,15 +238,16 @@ def calculate_targets( # We need to add to that the log-probabilities of the backward actions up-to # the sub-trajectory's terminating state if i > 1: - targets[is_terminal_mask[i - 1 :]] += ( - log_pb_trajectories_cum[i - 1 :] - log_pb_trajectories_cum[: -i + 1] - )[:-1][is_terminal_mask[i - 1 :]] + delta_pb = (log_pb_traj_cum[i - 1 :] - log_pb_traj_cum[: -i + 1])[:-1] + targets[is_terminal_mask[i - 1 :]] += delta_pb[is_terminal_mask[i - 1 :]] # The following creates the targets for the non-finishing sub-trajectories full_mask = sink_states_mask | is_terminal_mask + delta_pb2 = (log_pb_traj_cum[i:] - log_pb_traj_cum[:-i])[:-1] + rhs_mask = ~full_mask[i - 1 : -1] targets[~full_mask[i - 1 :]] = ( - log_pb_trajectories_cum[i:] - log_pb_trajectories_cum[:-i] - )[:-1][~full_mask[i - 1 : -1]] + log_state_flows[i:][~sink_states_mask[i:]] + delta_pb2[rhs_mask] + log_state_flows[i:][~sink_states_mask[i:]] + ) return targets @@ -445,11 +455,8 @@ def get_tb_contributions(self, trajectories: Trajectories) -> ContributionsTenso contributions = torch.zeros(n_rows, len(trajectories)) # Each trajectory contributes one element to the loss, equally weighted - terminating_idx = trajectories.terminating_idx - indices = ( - max_len * (terminating_idx - 1) - - (terminating_idx - 1) * (terminating_idx - 2) / 2 - ).long() + t_idx = trajectories.terminating_idx + indices = (max_len * (t_idx - 1) - (t_idx - 1) * (t_idx - 2) / 2).long() contributions.scatter_(0, indices.unsqueeze(0), 1) contributions = contributions / len(trajectories) @@ -498,16 +505,16 @@ def get_geometric_within_contributions( """ L = self.lamda max_len = trajectories.max_length - terminating_idx = trajectories.terminating_idx + t_idx = trajectories.terminating_idx # The following tensor represents the weights given to each possible # sub-trajectory length. - contributions = ( - L ** torch.arange(max_len, device=terminating_idx.device).double() - ).to(torch.get_default_dtype()) + contributions = (L ** torch.arange(max_len, device=t_idx.device).double()).to( + torch.get_default_dtype() + ) contributions = contributions.unsqueeze(-1).repeat(1, len(trajectories)) contributions = contributions.repeat_interleave( - torch.arange(max_len, 0, -1, device=terminating_idx.device), + torch.arange(max_len, 0, -1, device=t_idx.device), dim=0, output_size=int(max_len * (max_len + 1) / 2), ) @@ -519,10 +526,7 @@ def get_geometric_within_contributions( per_trajectory_denom = ( 1.0 / (1 - L) ** 2 - * ( - L * (L ** terminating_idx.double() - 1) - + (1 - L) * terminating_idx.double() - ) + * (L * (L ** t_idx.double() - 1) + (1 - L) * t_idx.double()) ).to(torch.get_default_dtype()) contributions = contributions / per_trajectory_denom / len(trajectories) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 4fd0586a..6614500b 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -3,7 +3,7 @@ and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446). """ -from typing import cast +from typing import Any, cast import torch import torch.nn as nn @@ -47,6 +47,9 @@ def __init__( init_logZ: float = 0.0, log_reward_clip_min: float = -float("inf"), constant_pb: bool = False, + *, + pf_adapter: Any | None = None, + pb_adapter: Any | None = None, ): """Initializes a TBGFlowNet instance. @@ -61,8 +64,16 @@ def __init__( constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + pf_adapter: Optional estimator adapter controlling how PF probabilities are + computed and sampled (e.g., `RecurrentEstimatorAdapter`). When provided, + it is used both by the Sampler and by probability recomputation paths. + pb_adapter: Optional estimator adapter for PB probability computation. If + provided and `pb` is an Estimator, it will be used in probability + recomputation paths that require PB. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__( + pf, pb, constant_pb=constant_pb, pf_adapter=pf_adapter, pb_adapter=pb_adapter + ) self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) self.log_reward_clip_min = log_reward_clip_min diff --git a/tutorials/examples/train_bitsequence_recurrent.py b/tutorials/examples/train_bitsequence_recurrent.py new file mode 100644 index 00000000..4c97b18c --- /dev/null +++ b/tutorials/examples/train_bitsequence_recurrent.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +""" +Minimal TB training on BitSequence with a recurrent policy. + +Key choices: +- RecurrentDiscretePolicyEstimator + RecurrentDiscreteSequenceModel +- Sampler uses RecurrentEstimatorAdapter (saves on-policy log-probs) +- TBGFlowNet with constant_pb=True (tree DAG), pb=None + +This is intentionally small and mirrors train_hypergrid_simple.py structure. +""" + +import argparse +from typing import cast + +import torch +from tqdm import tqdm + +from gfn.estimators import RecurrentDiscretePolicyEstimator +from gfn.gflownet import PFBasedGFlowNet, TBGFlowNet +from gfn.gym.bitSequence import BitSequence +from gfn.samplers import RecurrentEstimatorAdapter, Sampler +from gfn.states import DiscreteStates +from gfn.utils.common import set_seed +from gfn.utils.modules import RecurrentDiscreteSequenceModel +from gfn.utils.prob_calculations import get_trajectory_pfs + + +def estimated_dist(gflownet: PFBasedGFlowNet, env: BitSequence): + states = env.terminating_states + trajectories = env.trajectory_from_terminating_states(states.tensor) + log_pf_trajectories = get_trajectory_pfs( + pf=gflownet.pf, trajectories=trajectories, recalculate_all_logprobs=True + ) + pf = torch.exp(log_pf_trajectories.sum(dim=0)) + return pf + + +def main(args): + set_seed(args.seed) + device = torch.device( + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + ) + + # Environment + H = torch.randint( + 0, 2, (args.n_modes, args.seq_size), dtype=torch.long, device=device + ) + env = BitSequence( + word_size=args.word_size, + seq_size=args.seq_size, + n_modes=args.n_modes, + temperature=args.temperature, + H=H, + device_str=str(device), + seed=args.seed, + check_action_validity=__debug__, + ) + + # Model + Estimator + # Set vocab_size so projection outputs env.n_actions logits (includes exit). + model = RecurrentDiscreteSequenceModel( + vocab_size=env.n_actions, # projection -> env.n_actions + embedding_dim=args.embedding_dim, + hidden_size=args.hidden_size, + num_layers=args.num_layers, + rnn_type=args.rnn_type, + dropout=args.dropout, + ).to(device) + + pf_estimator = RecurrentDiscretePolicyEstimator( + module=model, + n_actions=env.n_actions, + is_backward=False, + ).to(device) + + # GFlowNet (Trajectory Balance), tree DAG -> pb=None, constant_pb=True + gflownet = TBGFlowNet(pf=pf_estimator, pb=None, init_logZ=0.0, constant_pb=True) + gflownet = gflownet.to(device) + + # Sampler with recurrent adapter; save log-probs for on-policy TB + sampler = Sampler( + estimator=pf_estimator, adapter=RecurrentEstimatorAdapter(pf_estimator) + ) + + # Optimizer: policy params + logZ + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) + optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": args.lr_logz}) + + visited_terminating_states = env.states_from_batch_shape((0,)) + + for it in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)): + trajectories = sampler.sample_trajectories( + env, + n=args.batch_size, + save_logprobs=True, # crucial: avoid recalculation, use adapter path + save_estimator_outputs=False, + epsilon=args.epsilon, + ) + + visited_terminating_states.extend( + cast(DiscreteStates, trajectories.terminating_states) + ) + + optimizer.zero_grad() + # Use saved log-probs from sampler; no need to recalc + loss = gflownet.loss(env, trajectories, recalculate_all_logprobs=False) + loss.backward() + + gflownet.assert_finite_gradients() + torch.nn.utils.clip_grad_norm_(gflownet.parameters(), 1.0) + optimizer.step() + gflownet.assert_finite_parameters() + + pbar.set_postfix({"loss": loss.item()}) + + # Final validation. + gflownet = cast(PFBasedGFlowNet, gflownet) + l1_dist = torch.abs(estimated_dist(gflownet, env) - env.true_dist).mean().item() + print(f"Final L1 distance: {l1_dist}") + + return l1_dist + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--no_cuda", action="store_true", help="Disable CUDA use") + + # BitSequence config (keep small by default) + parser.add_argument("--word_size", type=int, default=1, help="Word size") + parser.add_argument("--seq_size", type=int, default=4, help="Sequence size") + parser.add_argument("--n_modes", type=int, default=2, help="Number of modes") + parser.add_argument("--temperature", type=float, default=1.0) + + # Model config + parser.add_argument("--embedding_dim", type=int, default=64) + parser.add_argument("--hidden_size", type=int, default=128) + parser.add_argument("--num_layers", type=int, default=1) + parser.add_argument("--rnn_type", type=str, choices=["lstm", "gru"], default="lstm") + parser.add_argument("--dropout", type=float, default=0.0) + + # Training config + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--lr_logz", type=float, default=1e-1) + parser.add_argument("--n_iterations", type=int, default=1000) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--epsilon", type=float, default=0.05) + + args = parser.parse_args() + main(args) From e8d3fc2b8180235d8535f2fc4b55a534d2633938 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 9 Oct 2025 21:27:32 -0400 Subject: [PATCH 12/27] added working bitsequence example for recurrent estimators and their adapters --- .../examples/train_bitsequence_recurrent.py | 58 ++++++++++++------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/tutorials/examples/train_bitsequence_recurrent.py b/tutorials/examples/train_bitsequence_recurrent.py index 4c97b18c..9c922a24 100644 --- a/tutorials/examples/train_bitsequence_recurrent.py +++ b/tutorials/examples/train_bitsequence_recurrent.py @@ -19,7 +19,7 @@ from gfn.estimators import RecurrentDiscretePolicyEstimator from gfn.gflownet import PFBasedGFlowNet, TBGFlowNet from gfn.gym.bitSequence import BitSequence -from gfn.samplers import RecurrentEstimatorAdapter, Sampler +from gfn.samplers import RecurrentEstimatorAdapter from gfn.states import DiscreteStates from gfn.utils.common import set_seed from gfn.utils.modules import RecurrentDiscreteSequenceModel @@ -30,10 +30,16 @@ def estimated_dist(gflownet: PFBasedGFlowNet, env: BitSequence): states = env.terminating_states trajectories = env.trajectory_from_terminating_states(states.tensor) log_pf_trajectories = get_trajectory_pfs( - pf=gflownet.pf, trajectories=trajectories, recalculate_all_logprobs=True + pf=gflownet.pf, + trajectories=trajectories, + recalculate_all_logprobs=True, + adapter=gflownet.pf_adapter, ) pf = torch.exp(log_pf_trajectories.sum(dim=0)) - return pf + + l1_dist = torch.abs(pf - env.true_dist).mean().item() + + return l1_dist def main(args): @@ -74,28 +80,33 @@ def main(args): is_backward=False, ).to(device) - # GFlowNet (Trajectory Balance), tree DAG -> pb=None, constant_pb=True - gflownet = TBGFlowNet(pf=pf_estimator, pb=None, init_logZ=0.0, constant_pb=True) - gflownet = gflownet.to(device) - - # Sampler with recurrent adapter; save log-probs for on-policy TB - sampler = Sampler( - estimator=pf_estimator, adapter=RecurrentEstimatorAdapter(pf_estimator) + # GFlowNet (Trajectory Balance), tree DAG -> pb=None, constant_pb=True, + # Use a recurrent adapter for the PF. + gflownet = TBGFlowNet( + pf=pf_estimator, + pb=None, + init_logZ=0.0, + constant_pb=True, + pf_adapter=RecurrentEstimatorAdapter(pf_estimator), ) + gflownet = gflownet.to(device) # Optimizer: policy params + logZ optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": args.lr_logz}) visited_terminating_states = env.states_from_batch_shape((0,)) + l1_distances = [] + eval_freq = args.n_iterations // 10 # 10% of the iterations. + l1_dist = float("inf") for it in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)): - trajectories = sampler.sample_trajectories( + trajectories = gflownet.sample_trajectories( env, n=args.batch_size, save_logprobs=True, # crucial: avoid recalculation, use adapter path save_estimator_outputs=False, - epsilon=args.epsilon, + epsilon=args.epsilon, # Off-policy sampling. ) visited_terminating_states.extend( @@ -112,12 +123,17 @@ def main(args): optimizer.step() gflownet.assert_finite_parameters() - pbar.set_postfix({"loss": loss.item()}) + if (it + 1) % eval_freq == 0 or it == 0: + l1_dist = estimated_dist(gflownet, env) + l1_distances.append(l1_dist) + + pbar.set_postfix({"loss": loss.item(), "l1_dist": l1_dist}) # Final validation. - gflownet = cast(PFBasedGFlowNet, gflownet) - l1_dist = torch.abs(estimated_dist(gflownet, env) - env.true_dist).mean().item() - print(f"Final L1 distance: {l1_dist}") + l1_dist = estimated_dist(gflownet, env) + l1_distances.append(l1_dist) + print(f"L1_dist training curve: {[f'{l1:.5f}' for l1 in l1_distances]}") + print(f"Final L1_dist: {l1_dist:.5f}") return l1_dist @@ -127,15 +143,15 @@ def main(args): parser.add_argument("--no_cuda", action="store_true", help="Disable CUDA use") # BitSequence config (keep small by default) - parser.add_argument("--word_size", type=int, default=1, help="Word size") - parser.add_argument("--seq_size", type=int, default=4, help="Sequence size") - parser.add_argument("--n_modes", type=int, default=2, help="Number of modes") + parser.add_argument("--word_size", type=int, default=3, help="Word size") + parser.add_argument("--seq_size", type=int, default=9, help="Sequence size") + parser.add_argument("--n_modes", type=int, default=5, help="Number of modes") parser.add_argument("--temperature", type=float, default=1.0) # Model config parser.add_argument("--embedding_dim", type=int, default=64) parser.add_argument("--hidden_size", type=int, default=128) - parser.add_argument("--num_layers", type=int, default=1) + parser.add_argument("--num_layers", type=int, default=3) parser.add_argument("--rnn_type", type=str, choices=["lstm", "gru"], default="lstm") parser.add_argument("--dropout", type=float, default=0.0) @@ -143,7 +159,7 @@ def main(args): parser.add_argument("--seed", type=int, default=0) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--lr_logz", type=float, default=1e-1) - parser.add_argument("--n_iterations", type=int, default=1000) + parser.add_argument("--n_iterations", type=int, default=500) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--epsilon", type=float, default=0.05) From e0dd464ab5d22ef02ab4f363a1c51fd284ee7148 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 9 Oct 2025 23:19:21 -0400 Subject: [PATCH 13/27] fixed test --- src/gfn/gflownet/base.py | 20 ++ src/gfn/gflownet/detailed_balance.py | 13 + src/gfn/samplers.py | 8 + src/gfn/utils/prob_calculations.py | 29 ++- ..._adaptor_estimator_gflownet_integration.py | 240 ++++++++++++++++++ testing/test_samplers_and_trajectories.py | 5 +- 6 files changed, 303 insertions(+), 12 deletions(-) create mode 100644 testing/test_adaptor_estimator_gflownet_integration.py diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 40e46c98..8a6b1294 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -227,6 +227,26 @@ def __init__( self.pf_adapter = pf_adapter self.pb_adapter = pb_adapter + # Advisory: recurrent PF with non-recurrent PB is unusual + # (tree DAGs typically prefer pb=None with constant_pb=True). + # Import locally to avoid circular imports during module import time. + from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore + + if isinstance(self.pf, RecurrentDiscretePolicyEstimator) and isinstance( + self.pb, Estimator + ): + warnings.warn( + "Using a recurrent PF, which is only valid for tree DAGs, with a " + "non-recurrent PB is unusual. " + "Consider using pb=None with constant_pb=True for tree DAGs.", + ) + # Disallow recurrent PB estimators universally. + if isinstance(self.pb, RecurrentDiscretePolicyEstimator): + raise TypeError( + "Recurrent PB estimators are not supported. Use a non-recurrent PB " + "or set pb=None with constant_pb=True for tree DAGs." + ) + def sample_trajectories( self, env: Env, diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 7a7133b0..fd63431c 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -108,6 +108,19 @@ def __init__( pf_adapter=pf_adapter, pb_adapter=pb_adapter, ) + # Disallow recurrent PF or recurrent adapter for transition-based DB + from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore + from gfn.samplers import RecurrentEstimatorAdapter # type: ignore + + if isinstance(self.pf, RecurrentDiscretePolicyEstimator): + raise TypeError( + "DBGFlowNet does not support recurrent PF estimators (transitions path cannot propagate carry)." + ) + if isinstance(self.pf_adapter, RecurrentEstimatorAdapter): + raise TypeError( + "DBGFlowNet does not support RecurrentEstimatorAdapter (transitions path cannot propagate carry)." + ) + assert any( isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator] diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 0eb6251b..24a02996 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -449,6 +449,14 @@ class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): """ def __init__(self, estimator: Estimator) -> None: + # Validate that the estimator presents a recurrent interface + # We check for the presence of `init_carry` and a callable that accepts (states, carry). + init_carry = getattr(estimator, "init_carry", None) + if not callable(init_carry): + raise TypeError( + "RecurrentEstimatorAdapter requires an estimator implementing " + "init_carry(batch_size: int, device: torch.device)." + ) super().__init__(estimator) @property diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index ceb65fb8..31c3ba79 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -300,14 +300,29 @@ def get_trajectory_pbs( # Recurrent adapters are only valid for trajectories and never require pb from gfn.samplers import RecurrentEstimatorAdapter # type: ignore - is_recurrent = ( - isinstance(adapter, RecurrentEstimatorAdapter) if adapter is not None else False - ) + if adapter is not None: + is_recurrent = isinstance(adapter, RecurrentEstimatorAdapter) + else: + is_recurrent = False - if is_recurrent: - # With recurrent adapter, pb must be None (tree DAG); return zeros + if is_recurrent or pb is None: + # With recurrent adapter, pb *must* be None (tree DAG); return zeros. assert pb is None, "When using a RecurrentEstimatorAdapter, pb must be None." + # If pb is None, we assume that the gflownet DAG is a tree, and therefore + # the backward policy probability is always 1 (log probs are 0). valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) + valid_log_pb_actions = valid_log_pb_actions.squeeze(-1) # no padding. + + # TODO: Add logging in follow up PR. + # if os.getenv("GFN_DEBUG_REC_PB") == "1": + # print( + # "[DBG] pb=None path: valid_actions.shape=", + # tuple(valid_actions.tensor.shape), + # "valid_log_pb_actions.shape=", + # tuple(valid_log_pb_actions.shape), + # "target_len=", + # int(action_mask.sum().item()), + # ) elif pb is not None: # Choose vectorized (legacy) vs non-vectorized (adapter per-step) @@ -363,10 +378,6 @@ def get_trajectory_pbs( log_pb_trajectories[action_mask] = valid_log_pb_actions.to( log_pb_trajectories.dtype, copy=False ) - else: - # If pb is None, we assume that the gflownet DAG is a tree, and therefore - # the backward policy probability is always 1 (log probs are 0). - valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) log_pb_trajectories[action_mask] = valid_log_pb_actions.to( log_pb_trajectories.dtype, copy=False diff --git a/testing/test_adaptor_estimator_gflownet_integration.py b/testing/test_adaptor_estimator_gflownet_integration.py new file mode 100644 index 00000000..05bd03c6 --- /dev/null +++ b/testing/test_adaptor_estimator_gflownet_integration.py @@ -0,0 +1,240 @@ +import warnings + +import pytest +import torch + +from gfn.estimators import ( + DiscretePolicyEstimator, + RecurrentDiscretePolicyEstimator, + ScalarEstimator, +) +from gfn.gflownet import DBGFlowNet, TBGFlowNet +from gfn.gym.bitSequence import BitSequence +from gfn.samplers import DefaultEstimatorAdapter, RecurrentEstimatorAdapter +from gfn.utils.modules import MLP, RecurrentDiscreteSequenceModel + + +def _make_bitsequence_env( + *, device: torch.device, word_size: int = 3, seq_size: int = 9, n_modes: int = 5 +) -> BitSequence: + H = torch.randint(0, 2, (n_modes, seq_size), dtype=torch.long, device=device) + env = BitSequence( + word_size=word_size, + seq_size=seq_size, + n_modes=n_modes, + temperature=1.0, + H=H, + device_str=str(device), + seed=0, + check_action_validity=True, + ) + return env + + +def _make_recurrent_pf( + env: BitSequence, device: torch.device +) -> RecurrentDiscretePolicyEstimator: + model = RecurrentDiscreteSequenceModel( + vocab_size=env.n_actions, + embedding_dim=16, + hidden_size=32, + num_layers=1, + rnn_type="lstm", + dropout=0.0, + ).to(device) + pf = RecurrentDiscretePolicyEstimator( + module=model, n_actions=env.n_actions, is_backward=False + ).to(device) + return pf + + +def _make_nonrecurrent_pf_pb(env: BitSequence, device: torch.device): + input_dim = ( + env.words_per_seq + ) # BitSequence states are integer words of length words_per_seq + pf_module = MLP( + input_dim=input_dim, output_dim=env.n_actions, hidden_dim=32, n_hidden_layers=1 + ).to(device) + pb_module = MLP( + input_dim=input_dim, + output_dim=env.n_actions - 1, + hidden_dim=32, + n_hidden_layers=1, + ).to(device) + pf = DiscretePolicyEstimator( + module=pf_module, n_actions=env.n_actions, is_backward=False + ).to(device) + pb = DiscretePolicyEstimator( + module=pb_module, n_actions=env.n_actions, is_backward=True + ).to(device) + return pf, pb + + +def test_recurrent_tb_passes_with_pb_none(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf = _make_recurrent_pf(env, device) + adapter = RecurrentEstimatorAdapter(pf) + gfn = TBGFlowNet(pf=pf, pb=None, init_logZ=0.0, constant_pb=True, pf_adapter=adapter) + + # sample and compute a loss to ensure end-to-end path works + trajectories = gfn.sample_trajectories( + env, n=4, save_logprobs=True, save_estimator_outputs=False + ) + loss = gfn.loss(env, trajectories, recalculate_all_logprobs=False) + assert torch.isfinite(loss) + + +def test_warn_on_recurrent_pf_with_nonrecurrent_pb(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf = _make_recurrent_pf(env, device) + pb_pf, pb = _make_nonrecurrent_pf_pb(env, device) + del pb_pf # unused + + adapter = RecurrentEstimatorAdapter(pf) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _ = TBGFlowNet( + pf=pf, pb=pb, init_logZ=0.0, constant_pb=False, pf_adapter=adapter + ) + assert any("unusual" in str(x.message).lower() for x in w) + + +def test_error_on_recurrent_pb(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf_nonrec, _ = _make_nonrecurrent_pf_pb(env, device) + + # Build a recurrent PB + model = RecurrentDiscreteSequenceModel( + vocab_size=env.n_actions - 1, + embedding_dim=16, + hidden_size=32, + num_layers=1, + rnn_type="lstm", + dropout=0.0, + ).to(device) + pb_recurrent = RecurrentDiscretePolicyEstimator( + module=model, n_actions=env.n_actions, is_backward=True + ).to(device) + + with pytest.raises(TypeError, match="Recurrent PB estimators are not supported"): + _ = TBGFlowNet(pf=pf_nonrec, pb=pb_recurrent, init_logZ=0.0, constant_pb=False) + + +def test_db_gflownet_rejects_recurrent_pf_and_adapter(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf = _make_recurrent_pf(env, device) + + # recurrent PF should be rejected + logF_est = ScalarEstimator( + module=MLP( + input_dim=env.words_per_seq, + output_dim=1, + hidden_dim=16, + n_hidden_layers=1, + ).to(device) + ) + with pytest.raises(TypeError, match="does not support recurrent PF"): + _ = DBGFlowNet( + pf=pf, + pb=None, + logF=logF_est, + constant_pb=True, + ) # type: ignore[arg-type] + + # non-recurrent PF with recurrent adapter should also be rejected + pf_nonrec, _ = _make_nonrecurrent_pf_pb(env, device) + adapter = RecurrentEstimatorAdapter( + _make_recurrent_pf(env, device) + ) # construct valid adapter + with pytest.raises(TypeError, match="does not support RecurrentEstimatorAdapter"): + _ = DBGFlowNet( + pf=pf_nonrec, + pb=None, + logF=logF_est, + constant_pb=True, + pf_adapter=adapter, + ) # type: ignore[arg-type] + + +def test_nonrecurrent_tb_passes_with_pb_defined(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf, pb = _make_nonrecurrent_pf_pb(env, device) + gfn = TBGFlowNet( + pf=pf, + pb=pb, + init_logZ=0.0, + constant_pb=False, + pf_adapter=DefaultEstimatorAdapter(pf), + ) + + trajectories = gfn.sample_trajectories( + env, n=3, save_logprobs=True, save_estimator_outputs=False + ) + loss = gfn.loss(env, trajectories, recalculate_all_logprobs=False) + assert torch.isfinite(loss) + + +def test_adapter_rejects_nonrecurrent_estimator(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf, _ = _make_nonrecurrent_pf_pb(env, device) + with pytest.raises(TypeError, match="requires an estimator implementing init_carry"): + _ = RecurrentEstimatorAdapter(pf) + + +def test_pb_mlp_trunk_sharing_parity_on_transitions(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + + # Build non-recurrent PF for sampling + pf, _ = _make_nonrecurrent_pf_pb(env, device) + + # PB with trunk sharing from PF + pb_shared_module = MLP( + input_dim=env.words_per_seq, + output_dim=env.n_actions - 1, + hidden_dim=32, + n_hidden_layers=1, + trunk=pf.module.trunk, # type: ignore[attr-defined] + ).to(device) + pb_shared = DiscretePolicyEstimator( + module=pb_shared_module, n_actions=env.n_actions, is_backward=True + ).to(device) + + # PB independent with identical weights copied from shared version + pb_indep_module = MLP( + input_dim=env.words_per_seq, + output_dim=env.n_actions - 1, + hidden_dim=32, + n_hidden_layers=1, + ).to(device) + pb_indep_module.load_state_dict(pb_shared_module.state_dict()) + pb_indep = DiscretePolicyEstimator( + module=pb_indep_module, n_actions=env.n_actions, is_backward=True + ).to(device) + + # Sample trajectories and convert to transitions + from gfn.samplers import Sampler + + sampler = Sampler(estimator=pf) + trajectories = sampler.sample_trajectories( + env, n=5, save_logprobs=False, save_estimator_outputs=False + ) + transitions = trajectories.to_transitions() + + # Compute PB log-probs using vectorized default adapters for each PB + from gfn.utils.prob_calculations import get_transition_pbs + + lp_shared = get_transition_pbs( + pb_shared, transitions, adapter=DefaultEstimatorAdapter(pb_shared) + ) + lp_indep = get_transition_pbs( + pb_indep, transitions, adapter=DefaultEstimatorAdapter(pb_indep) + ) + + torch.testing.assert_close(lp_shared, lp_indep) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 71862dae..3ae98668 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -547,9 +547,8 @@ def test_recurrent_adapter_requires_init_carry(): class _BadEstimator: is_backward = False - adapter = RecurrentEstimatorAdapter(cast(Estimator, _BadEstimator())) - with pytest.raises(TypeError): - _ = adapter.init_context(2, torch.device("cpu"), None) + with pytest.raises(TypeError, match="requires an estimator implementing init_carry"): + _ = RecurrentEstimatorAdapter(cast(Estimator, _BadEstimator())) def test_recurrent_adapter_flow(): From bca3df6c3bf2df737991b4cf5dfe04ea1e616603 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 11 Oct 2025 21:51:43 -0400 Subject: [PATCH 14/27] Update estimators.py tweak of how default preprocessor is defined --- src/gfn/estimators.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index df05f8e4..0b950fe3 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -834,9 +834,7 @@ def __init__( self, module: nn.Module, n_actions: int, - preprocessor: Preprocessor | None = IdentityPreprocessor( - output_dim=None - ), # Addressed in https://github.com/GFNOrg/torchgfn/pull/399. + preprocessor: Preprocessor | None = None, is_backward: bool = False, ): """Initializes a RecurrentDiscretePolicyEstimator. @@ -847,9 +845,8 @@ def __init__( preprocessor: Preprocessor object that transforms states to tensors. """ if preprocessor is None: - preprocessor = IdentityPreprocessor( - output_dim=None - ) # Addressed in https://github.com/GFNOrg/torchgfn/pull/399. + preprocessor = IdentityPreprocessor(output_dim=None) + super().__init__( module=module, n_actions=n_actions, From fc6cb7a044e62381ae6aa0abb184e0dfa0d2b095 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sun, 12 Oct 2025 22:21:17 -0400 Subject: [PATCH 15/27] black / isort --- docs/source/guides/estimator_adapters.md | 19 +- src/gfn/samplers.py | 361 +++++++++------------- src/gfn/utils/prob_calculations.py | 39 ++- testing/test_probability_calculations.py | 9 +- testing/test_samplers_and_trajectories.py | 8 +- 5 files changed, 185 insertions(+), 251 deletions(-) diff --git a/docs/source/guides/estimator_adapters.md b/docs/source/guides/estimator_adapters.md index d5971886..d7f08150 100644 --- a/docs/source/guides/estimator_adapters.md +++ b/docs/source/guides/estimator_adapters.md @@ -3,7 +3,7 @@ Adapters decouple the generic sampling and probability computation logic from estimator-specific details (conditioning shape, recurrent state/carry, distribution construction, artifact recording). They enable a single sampler and probability utilities to work across different estimator families. This guide explains: -- The Adapter protocol and lifecycle +- The Adapter and RolloutContext - Vectorized vs non-vectorized probability paths - How adapters integrate with the Sampler and probability calculators - How to implement a new Adapter @@ -17,9 +17,9 @@ An Adapter mediates between three places where estimator logic is needed: The Sampler remains estimator-agnostic. Adapters own any estimator-specific state (e.g., recurrent carry) and control how to run the estimator and build the policy distribution. -## Adapter Protocol (call signature) +## Adapters -Adapters conform to a structural Protocol with the following surface (see `gfn/samplers.py`): +Adapters conform to an abstract class structure (see `gfn/samplers.py`): - Properties - `is_backward: bool` — whether the wrapped estimator is a backward policy. @@ -39,6 +39,9 @@ Adapters conform to a structural Protocol with the following surface (see `gfn/s - `log_prob_of_actions(states_active: States, actions_active: Tensor, ctx: Any, step_mask: Tensor, **policy_kwargs) -> (Tensor, Any)` - Computes log-probs for a batch of (state, action) pairs corresponding to `True` entries of `step_mask` and returns a padded `(N,)` vector. + - `finalize(ctx) -> -> dict[str, Optional[torch.Tensor]]` + - Realizes the buffers of the context object into tensors which can be used by the rest of the library (e.g., Trajectories objects). + - `get_current_estimator_output(ctx: Any) -> Tensor|None` - Convenience to expose the last estimator output after `compute`. @@ -47,7 +50,6 @@ Adapters conform to a structural Protocol with the following surface (see `gfn/s - `batch_size`, `device`, optional `conditioning` - Optional `carry` (recurrent hidden state) - Per-step buffers: `trajectory_log_probs`, `trajectory_estimator_outputs` - - A `finalize()` method that stacks per-step artifacts into `(T, N, ...)` tensors for `Trajectories`. ## Built-in Adapters @@ -89,10 +91,12 @@ The Sampler uses the adapter lifecycle: - Sample actions from `dist`; build actions for the full batch - `adapter.record(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs)` - Step the environment forward/backward based on `adapter.is_backward` -- After rollout: `artifacts = ctx.finalize()` and populate `Trajectories`. +- After rollout: `artifacts = adapter.finalize(ctx)` and populate `Trajectories`. ## How to Implement a New Adapter +A new Adapter will only likely need changes to `compute`, `record`, and `log_prob_of_actions`. You can rely otherwise on the defaults. However we detail all of the steps below for completeness: + 1) Decide if your estimator needs a recurrent carry - some persistent state or cache that is utilized throughout the trajectory. - If yes, set `is_vectorized = False` and implement `init_context` to initialize `carry`. Implement `compute` to update `carry` each step. - If no, set `is_vectorized = True` and follow the default adapter pattern. @@ -108,6 +112,9 @@ The Sampler uses the adapter lifecycle: - Given `(states_active, actions_active)` for the active rows, compute the Distribution (reusing the same forward logic) and return a padded `(N,)` vector of `log_prob`. - Do not modify masks here; calculators pass in `step_mask` already built from existing masks. +5) Implement `finalize` + - Given the contents of your context, return the trajectory-level objects required by the Sampler. + 5) Mark `is_backward` if your estimator is a backward policy; the sampler will step the environment backward accordingly. 6) Performance Guidance @@ -116,6 +123,6 @@ The Sampler uses the adapter lifecycle: ## Reference: Legacy Implementations -The legacy, vectorized implementations are the gold standard. Adapters are designed to use those paths whenever possible (vectorized) and to exactly match their behavior when per-step evaluation is required (non-vectorized). See the reference for details: +The move to adaptors, while allowing for portentially much more complex forms of estimators, introduces significant complexity into the Sampler and probability calculation logic. The legacy, vectorized implementations of these operations exactly re-implemented in the DefaultEstimatorAdapters and the library is designed to use those paths whenever possible (i.e., using vectorized operations), and we have ensured to exactly match the behaviour of this path when using per-step evaluation (the non-vectorized path). These paths are also tested against the legacy code in `test_probability_calculations.py` to ensure correctness. See the reference for details: - `utils/prob_calculations.py` (master): [link](https://raw.githubusercontent.com/GFNOrg/torchgfn/refs/heads/master/src/gfn/utils/prob_calculations.py) \ No newline at end of file diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 24a02996..a4887b41 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, cast import torch from torch.distributions import Distribution @@ -14,34 +15,27 @@ from gfn.utils.prob_calculations import get_trajectory_pbs, get_trajectory_pfs -class EstimatorAdapter(Protocol): +class EstimatorAdapter(ABC): """Adapter interface for estimator-specific policy behavior. - Purpose - ------- - This Protocol defines the minimal interface the Sampler relies on, allowing - us to keep one generic sampling loop while plugging in different estimator + This abstract base class defines the minimal interface the Sampler relies on, + allowing us to keep one generic sampling loop while plugging in different estimator behaviors (e.g., non‑recurrent, recurrent with carry, tempered variants) - without modifying the Sampler. We use a Protocol (structural typing) so any - class that implements these members is accepted; no inheritance is required. - - Opaque context (ctx) - -------------------- - The adapter owns an opaque context object (``ctx``). The Sampler never - inspects it and simply passes it back to the adapter at each step. The - adapter is responsible for: - - initializing ``ctx`` once per rollout in ``init_context`` - - updating any internal state (e.g., recurrent ``carry``) during ``compute`` - - recording per‑step artifacts in ``record`` (e.g., log_probs, - estimator outputs), typically with mask-aware padding - Finalization is handled by the rollout context itself. - - Guidance - -------- - - Allocate ``ctx`` once per rollout; mutate it in place for performance. - - Apply masking inside the adapter (``step_mask``) when slicing conditioning - or padding per‑step tensors back to full batch size. - - Keep Sampler oblivious to estimator details (conditioning, carry, etc.). + without modifying the Sampler. + + The adapter owns an opaque RolloutContext object. The Sampler never inspects + it and simply passes it back to the adapter at each step. The adapter is + responsible for: + - initializing the context in `init_context`. + - updating any internal state (e.g., recurrent `carry`) during `compute` + - recording per‑step artifacts in `record` (e.g., log_probs, + estimator outputs), typically with mask-aware padding. + - output trajectory-length artifacts via `finalize(ctx)` + + The context should be allocated once per rollout. Masking should be applied inside + the adapter (via `step_mask`) when slicing conditioning or padding per‑step + tensors back to full batch size. The Sampler can therefore be oblivious to + estimator details (conditioning, carry, etc.). """ @property @@ -52,6 +46,7 @@ def is_backward(self) -> bool: def is_vectorized(self) -> bool: ... # fmt: skip + @abstractmethod def init_context( self, batch_size: int, @@ -60,6 +55,7 @@ def init_context( ) -> Any: ... # fmt: skip + @abstractmethod def compute( self, states_active: States, @@ -69,6 +65,7 @@ def compute( ) -> tuple[Distribution, Any]: ... # fmt: skip + @abstractmethod def record( self, ctx: Any, @@ -80,6 +77,7 @@ def record( ) -> None: ... # fmt: skip + @abstractmethod def log_prob_of_actions( self, states_active: States, @@ -129,101 +127,50 @@ def __init__( self.current_estimator_output: Optional[torch.Tensor] = None self.extras: Dict[str, Any] = {} - def append_step( - self, - step_mask: torch.Tensor, - sampled_actions: torch.Tensor, - dist: Distribution, - save_logprobs: bool, - save_estimator_outputs: bool, - ) -> None: - """Record per-step artifacts into trajectory-level buffers owned by the context.""" - N = self.batch_size - device = self.device - - if save_logprobs: - lp_masked = dist.log_prob(sampled_actions) - if torch.any(torch.isinf(lp_masked)): - raise RuntimeError("Log probabilities are inf. This should not happen.") - step_lp = torch.full((N,), 0.0, device=device) - step_lp[step_mask] = lp_masked - self.trajectory_log_probs.append(step_lp) - - if save_estimator_outputs and self.current_estimator_output is not None: - est_out = self.current_estimator_output - padded = torch.full((N,) + est_out.shape[1:], -float("inf"), device=device) - padded[step_mask] = est_out - self.trajectory_estimator_outputs.append(padded) - - def finalize(self) -> dict[str, Optional[torch.Tensor]]: - """Stack recorded per-step artifacts along time into trajectory-level tensors.""" - log_probs = ( - torch.stack(self.trajectory_log_probs, dim=0) - if self.trajectory_log_probs - else None - ) - estimator_outputs = ( - torch.stack(self.trajectory_estimator_outputs, dim=0) - if self.trajectory_estimator_outputs - else None - ) - - return {"log_probs": log_probs, "estimator_outputs": estimator_outputs} - -class DefaultEstimatorAdapter: +class DefaultEstimatorAdapter(EstimatorAdapter): """Adapter for non-recurrent estimators (current default behavior). Overview -------- - This adapter bridges the generic sampling loop and the "classic" non‑recurrent - estimators already used throughout the codebase. It exposes the minimal - interface required by the `EstimatorAdapter` Protocol while keeping the - sampler loop estimator-agnostic. - - Assumptions - ----------- - - The wrapped estimator is non‑recurrent (no carry between steps). + This adapter bridges the generic sampling loop and is used throughout the codebase. + It exposes the minimal interface required by the `EstimatorAdapter` abstract base + class while keeping the sampler loop estimator-agnostic. + - If conditioning is provided, the estimator accepts `(states, conditioning)`; otherwise it accepts `(states)`. - The estimator provides `to_probability_distribution(states, est_out, **kw)` returning a torch Distribution over actions for the masked states. - Context Lifecycle (opaque to the Sampler) - ---------------------------------------- - The adapter owns an opaque rollout context `ctx` (see `RolloutContext`) which - the Sampler never reads. The context is created once per rollout and mutated - in place: + The adapter owns an opaque rollout context `ctx` which the Sampler never reads. The + context (a TensorDict) is created once per rollout and mutated in place: - - init_context(batch_size, device, conditioning) -> ctx - Stores rollout invariants and optional conditioning. Also prepares per‑step - buffers for artifacts that may be recorded (log_probs, estimator_outputs). + - init_context: stores rollout invariants and optional conditioning. Also prepares + per‑step buffers for artifacts trajectory-level artifacts (log_probs, + estimator_outputs). - - compute(states_active, ctx, step_mask, **policy_kwargs) -> (dist, ctx) + - compute: 1) Selects the appropriate estimator call signature depending on whether conditioning is present. If conditioning is present, the adapter slices it with `step_mask` so shapes match `states_active`. 2) Calls the estimator forward pass to obtain the raw `est_out`. 3) Converts `est_out` into a torch Distribution with `to_probability_distribution`. - 4) Saves `est_out` into `ctx.current_estimator_output` so it can optionally be - recorded by `record` or exposed to callers that need it. + 4) Saves `est_out` into `ctx.current_estimator_output`. - - record(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs) + - record: Materializes optional per‑step artifacts into context‑managed buffers with - mask‑aware padding back to the full rollout batch size `N`: + mask‑aware padding back to the full rollout batch size: * Log‑probs: computes `dist.log_prob(sampled_actions)` for active rows only, - then writes into a 1D tensor of shape `(N,)` filled with zeros and masked + then writes into a 1D tensor of length batch_size filled with zeros and masked assignment for active positions. Appends this to a list (one tensor per time step). * Estimator outputs: if requested, pads the last estimator output - (`ctx.current_estimator_output`) to shape `(N, ...)` using `-inf` for inactive rows - and appends to a list (one tensor per time step). + (`ctx.current_estimator_output`) to shape `(batch_size, ...)` using `-inf` for + inactive rows and appends to a list (one tensor per time step). - Finalization is performed by the context itself: - - ctx.finalize() -> {"log_probs": Tensor | None, "estimator_outputs": Tensor | None} - Stacks recorded per‑step lists along the time dimension into tensors of shape - `(T, N, ...)` suitable for `Trajectories`. Returns `None` for any artifact - that was never recorded. + - finalize: Stacks recorded per‑step lists along the time dimension into tensors of + shape `(trajectory_legnth, batch_size, ...)` suitable for `Trajectories`. + Returns `None` for any artifact that was never recorded. Masking & Shapes ---------------- @@ -269,6 +216,7 @@ def is_backward(self) -> bool: @property def is_vectorized(self) -> bool: + """Used for vectorized probability calculations.""" return True def init_context( @@ -299,8 +247,10 @@ def compute( - Saves the raw estimator output in `ctx.current_estimator_output` for optional recording in `record_step`. """ - conditioning = ctx.conditioning # type: ignore[attr-defined] - cond_active = conditioning[step_mask] if conditioning is not None else None + cond_active = None + if ctx.conditioning is not None: + cond_active = ctx.conditioning[step_mask] + est_out = check_cond_forward( self._estimator, "estimator", states_active, cond_active ) @@ -308,7 +258,7 @@ def compute( dist = self._estimator.to_probability_distribution( states_active, est_out, **policy_kwargs ) - ctx.current_estimator_output = est_out # type: ignore[attr-defined] + ctx.current_estimator_output = est_out return dist, ctx @@ -322,14 +272,21 @@ def record( save_estimator_outputs: bool, ) -> None: """Record per-step artifacts into the context's trajectory-level lists.""" - # Delegate recording to the rollout context - ctx.append_step( # type: ignore[attr-defined] - step_mask=step_mask, - sampled_actions=sampled_actions, - dist=dist, - save_logprobs=save_logprobs, - save_estimator_outputs=save_estimator_outputs, - ) + if save_logprobs: + lp_masked = dist.log_prob(sampled_actions) + if torch.any(torch.isinf(lp_masked)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device) + step_lp[step_mask] = lp_masked + ctx.trajectory_log_probs.append(step_lp) + + if save_estimator_outputs and ctx.current_estimator_output is not None: + est_out = ctx.current_estimator_output + padded = torch.full( + (ctx.batch_size,) + est_out.shape[1:], -float("inf"), device=ctx.device + ) + padded[step_mask] = est_out + ctx.trajectory_estimator_outputs.append(padded) def log_prob_of_actions( self, @@ -339,48 +296,51 @@ def log_prob_of_actions( step_mask: torch.Tensor, **policy_kwargs: Any, ) -> tuple[torch.Tensor, Any]: - # Optional fast path: use caller-provided estimator outputs - precomputed = policy_kwargs.pop("precomputed_estimator_output", None) + # Optional fast path: use caller-provided estimator outputs only when + # shapes match. + precomputed = self.get_current_estimator_output(ctx) if precomputed is not None: + expected_bs = states_active.batch_shape[0] + assert precomputed.shape[0] == expected_bs, ( + "current_estimator_output batch size does not match active states batch size. " + f"Got precomputed={precomputed.shape[0]}, expected={expected_bs}. " + "Likely stale reuse. Ensure PB clears ctx.current_estimator_output each step " + "and PF indexes trajectories.estimator_outputs[t][step_mask]." + ) est_out = precomputed - else: - conditioning = ctx.conditioning # type: ignore[attr-defined] - cond_active = conditioning[step_mask] if conditioning is not None else None - est_out = check_cond_forward( - self._estimator, "estimator", states_active, cond_active + dist = self._estimator.to_probability_distribution( + states_active, est_out, **policy_kwargs ) - dist = self._estimator.to_probability_distribution( - states_active, est_out, **policy_kwargs - ) - ctx.current_estimator_output = est_out # type: ignore[attr-defined] + else: + # Compute fresh estimator output when no valid precomputed output is + # provided. + # TODO: Should I toggle "save_estimator_outputs" here? + dist, ctx = self.compute(states_active, ctx, step_mask, **policy_kwargs) - N = ctx.batch_size # type: ignore[attr-defined] - device = ctx.device # type: ignore[attr-defined] - lp_masked = dist.log_prob(actions_active) - if torch.any(torch.isinf(lp_masked)): + masked_log_probs = dist.log_prob(actions_active) + + if torch.any(torch.isinf(masked_log_probs)): raise RuntimeError("Log probabilities are inf. This should not happen.") - step_lp = torch.full((N,), 0.0, device=device) - step_lp[step_mask] = lp_masked - return step_lp, ctx - # Vectorized helper to mirror legacy probability calculations without per-step context - def log_prob_vectorized( - self, - states: States, - actions_tensor: torch.Tensor, - conditioning: Optional[torch.Tensor] = None, - **policy_kwargs: Any, - ) -> torch.Tensor: - """Compute log_prob for a batch of (state, action) pairs in a vectorized way. + step_log_probs = torch.full((ctx.batch_size,), 0.0, device=ctx.device) + step_log_probs[step_mask] = masked_log_probs - This mirrors the legacy vectorized path used in probability utils and uses - the adapter's estimator and distribution construction, including policy kwargs. - """ - est_out = check_cond_forward(self._estimator, "estimator", states, conditioning) - dist = self._estimator.to_probability_distribution( - states, est_out, **policy_kwargs + return step_log_probs, ctx + + def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: + """Stack recorded per-step artifacts along time into trajectory-level tensors.""" + log_probs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None ) - return dist.log_prob(actions_tensor) + + return {"log_probs": log_probs, "estimator_outputs": estimator_outputs} def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: """Expose the most recent per-step estimator output saved during `compute`.""" @@ -388,64 +348,24 @@ def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): - """Adapter for recurrent estimators that require and update a carry. - - Overview - -------- - This adapter extends the default (non‑recurrent) behavior to handle models that - maintain a recurrent state ("carry"). It exposes the same surface as - `DefaultEstimatorAdapter`, with the following differences: - - - `is_vectorized = False`: Probability calculators will use a non‑vectorized - (per‑step) path that mirrors the legacy reference exactly (including masks and - state/action alignment), since a recurrent carry must be updated sequentially. - - The rollout `ctx` stores a `carry` that is initialized once via - `estimator.init_carry(batch_size, device)` and updated at every call to - `compute`/`log_prob_of_actions`. - - Context Lifecycle (opaque to the Sampler) - ---------------------------------------- - The adapter owns an opaque rollout context `ctx` (see `RolloutContext`) which the - Sampler never reads. The context is created once per rollout and mutated in place: - - - init_context(batch_size, device, conditioning) -> ctx - Stores rollout invariants and optional conditioning. Initializes recurrent - `carry` via `estimator.init_carry`. Also prepares per‑step buffers for optional - artifacts (log_probs, estimator_outputs). - - - compute(states_active, ctx, step_mask, **policy_kwargs) -> (dist, ctx) - 1) Calls the recurrent estimator as `(states_active, ctx.carry) -> (est_out, new_carry)` - and stores `new_carry` back into `ctx.carry`. - 2) Converts `est_out` into a torch Distribution with - `to_probability_distribution(states_active, est_out, **policy_kwargs)`. - 3) Saves `est_out` into `ctx.current_estimator_output` for optional recording. - - - record(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs) - Materializes optional per‑step artifacts into context‑managed buffers with - mask‑aware padding back to the full rollout batch size `N`: - * Log‑probs: computes `dist.log_prob(sampled_actions)` for active rows only, - then writes into a 1D tensor of shape `(N,)` filled with zeros and masked - assignment for active positions. Appends this to a list (one tensor per time step). - * Estimator outputs: if requested, pads `ctx.current_estimator_output` to shape - `(N, ...)` using `-inf` for inactive rows and appends to a list. - - Finalization is performed by the context itself: - - ctx.finalize() -> {"log_probs": Tensor | None, "estimator_outputs": Tensor | None} - Stacks recorded per‑step lists along the time dimension into tensors of shape - `(T, N, ...)` suitable for `Trajectories`. Returns `None` for any artifact - that was never recorded. - - Probability Calculators - ----------------------- - Since `is_vectorized = False`, the PF/PB probability calculators use the - non‑vectorized, per‑step path that matches the legacy reference: - - Trajectory PF: `step_mask = ~states.is_sink_state[t] & ~actions.is_dummy[t]`. - - Trajectory PB: actions at time `t` are aligned with states at time `t+1`, and - `step_mask = ~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] - & ~actions.is_dummy[t] & ~actions.is_exit[t]` with `t==0` skipped. - - Transitions: the same legacy masks are used, and a single adapter call is made - per batch. - No mask indexing with action ids is used; distributions handle illegal actions. + """Adapter for recurrent estimators that maintain a rollout carry (hidden state). + + Differences from `DefaultEstimatorAdapter`: + - is_vectorized = False: runs sequential, per‑step probability calculators (legacy PF/PB + masks and alignment; PB aligns action at t with state at t+1, t==0 skipped). + - Rollout context manages `ctx.carry` which contains the hidden state of the + recurrent estimator. It is initialized once via `estimator.init_carry(batch_size, + device)` and updated every step. + - `compute(states, ctx, ...)` calls `estimator(states, ctx.carry) -> + (est_out, new_carry)`, updates `ctx.carry`, then builds the Distribution from + `est_out`. + - recording mirrors the default but pads per‑step tensors to batch size and + stacks into `(T, N, ...)`. + - No action‑id mask indexing; illegal actions are handled by the Distribution. + + Requires the estimator to implement: + - `init_carry(batch_size: int, device: torch.device)` + - a recurrent forward returning `(est_out, new_carry)`. """ def __init__(self, estimator: Estimator) -> None: @@ -459,6 +379,7 @@ def __init__(self, estimator: Estimator) -> None: ) super().__init__(estimator) + # TODO: Need to support vectorized probability calculations with Transformers. @property def is_vectorized(self) -> bool: return False @@ -484,7 +405,6 @@ def init_context( "expose `init_carry`." ) ctx = super().init_context(batch_size, device, conditioning) - # Expect estimator to implement init_carry(batch_size, device) init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) ctx.carry = init_carry_fn(batch_size, device) @@ -505,12 +425,12 @@ def compute( Distribution. """ # Recurrent estimators are expected to accept (states, carry) -> (out, new_carry) - est_out, new_carry = self._estimator(states_active, ctx.carry) # type: ignore[attr-defined] - ctx.carry = new_carry # type: ignore[attr-defined] + est_out, new_carry = self._estimator(states_active, ctx.carry) + ctx.carry = new_carry dist = self._estimator.to_probability_distribution( states_active, est_out, **policy_kwargs ) - ctx.current_estimator_output = est_out # type: ignore[attr-defined] + ctx.current_estimator_output = est_out return dist, ctx @@ -522,22 +442,16 @@ def log_prob_of_actions( step_mask: torch.Tensor, **policy_kwargs: Any, ) -> tuple[torch.Tensor, Any]: - # Recurrent estimators are expected to accept (states, carry) -> (out, new_carry) - est_out, new_carry = self._estimator(states_active, ctx.carry) # type: ignore[attr-defined] - ctx.carry = new_carry # type: ignore[attr-defined] - dist = self._estimator.to_probability_distribution( - states_active, est_out, **policy_kwargs - ) - ctx.current_estimator_output = est_out # type: ignore[attr-defined] + # TODO: Should I toggle "save_estimator_outputs" here? + # TODO: Need to look at how we handle logprobs for on-policy training. + dist, ctx = self.compute(states_active, ctx, step_mask, **policy_kwargs) - N = ctx.batch_size # type: ignore[attr-defined] - device = ctx.device # type: ignore[attr-defined] - lp_masked = dist.log_prob(actions_active) - if torch.any(torch.isinf(lp_masked)): + masked_log_probs = dist.log_prob(actions_active) + if torch.any(torch.isinf(masked_log_probs)): raise RuntimeError("Log probabilities are inf. This should not happen.") - step_lp = torch.full((N,), 0.0, device=device) - step_lp[step_mask] = lp_masked - return step_lp, ctx + step_log_probs = torch.full((ctx.batch_size,), 0.0, device=ctx.device) + step_log_probs[step_mask] = masked_log_probs + return step_log_probs, ctx class Sampler: @@ -642,10 +556,13 @@ def sample_actions( actions = env.actions_from_tensor(actions_tensor) estimator_output = None - if save_estimator_outputs and hasattr( - self.adapter, "get_current_estimator_output" - ): + if save_estimator_outputs: + if not hasattr(self.adapter, "get_current_estimator_output"): + raise TypeError( + "Adapter does not support get_current_estimator_output and save_estimator_outputs is True!" + ) estimator_output = self.adapter.get_current_estimator_output(ctx) + assert estimator_output is not None assert log_probs is None or log_probs.shape == actions.batch_shape @@ -805,7 +722,7 @@ def sample_trajectories( 1: ] # Drop dummy action # Finalize stacked trajectory artifacts from the context (already shaped (T, N, ...)) - trajectory_artifacts = ctx.finalize() # type: ignore[attr-defined] + trajectory_artifacts = self.adapter.finalize(ctx) # type: ignore[attr-defined] stacked_logprobs = trajectory_artifacts.get("log_probs", None) stacked_estimator_outputs = trajectory_artifacts.get("estimator_outputs", None) diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 31c3ba79..15fd744f 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -154,28 +154,31 @@ def get_trajectory_pfs( state_ok = ~trajectories.states.is_sink_state[t] action_ok = ~trajectories.actions.is_dummy[t] step_mask = state_ok & action_ok + if not torch.any(step_mask): continue + step_states = trajectories.states[t][step_mask] step_actions = trajectories.actions.tensor[t][step_mask] + # Optimization: forward cached estimator outputs when available if ( trajectories.estimator_outputs is not None and not recalculate_all_logprobs ): - precomputed = trajectories.estimator_outputs[t][step_mask] - step_lp, ctx = adapter.log_prob_of_actions( # type: ignore[union-attr] - step_states, - step_actions, - ctx, - step_mask, - precomputed_estimator_output=precomputed, - **policy_kwargs, - ) + ctx.current_estimator_output = trajectories.estimator_outputs[t][ + step_mask + ] else: - step_lp, ctx = adapter.log_prob_of_actions( # type: ignore[union-attr] - step_states, step_actions, ctx, step_mask, **policy_kwargs - ) + # Ensure we do not accidentally reuse estimator outputs from a + # previous time step. Precomputed outputs must be provided + # explicitly for the current step. + ctx.current_estimator_output = None + + # Calculate the log-probabilities of the actions. + step_lp, ctx = adapter.log_prob_of_actions( # type: ignore[union-attr] + step_states, step_actions, ctx, step_mask, **policy_kwargs + ) if fill_value != 0.0: padded = torch.full( (N,), fill_value, device=device, dtype=step_lp.dtype @@ -184,7 +187,7 @@ def get_trajectory_pfs( step_lp = padded log_pf_trajectories[t] = step_lp else: - # Legacy vectorized path (strict reference behavior) + # Vectorized path. log_pf_trajectories = torch.full_like( trajectories.actions.tensor[..., 0], fill_value=fill_value, @@ -338,6 +341,7 @@ def get_trajectory_pbs( N = trajectories.n_trajectories device = trajectories.states.device cond = trajectories.conditioning + if cond is not None and len(cond.shape) >= 2: cond = cond[0] ctx = adapter.init_context(int(N), device, cond) # type: ignore[arg-type] @@ -349,8 +353,9 @@ def get_trajectory_pbs( ~trajectories.states.is_initial_state[t + 1] ) if t == 0: - # Legacy explicitly disables PB at t=0 + # log PB is always zero for the transition s1 -> s0. state_ok = torch.zeros_like(state_ok, dtype=torch.bool) + action_ok = (~trajectories.actions.is_dummy[t]) & ( ~trajectories.actions.is_exit[t] ) @@ -358,11 +363,17 @@ def get_trajectory_pbs( if not torch.any(step_mask): continue + step_states = trajectories.states[t + 1][step_mask] step_actions = trajectories.actions.tensor[t][step_mask] + + # Prevent reusing last step's estimator output (batch size may differ, + # and estimator output caching isn't needed for PB). + ctx.current_estimator_output = None step_lp, ctx = adapter.log_prob_of_actions( step_states, step_actions, ctx, step_mask, **policy_kwargs ) + padded = torch.full((N,), fill_value, device=device, dtype=step_lp.dtype) padded[step_mask] = step_lp[step_mask] log_pb_trajectories[t] = padded diff --git a/testing/test_probability_calculations.py b/testing/test_probability_calculations.py index e71a7825..e3e7952f 100644 --- a/testing/test_probability_calculations.py +++ b/testing/test_probability_calculations.py @@ -345,10 +345,9 @@ def test_adapter_log_prob_of_actions_precomputed_matches_forward(): # Baseline: adapter recomputes estimator outputs internally lp1, _ = adapter.log_prob_of_actions(states, actions_tensor, ctx1, step_mask) - # Precomputed: adapter uses provided estimator outputs (fast path) - lp2, _ = adapter.log_prob_of_actions( - states, actions_tensor, ctx2, step_mask, precomputed_estimator_output=est_out - ) + # Precomputed: adapter uses provided estimator outputs (fast path). + ctx2.current_estimator_output = est_out + lp2, _ = adapter.log_prob_of_actions(states, actions_tensor, ctx2, step_mask) torch.testing.assert_close(lp1, lp2) @@ -438,4 +437,4 @@ def test_get_transition_pbs_matches_legacy_with_default_adapter(): if __name__ == "__main__": - test_get_trajectory_pbs_matches_legacy_with_default_adapter() + test_trajectory_pb_vectorized_vs_nonvectorized_parity() diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 3ae98668..8245201c 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -536,7 +536,7 @@ def test_default_adapter_compute_record_finalize(): adapter.record( ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True ) - out = ctx.finalize() + out = adapter.finalize(ctx) assert out["log_probs"] is not None and out["log_probs"].shape == (1, n) assert out["estimator_outputs"] is not None and out["estimator_outputs"].shape[ :2 @@ -574,7 +574,7 @@ def test_recurrent_adapter_flow(): ) h1 = ctx.carry["hidden"].clone() assert torch.all(h1 == h0 + 1) - out = ctx.finalize() + out = adapter.finalize(ctx) assert out["log_probs"] is not None and out["log_probs"].shape == (2, n) assert out["estimator_outputs"] is not None and out["estimator_outputs"].shape[ :2 @@ -649,7 +649,7 @@ def test_integration_recurrent_sequence_model_with_adapter( save_estimator_outputs=True, ) - out = ctx.finalize() + out = adapter.finalize(ctx) log_probs = out["log_probs"] estimator_outputs = out["estimator_outputs"] assert log_probs is not None @@ -699,7 +699,7 @@ def test_integration_transformer_sequence_model_with_adapter( ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True ) - out = ctx.finalize() + out = adapter.finalize(ctx) assert out["log_probs"] is not None and out["log_probs"].shape[0] == 1 assert ( out["estimator_outputs"] is not None and out["estimator_outputs"].shape[0] == 1 From e2dc289acbd434b8d313b76eaf13dcba0713a3df Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sun, 12 Oct 2025 23:13:04 -0400 Subject: [PATCH 16/27] simplification of the contex, adapter logic, compression of documentation, added some useful safeguard assertions, bugfix related to saving estimator_outputs path --- src/gfn/estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 0b950fe3..170f37d9 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -845,7 +845,7 @@ def __init__( preprocessor: Preprocessor object that transforms states to tensors. """ if preprocessor is None: - preprocessor = IdentityPreprocessor(output_dim=None) + preprocessor = IdentityPreprocessor(output_dim=None) super().__init__( module=module, From 3c2862f2d43b1e7a25be2b916de937f5f92e0e8e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 02:20:47 -0400 Subject: [PATCH 17/27] streamlined adapters under their own module --- src/gfn/adapters.py | 471 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 471 insertions(+) create mode 100644 src/gfn/adapters.py diff --git a/src/gfn/adapters.py b/src/gfn/adapters.py new file mode 100644 index 00000000..c09805f3 --- /dev/null +++ b/src/gfn/adapters.py @@ -0,0 +1,471 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, cast +from inspect import signature + +import torch +from torch.distributions import Distribution + +from gfn.estimators import Estimator +from gfn.states import States +from gfn.utils.handlers import check_cond_forward + + +class EstimatorAdapter(ABC): + """Adapter interface for estimator-specific policy behavior. + + This abstract base class defines the minimal interface the Sampler relies on, + allowing us to keep one generic sampling loop while plugging in different estimator + behaviors (e.g., non‑recurrent, recurrent with carry, tempered variants) + without modifying the Sampler. + + The adapter owns an opaque RolloutContext object. The Sampler never inspects + it and simply passes it back to the adapter at each step. The adapter is + responsible for: + - initializing the context in `init_context`. + - compute the action distribution, while updating any internal state (e.g., + recurrent `carry`) + - compute the log probabilities of the actions in `log_probs`. + - recording per‑step artifacts in `record` (e.g., log_probs, + estimator outputs), typically with mask-aware padding. + - output trajectory-length artifacts via `finalize(ctx)` + + The context should be allocated once per rollout. Masking should be applied inside + the adapter (via `step_mask`) when slicing conditioning or padding per‑step + tensors back to full batch size. The Sampler can therefore be oblivious to + estimator details (conditioning, carry, etc.). + """ + + @property + def is_backward(self) -> bool: + ... # fmt: skip + + @property + def is_vectorized(self) -> bool: + ... # fmt: skip + + @abstractmethod + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> Any: + ... # fmt: skip + + @abstractmethod + def compute_dist( + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + ... # fmt: skip + + @abstractmethod + def log_probs( + self, + actions_active: torch.Tensor, + dist: Distribution, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + vectorized: bool = False, + ) -> tuple[torch.Tensor, Any]: + ... # fmt: skip + + @abstractmethod + def record( + self, + ctx: Any, + step_mask: torch.Tensor, + sampled_actions: torch.Tensor, + dist: Distribution, + log_probs: Optional[torch.Tensor], + save_estimator_outputs: bool, + ) -> None: + ... # fmt: skip + + # Optional helper for `sample_actions` BC + def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: + ... # fmt: skip + + +class RolloutContext: + """Structured, mutable context owned by adapters. + + Uses fixed attributes for core fields and an `extras` dict for adapter- + specific extensions without changing the class shape. This keeps most + accesses fast and typed while preserving flexibility similar to dicts. + """ + + __slots__ = ( + "batch_size", + "device", + "conditioning", + "carry", + "trajectory_log_probs", + "trajectory_estimator_outputs", + "current_estimator_output", + "extras", + ) + + def __init__( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> None: + self.batch_size = batch_size + self.device = device + self.conditioning = conditioning + self.carry = None + self.trajectory_log_probs: List[torch.Tensor] = [] + self.trajectory_estimator_outputs: List[torch.Tensor] = [] + self.current_estimator_output: Optional[torch.Tensor] = None + self.extras: Dict[str, Any] = {} + + +class DefaultEstimatorAdapter(EstimatorAdapter): + """Adapter for non-recurrent estimators (current default behavior). + + Overview + -------- + This adapter bridges the generic sampling loop and is used throughout the codebase. + It exposes the minimal interface required by the `EstimatorAdapter` abstract base + class while keeping the sampler loop estimator-agnostic. + + - If conditioning is provided, the estimator accepts `(states, conditioning)`; + otherwise it accepts `(states)`. + - The estimator provides `to_probability_distribution(states, est_out, **kw)` + returning a torch Distribution over actions for the masked states. + + The adapter owns an opaque rollout context `ctx` which the Sampler never reads. The + context (a TensorDict) is created once per rollout and mutated in place: + + - init_context: stores rollout invariants and optional conditioning. Also prepares + per‑step buffers for artifacts trajectory-level artifacts (log_probs, + estimator_outputs). + + - compute: + 1) Selects the appropriate estimator call signature depending on whether + conditioning is present. If conditioning is present, the adapter slices + it with `step_mask` so shapes match `states_active`. + 2) Calls the estimator forward pass to obtain the raw `est_out`. + 3) Converts `est_out` into a torch Distribution with + `to_probability_distribution`. + 4) Saves `est_out` into `ctx.current_estimator_output`. + + - record: + Materializes optional per‑step artifacts into context‑managed buffers with + mask‑aware padding back to the full rollout batch size: + * Log‑probs: computes `dist.log_prob(sampled_actions)` for active rows only, + then writes into a 1D tensor of length batch_size filled with zeros and masked + assignment for active positions. Appends this to a list (one tensor per time step). + * Estimator outputs: if requested, pads the last estimator output + (`ctx.current_estimator_output`) to shape `(batch_size, ...)` using `-inf` for + inactive rows and appends to a list (one tensor per time step). + + - finalize: Stacks recorded per‑step lists along the time dimension into tensors of + shape `(trajectory_legnth, batch_size, ...)` suitable for `Trajectories`. + Returns `None` for any artifact that was never recorded. + + Masking & Shapes + ---------------- + - `states_active` always corresponds to `states[~dones]` inside the sampler. + - The adapter receives `step_mask` (shape `(N,)`) to slice any step‑dependent + inputs (e.g., conditioning) and to pad per‑step outputs to the full batch. + - Padded tensors use `0.0` for log‑probs and `-inf` for estimator outputs to + maintain compatibility with downstream code. + + Backward/Forward Direction + -------------------------- + - `is_backward` is forwarded from the underlying estimator so the sampler can + choose the appropriate environment transition (forward vs backward). + + Vectorized Probability Path + -------------------------- + - `is_vectorized` is used by the Sampler to choose the appropriate probability path. + - Vectorized adapters always use faster paths in probability calculators. + Non-vectorized adapters (e.g., recurrent) use per-step paths with masking and + alignment identical to the legacy reference. + + Performance Notes + ----------------- + - `ctx` is allocated once per rollout and mutated in place to avoid per‑step + overhead. + - If you know trajectory length bounds, you can extend this adapter to + pre‑allocate fixed‑size storage in `init_context` rather than appending to + Python lists. + """ + + def __init__(self, estimator: Estimator) -> None: + """Initialize the adapter with a non-recurrent estimator. + + The estimator must expose `to_probability_distribution(states, est_out, **kw)` + and optionally accept conditioning via `estimator(states, conditioning)`. + """ + self._estimator = estimator + + @property + def is_backward(self) -> bool: + """Whether the wrapped estimator samples in the backward direction.""" + return getattr(self._estimator, "is_backward", False) + + @property + def is_vectorized(self) -> bool: + """Used for vectorized probability calculations.""" + return True + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> RolloutContext: + """Create a new per-rollout context. + + Stores rollout invariants (batch size, device, optional conditioning) and + initializes empty buffers for per-step artifacts. + """ + return RolloutContext( + batch_size=batch_size, device=device, conditioning=conditioning + ) + + def compute_dist( + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + """Run the estimator for active rows and build an action Distribution. + + - Uses `step_mask` to slice conditioning to the active subset. + - Saves the raw estimator output in `ctx.current_estimator_output` for + optional recording in `record_step`. + """ + cond_active = None + if ctx.conditioning is not None: + if step_mask is None: + cond_active = ctx.conditioning + else: + cond_active = ctx.conditioning[step_mask] + + estimator_outputs = check_cond_forward( + self._estimator, "estimator", states_active, cond_active + ) + + dist = self._estimator.to_probability_distribution( + states_active, estimator_outputs, **policy_kwargs + ) + + # TODO: Make optional. + ctx.current_estimator_output = estimator_outputs + + return dist, ctx + + def log_probs( + self, + actions_active: torch.Tensor, + dist: Distribution, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + vectorized: bool = False, + ) -> tuple[torch.Tensor, Any]: + """Compute log-probs, optionally padding back to full batch when non-vectorized.""" + lp = dist.log_prob(actions_active) + if torch.any(torch.isinf(lp)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + + if vectorized: + return lp, ctx + + assert step_mask is not None, "step_mask is required when vectorized=False" + step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device, dtype=lp.dtype) + step_lp[step_mask] = lp + + return step_lp, ctx + + def record( + self, + ctx: Any, + step_mask: torch.Tensor, + sampled_actions: torch.Tensor, + dist: Distribution, + log_probs: Optional[torch.Tensor], + save_estimator_outputs: bool, + ) -> None: + """Record per-step artifacts into the context's trajectory-level lists.""" + if log_probs is not None: + ctx.trajectory_log_probs.append(log_probs) + + if save_estimator_outputs and ctx.current_estimator_output is not None: + estimator_outputs = ctx.current_estimator_output + padded = torch.full( + (ctx.batch_size,) + estimator_outputs.shape[1:], -float("inf"), device=ctx.device + ) + padded[step_mask] = estimator_outputs + ctx.trajectory_estimator_outputs.append(padded) + + def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: + """Stack recorded per-step artifacts along time into trajectory-level tensors.""" + log_probs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + + return {"log_probs": log_probs, "estimator_outputs": estimator_outputs} + + def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: + """Expose the most recent per-step estimator output saved during `compute`.""" + return getattr(ctx, "current_estimator_output", None) + + +class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): + """Adapter for recurrent estimators that maintain a rollout carry (hidden state). + + Differences from `DefaultEstimatorAdapter`: + - is_vectorized = False: runs sequential, per‑step probability calculators for + PF/PB. PB aligns action at t with state at t+1, t==0 skipped). + - Rollout context manages `ctx.carry` which contains the hidden state of the + recurrent estimator. It is initialized once via `estimator.init_carry(batch_size, + device)` and updated every step. + - `compute(states, ctx, ...)` calls `estimator(states, ctx.carry) -> + (estimator_outputs, new_carry)`, updates `ctx.carry`, then builds the + Distribution from `estimator_outputs`. + - recording mirrors the default but pads per‑step tensors to batch size and + stacks into `(T, N, ...)`. + - No action‑id mask indexing; illegal actions are handled by the Distribution. + + Requires the estimator to implement: + - `init_carry(batch_size: int, device: torch.device)` + - a recurrent forward returning `(estimator_outputs, new_carry)`. + """ + + def __init__(self, estimator: Estimator) -> None: + # Validate that the estimator presents a recurrent interface + # We check for the presence of `init_carry` and a callable that accepts (states, carry). + init_carry = getattr(estimator, "init_carry", None) + if not callable(init_carry): + raise TypeError( + "RecurrentEstimatorAdapter requires an estimator implementing " + "init_carry(batch_size: int, device: torch.device)." + ) + super().__init__(estimator) + + # TODO: Need to support vectorized probability calculations with Transformers. + @property + def is_vectorized(self) -> bool: + return False + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> RolloutContext: + """Create context and initialize recurrent carry, (estimator hidden state). + + Differs from the default adapter by allocating `ctx.carry` via + `estimator.init_carry(batch_size, device)`. + """ + init_carry = getattr(self._estimator, "init_carry", None) + if not callable(init_carry): + raise TypeError( + "RecurrentEstimatorAdapter requires an estimator that implements " + "init_carry(batch_size: int, device: torch.device).\n" + "A) Recurrent estimators must expose an `init_carry` method.\n" + "B) RecurrentEstimatorAdapter is only compatible with estimators that " + "expose `init_carry`." + ) + ctx = super().init_context(batch_size, device, conditioning) + init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) + ctx.carry = init_carry_fn(batch_size, device) + + return ctx + + def compute_dist( + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + """Run estimator with carry and update it. + + Differs from the default adapter by calling + `estimator(states_active, ctx.carry) -> (est_out, new_carry)`, storing the + updated carry and saving `current_estimator_output` before building the + Distribution. + """ + estimator_outputs, new_carry = self._estimator(states_active, ctx.carry) + ctx.carry = new_carry + dist = self._estimator.to_probability_distribution( + states_active, + estimator_outputs, + **policy_kwargs, + ) + + # TODO: Make optional. + ctx.current_estimator_output = estimator_outputs + + return dist, ctx + + +def maybe_instantiate_adapter( + estimator: Estimator, + adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None, +) -> EstimatorAdapter: + """Maybe instantiate an adapter for a given estimator. + + Args: + estimator: The estimator to instantiate an adapter for. + adapter: An adapter class instance or callable to use for sampling actions + and computing probability distributions. If None, the default adapter class + for the estimator will be used. + + Returns: + An adapter instance. + """ + # If no adapter is provided, use the default adapter class for the estimator, + # which we need to retrieve and instantiate here. + if adapter is None: + adapter_cls = estimator.default_adapter_class + assert ( + adapter_cls is not None + ), "Estimator has no default adapter class and no adapter was provided" + adapter_cls = cast(Callable[[Estimator], EstimatorAdapter], adapter_cls) + return adapter_cls(estimator) + + # If an adapter class is provided, instantiate it with the estimator. + elif isinstance(adapter, type) and issubclass(adapter, EstimatorAdapter): + + # We have to assume that the adapter class accepts exactly 1 argument + # (estimator). + sig = signature(adapter) + + # Count parameters excluding 'self' + params = [p for p in sig.parameters.values() if p.name != "self"] + if len(params) != 1: + raise TypeError( + f"Adapter class {adapter.__name__} must accept exactly 1 argument " + f"(estimator) to use automatic adapter instantiation, " + f"but has {len(params)} parameters: {[p.name for p in params]}," + f"You can provide an adapter instance to the GFlowNet instead." + ) + + adapter_factory = cast(Callable[[Estimator], EstimatorAdapter], adapter) + return adapter_factory(estimator) + + # If an adapter instance is provided, use it. + elif isinstance(adapter, EstimatorAdapter): + return adapter + + else: + raise ValueError(f"Invalid adapter type: {type(adapter)}") From 4a23ea029dde274c80df43a5bd94812a0e7a5b07 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 04:47:12 -0400 Subject: [PATCH 18/27] typing --- src/gfn/adapters.py | 68 ++- src/gfn/estimators.py | 55 +- src/gfn/gflownet/base.py | 24 +- src/gfn/gflownet/detailed_balance.py | 31 +- src/gfn/gflownet/flow_matching.py | 4 +- src/gfn/gflownet/sub_trajectory_balance.py | 11 +- src/gfn/gflownet/trajectory_balance.py | 17 +- src/gfn/samplers.py | 483 +----------------- src/gfn/utils/handlers.py | 8 +- src/gfn/utils/prob_calculations.py | 346 ++++++------- ..._adaptor_estimator_gflownet_integration.py | 2 +- testing/test_probability_calculations.py | 17 +- testing/test_samplers_and_trajectories.py | 63 +-- tutorials/examples/test_scripts.py | 2 +- .../examples/train_bitsequence_recurrent.py | 2 +- 15 files changed, 412 insertions(+), 721 deletions(-) diff --git a/src/gfn/adapters.py b/src/gfn/adapters.py index c09805f3..fac92e7f 100644 --- a/src/gfn/adapters.py +++ b/src/gfn/adapters.py @@ -1,14 +1,16 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, cast from inspect import signature +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast import torch from torch.distributions import Distribution -from gfn.estimators import Estimator from gfn.states import States from gfn.utils.handlers import check_cond_forward +if TYPE_CHECKING: + from gfn.estimators import Estimator + class EstimatorAdapter(ABC): """Adapter interface for estimator-specific policy behavior. @@ -198,7 +200,7 @@ class while keeping the sampler loop estimator-agnostic. Python lists. """ - def __init__(self, estimator: Estimator) -> None: + def __init__(self, estimator: "Estimator") -> None: """Initialize the adapter with a non-recurrent estimator. The estimator must expose `to_probability_distribution(states, est_out, **kw)` @@ -240,21 +242,39 @@ def compute_dist( ) -> tuple[Distribution, Any]: """Run the estimator for active rows and build an action Distribution. - - Uses `step_mask` to slice conditioning to the active subset. + - Uses `step_mask` to slice conditioning to the active subset. When `step_mask` + is None, the estimator running in a vectorized context. - Saves the raw estimator output in `ctx.current_estimator_output` for optional recording in `record_step`. """ - cond_active = None - if ctx.conditioning is not None: - if step_mask is None: - cond_active = ctx.conditioning - else: - cond_active = ctx.conditioning[step_mask] - - estimator_outputs = check_cond_forward( - self._estimator, "estimator", states_active, cond_active - ) + precopmputed_estimator_outputs = getattr(ctx, "current_estimator_output", None) + + # Reuse precomputed outputs only in vectorized contexts (no step_mask). + if step_mask is None and precopmputed_estimator_outputs is not None: + expected_bs = states_active.batch_shape[0] + if precopmputed_estimator_outputs.shape[0] != expected_bs: + raise RuntimeError( + "current_estimator_output batch size does not match active states. " + f"Got {precopmputed_estimator_outputs.shape[0]}, expected {expected_bs}. " + "This indicates stale cache reuse; ensure per-step masking when setting " + "ctx.current_estimator_output and clear it when not valid." + ) + estimator_outputs = precopmputed_estimator_outputs + + # Otherwise, compute the estimator outputs. + else: + cond_active = None + if ctx.conditioning is not None: + if step_mask is None: + cond_active = ctx.conditioning + else: + cond_active = ctx.conditioning[step_mask] + + estimator_outputs = check_cond_forward( + self._estimator, "estimator", states_active, cond_active + ) + # Build the distribution. dist = self._estimator.to_probability_distribution( states_active, estimator_outputs, **policy_kwargs ) @@ -274,12 +294,14 @@ def log_probs( ) -> tuple[torch.Tensor, Any]: """Compute log-probs, optionally padding back to full batch when non-vectorized.""" lp = dist.log_prob(actions_active) - if torch.any(torch.isinf(lp)): - raise RuntimeError("Log probabilities are inf. This should not happen.") if vectorized: return lp, ctx + # Non-vectorized path strict check. None of these should be -inf after masking. + if torch.any(torch.isinf(lp)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + assert step_mask is not None, "step_mask is required when vectorized=False" step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device, dtype=lp.dtype) step_lp[step_mask] = lp @@ -302,7 +324,9 @@ def record( if save_estimator_outputs and ctx.current_estimator_output is not None: estimator_outputs = ctx.current_estimator_output padded = torch.full( - (ctx.batch_size,) + estimator_outputs.shape[1:], -float("inf"), device=ctx.device + (ctx.batch_size,) + estimator_outputs.shape[1:], + -float("inf"), + device=ctx.device, ) padded[step_mask] = estimator_outputs ctx.trajectory_estimator_outputs.append(padded) @@ -348,7 +372,7 @@ class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): - a recurrent forward returning `(estimator_outputs, new_carry)`. """ - def __init__(self, estimator: Estimator) -> None: + def __init__(self, estimator: "Estimator") -> None: # Validate that the estimator presents a recurrent interface # We check for the presence of `init_carry` and a callable that accepts (states, carry). init_carry = getattr(estimator, "init_carry", None) @@ -419,8 +443,8 @@ def compute_dist( def maybe_instantiate_adapter( - estimator: Estimator, - adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None, + estimator: "Estimator", + adapter: Callable[["Estimator"], EstimatorAdapter] | EstimatorAdapter | None, ) -> EstimatorAdapter: """Maybe instantiate an adapter for a given estimator. @@ -440,7 +464,7 @@ def maybe_instantiate_adapter( assert ( adapter_cls is not None ), "Estimator has no default adapter class and no adapter was provided" - adapter_cls = cast(Callable[[Estimator], EstimatorAdapter], adapter_cls) + adapter_cls = cast(Callable[["Estimator"], EstimatorAdapter], adapter_cls) return adapter_cls(estimator) # If an adapter class is provided, instantiate it with the estimator. @@ -460,7 +484,7 @@ def maybe_instantiate_adapter( f"You can provide an adapter instance to the GFlowNet instead." ) - adapter_factory = cast(Callable[[Estimator], EstimatorAdapter], adapter) + adapter_factory = cast(Callable[["Estimator"], EstimatorAdapter], adapter) return adapter_factory(estimator) # If an adapter instance is provided, use it. diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 170f37d9..825fabb5 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -8,6 +8,11 @@ from torch.distributions import Categorical, Distribution from gfn.actions import GraphActions, GraphActionType +from gfn.adapters import ( + DefaultEstimatorAdapter, + EstimatorAdapter, + RecurrentEstimatorAdapter, +) from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, States from gfn.utils.distributions import GraphActionDistribution, UnsqueezedCategorical @@ -51,8 +56,11 @@ class Estimator(ABC, nn.Module): `IdentityPreprocessor`. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. + _default_adapter_class: The default adapter class for this estimator. """ + _default_adapter_class = DefaultEstimatorAdapter + def __init__( self, module: nn.Module, @@ -115,6 +123,11 @@ def expected_output_dim(self) -> Optional[int]: is not well-defined (e.g., when the output is a TensorDict for GraphActions). """ + @property + def default_adapter_class(self) -> type[EstimatorAdapter] | None: + """The default adapter class for this estimator.""" + return self._default_adapter_class + def to_probability_distribution( self, states: States, @@ -160,8 +173,11 @@ class ScalarEstimator(Estimator): that can be used as input to the module. is_backward: Always False for ScalarEstimator (since it's direction-agnostic). reduction_function: Function used to reduce multi-dimensional outputs to scalars. + _default_adapter_class: There is no default adapter class for this estimator. """ + _default_adapter_class = None + def __init__( self, module: nn.Module, @@ -172,8 +188,8 @@ def __init__( Args: module: The neural network module to use. - preprocessor: Preprocessor object that transforms states to tensors. If None, - uses `IdentityPreprocessor` with the module's input_dim. + preprocessor: Preprocessor object that transforms states to tensors. If + None, uses `IdentityPreprocessor` with the module's input_dim. reduction: String name of one of the REDUCTION_FUNCTIONS keys. """ super().__init__(module, preprocessor, False) @@ -219,6 +235,13 @@ class LogitBasedEstimator(Estimator): This class is used to define estimators that output logits, which can be used to construct probability distributions. + + Attributes: + module: The neural network module to use. + preprocessor: Preprocessor object that transforms raw States objects to tensors. + is_backward: Flag indicating whether this estimator is for backward policy, + i.e., is used for predicting probability distributions over parents. + _default_adapter_class: The default adapter class for this estimator. """ @staticmethod @@ -365,11 +388,17 @@ def _compute_logits_for_distribution( class ConditionalLogZEstimator(ScalarEstimator): """Conditional logZ estimator. - This estimator is used to estimate the logZ of a GFlowNet from a conditioning tensor. - Since conditioning is a tensor, it does not have a preprocessor. Reduction is used - to aggregate the outputs of the module into a single scalar. + This estimator is used to estimate the logZ of a GFlowNet from a conditioning + tensor. Since conditioning is a tensor, it does not have a preprocessor. Reduction is used to aggregate the outputs of the module into a single scalar. + + Attributes: + module: The neural network module to use. + reduction: String name of one of the REDUCTION_FUNCTIONS keys. + _default_adapter_class: There is no default adapter class for this estimator. """ + _default_adapter_class = None + def __init__(self, module: nn.Module, reduction: str = "mean"): super().__init__(module, preprocessor=None, reduction=reduction) @@ -393,6 +422,7 @@ class DiscretePolicyEstimator(LogitBasedEstimator): preprocessor: Preprocessor object that transforms raw States objects to tensors. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. + _default_adapter_class: The default adapter class for this estimator. """ def __init__( @@ -497,6 +527,7 @@ class ConditionalDiscretePolicyEstimator(DiscretePolicyEstimator): preprocessor: Preprocessor object that transforms raw States objects to tensors. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. + _default_adapter_class: The default adapter class for this estimator. """ def __init__( @@ -577,8 +608,11 @@ class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): is_backward: Always False for ConditionalScalarEstimator (since it's direction-agnostic). reduction_function: Function used to reduce multi-dimensional outputs to scalars. + _default_adapter_class: There is no default adapter class for this estimator. """ + _default_adapter_class = None + def __init__( self, state_module: nn.Module, @@ -674,6 +708,7 @@ class DiscreteGraphPolicyEstimator(LogitBasedEstimator): preprocessor: Preprocessor object that transforms GraphStates objects to tensors. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. + _default_adapter_class: The default adapter class for this estimator. """ def to_probability_distribution( @@ -828,8 +863,18 @@ class RecurrentDiscretePolicyEstimator(DiscretePolicyEstimator): entire trajectories typically requires different batching and masking. - ``init_carry`` is a hard requirement for compatibility with the recurrent adapter. + + Attributes: + module: The neural network module to use. + n_actions: Total number of actions in the discrete environment. + preprocessor: Preprocessor object that transforms states to tensors. + is_backward: Flag indicating whether this estimator is for backward policy, + i.e., is used for predicting probability distributions over parents. + _default_adapter_class: The default adapter class for this estimator. """ + _default_adapter_class = RecurrentEstimatorAdapter + def __init__( self, module: nn.Module, diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 8a6b1294..acf65f3a 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,11 +1,12 @@ import math import warnings from abc import ABC, abstractmethod -from typing import Any, Generic, Tuple, TypeVar +from typing import Any, Callable, Generic, Tuple, TypeVar import torch import torch.nn as nn +from gfn.adapters import EstimatorAdapter from gfn.containers import Container, Trajectories from gfn.env import Env from gfn.estimators import Estimator @@ -176,8 +177,12 @@ def __init__( pb: Estimator | None, constant_pb: bool = False, *, - pf_adapter: Any | None = None, - pb_adapter: Any | None = None, + pf_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, + pb_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, ) -> None: """Initializes a PFBasedGFlowNet instance. @@ -222,6 +227,7 @@ def __init__( self.pf = pf self.pb = pb self.constant_pb = constant_pb + # Optional adapters controlling estimator interactions via # vectorized / non-vectorized probability paths. self.pf_adapter = pf_adapter @@ -315,8 +321,12 @@ def __init__( pb: Estimator | None, constant_pb: bool = False, *, - pf_adapter: Any | None = None, - pb_adapter: Any | None = None, + pf_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, + pb_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, ) -> None: """Initializes a TrajectoryBasedGFlowNet instance. @@ -328,6 +338,10 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + pf_adapter: Optional adapter for PF probability calculation/sampling. When + provided, used both in the sampler and in probability recomputation paths. + pb_adapter: Optional adapter for PB probability calculation. Used in + probability recomputation paths when `pb` is provided. """ super().__init__( pf, diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index fd63431c..ab7b1950 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -1,9 +1,10 @@ import math -from typing import Any, Tuple +from typing import Callable, Tuple import torch from gfn.actions import Actions +from gfn.adapters import EstimatorAdapter from gfn.containers import Trajectories, Transitions from gfn.env import Env from gfn.estimators import ConditionalScalarEstimator, Estimator, ScalarEstimator @@ -77,8 +78,12 @@ def __init__( safe_log_prob_min: bool = True, constant_pb: bool = False, *, - pf_adapter: Any | None = None, - pb_adapter: Any | None = None, + pf_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, + pb_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, ) -> None: """Initializes a DBGFlowNet instance. @@ -109,14 +114,21 @@ def __init__( pb_adapter=pb_adapter, ) # Disallow recurrent PF or recurrent adapter for transition-based DB + from gfn.adapters import RecurrentEstimatorAdapter # type: ignore from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore - from gfn.samplers import RecurrentEstimatorAdapter # type: ignore if isinstance(self.pf, RecurrentDiscretePolicyEstimator): raise TypeError( "DBGFlowNet does not support recurrent PF estimators (transitions path cannot propagate carry)." ) - if isinstance(self.pf_adapter, RecurrentEstimatorAdapter): + + # Get the class whether pf_adapter is a class or instance + adapter_class = ( + self.pf_adapter + if isinstance(self.pf_adapter, type) + else type(self.pf_adapter) + ) + if issubclass(adapter_class, RecurrentEstimatorAdapter): raise TypeError( "DBGFlowNet does not support RecurrentEstimatorAdapter (transitions path cannot propagate carry)." ) @@ -125,6 +137,7 @@ def __init__( isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator] ), "logF must be a ScalarEstimator or derived" + self.logF = logF self.forward_looking = forward_looking self.log_reward_clip_min = log_reward_clip_min @@ -337,8 +350,12 @@ def __init__( pb: Estimator | None, constant_pb: bool = False, *, - pf_adapter: Any | None = None, - pb_adapter: Any | None = None, + pf_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, + pb_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, ) -> None: """Initializes a ModifiedDBGFlowNet instance. diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 74cc6125..cebdd19c 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -36,8 +36,8 @@ class FMGFlowNet(GFlowNet[StatesContainer[DiscreteStates]]): Adapter note ------------ Flow Matching does not rely on PF/PB probability recomputation. Any trajectory - sampling provided by this class is for diagnostics/visualization and uses the - default (non-recurrent) adapter internally. Sampler adapters (e.g., + sampling provided by this class is for diagnostics/visualization and can only use + the default (non-recurrent) adapter. Sampler adapters (e.g. `RecurrentEstimatorAdapter`) are not exposed as configuration options for this class. """ diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 8987451d..07697276 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -1,9 +1,10 @@ import math import warnings -from typing import Any, List, Literal, Tuple, TypeAlias +from typing import Callable, List, Literal, Tuple, TypeAlias import torch +from gfn.adapters import EstimatorAdapter from gfn.containers import Trajectories from gfn.env import Env from gfn.estimators import ConditionalScalarEstimator, Estimator, ScalarEstimator @@ -84,8 +85,12 @@ def __init__( forward_looking: bool = False, constant_pb: bool = False, *, - pf_adapter: Any | None = None, - pb_adapter: Any | None = None, + pf_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, + pb_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, ): """Initializes a SubTBGFlowNet instance. diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 6614500b..91dd5c45 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -3,11 +3,12 @@ and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446). """ -from typing import Any, cast +from typing import Callable, cast import torch import torch.nn as nn +from gfn.adapters import EstimatorAdapter from gfn.containers import Trajectories from gfn.env import Env from gfn.estimators import Estimator, ScalarEstimator @@ -48,8 +49,12 @@ def __init__( log_reward_clip_min: float = -float("inf"), constant_pb: bool = False, *, - pf_adapter: Any | None = None, - pb_adapter: Any | None = None, + pf_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, + pb_adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, ): """Initializes a TBGFlowNet instance. @@ -72,7 +77,11 @@ def __init__( recomputation paths that require PB. """ super().__init__( - pf, pb, constant_pb=constant_pb, pf_adapter=pf_adapter, pb_adapter=pb_adapter + pf, + pb, + constant_pb=constant_pb, + pf_adapter=pf_adapter, + pb_adapter=pb_adapter, ) self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index a4887b41..1bed10a0 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,459 +1,18 @@ -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Callable, List, Optional, Tuple import torch -from torch.distributions import Distribution from gfn.actions import Actions +from gfn.adapters import EstimatorAdapter, maybe_instantiate_adapter from gfn.containers import Trajectories from gfn.env import Env from gfn.estimators import Estimator from gfn.states import GraphStates, States from gfn.utils.common import ensure_same_device from gfn.utils.graphs import graph_states_share_storage -from gfn.utils.handlers import check_cond_forward from gfn.utils.prob_calculations import get_trajectory_pbs, get_trajectory_pfs -class EstimatorAdapter(ABC): - """Adapter interface for estimator-specific policy behavior. - - This abstract base class defines the minimal interface the Sampler relies on, - allowing us to keep one generic sampling loop while plugging in different estimator - behaviors (e.g., non‑recurrent, recurrent with carry, tempered variants) - without modifying the Sampler. - - The adapter owns an opaque RolloutContext object. The Sampler never inspects - it and simply passes it back to the adapter at each step. The adapter is - responsible for: - - initializing the context in `init_context`. - - updating any internal state (e.g., recurrent `carry`) during `compute` - - recording per‑step artifacts in `record` (e.g., log_probs, - estimator outputs), typically with mask-aware padding. - - output trajectory-length artifacts via `finalize(ctx)` - - The context should be allocated once per rollout. Masking should be applied inside - the adapter (via `step_mask`) when slicing conditioning or padding per‑step - tensors back to full batch size. The Sampler can therefore be oblivious to - estimator details (conditioning, carry, etc.). - """ - - @property - def is_backward(self) -> bool: - ... # fmt: skip - - @property - def is_vectorized(self) -> bool: - ... # fmt: skip - - @abstractmethod - def init_context( - self, - batch_size: int, - device: torch.device, - conditioning: Optional[torch.Tensor] = None, - ) -> Any: - ... # fmt: skip - - @abstractmethod - def compute( - self, - states_active: States, - ctx: Any, - step_mask: torch.Tensor, - **policy_kwargs: Any, - ) -> tuple[Distribution, Any]: - ... # fmt: skip - - @abstractmethod - def record( - self, - ctx: Any, - step_mask: torch.Tensor, - sampled_actions: torch.Tensor, - dist: Distribution, - save_logprobs: bool, - save_estimator_outputs: bool, - ) -> None: - ... # fmt: skip - - @abstractmethod - def log_prob_of_actions( - self, - states_active: States, - actions_active: torch.Tensor, - ctx: Any, - step_mask: torch.Tensor, - **policy_kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - ... # fmt: skip - - # Optional helper for `sample_actions` BC - def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: - ... # fmt: skip - - -class RolloutContext: - """Structured, mutable context owned by adapters. - - Uses fixed attributes for core fields and an `extras` dict for adapter- - specific extensions without changing the class shape. This keeps most - accesses fast and typed while preserving flexibility similar to dicts. - """ - - __slots__ = ( - "batch_size", - "device", - "conditioning", - "carry", - "trajectory_log_probs", - "trajectory_estimator_outputs", - "current_estimator_output", - "extras", - ) - - def __init__( - self, - batch_size: int, - device: torch.device, - conditioning: Optional[torch.Tensor] = None, - ) -> None: - self.batch_size = batch_size - self.device = device - self.conditioning = conditioning - self.carry = None - self.trajectory_log_probs: List[torch.Tensor] = [] - self.trajectory_estimator_outputs: List[torch.Tensor] = [] - self.current_estimator_output: Optional[torch.Tensor] = None - self.extras: Dict[str, Any] = {} - - -class DefaultEstimatorAdapter(EstimatorAdapter): - """Adapter for non-recurrent estimators (current default behavior). - - Overview - -------- - This adapter bridges the generic sampling loop and is used throughout the codebase. - It exposes the minimal interface required by the `EstimatorAdapter` abstract base - class while keeping the sampler loop estimator-agnostic. - - - If conditioning is provided, the estimator accepts `(states, conditioning)`; - otherwise it accepts `(states)`. - - The estimator provides `to_probability_distribution(states, est_out, **kw)` - returning a torch Distribution over actions for the masked states. - - The adapter owns an opaque rollout context `ctx` which the Sampler never reads. The - context (a TensorDict) is created once per rollout and mutated in place: - - - init_context: stores rollout invariants and optional conditioning. Also prepares - per‑step buffers for artifacts trajectory-level artifacts (log_probs, - estimator_outputs). - - - compute: - 1) Selects the appropriate estimator call signature depending on whether - conditioning is present. If conditioning is present, the adapter slices - it with `step_mask` so shapes match `states_active`. - 2) Calls the estimator forward pass to obtain the raw `est_out`. - 3) Converts `est_out` into a torch Distribution with - `to_probability_distribution`. - 4) Saves `est_out` into `ctx.current_estimator_output`. - - - record: - Materializes optional per‑step artifacts into context‑managed buffers with - mask‑aware padding back to the full rollout batch size: - * Log‑probs: computes `dist.log_prob(sampled_actions)` for active rows only, - then writes into a 1D tensor of length batch_size filled with zeros and masked - assignment for active positions. Appends this to a list (one tensor per time step). - * Estimator outputs: if requested, pads the last estimator output - (`ctx.current_estimator_output`) to shape `(batch_size, ...)` using `-inf` for - inactive rows and appends to a list (one tensor per time step). - - - finalize: Stacks recorded per‑step lists along the time dimension into tensors of - shape `(trajectory_legnth, batch_size, ...)` suitable for `Trajectories`. - Returns `None` for any artifact that was never recorded. - - Masking & Shapes - ---------------- - - `states_active` always corresponds to `states[~dones]` inside the sampler. - - The adapter receives `step_mask` (shape `(N,)`) to slice any step‑dependent - inputs (e.g., conditioning) and to pad per‑step outputs to the full batch. - - Padded tensors use `0.0` for log‑probs and `-inf` for estimator outputs to - maintain compatibility with downstream code. - - Backward/Forward Direction - -------------------------- - - `is_backward` is forwarded from the underlying estimator so the sampler can - choose the appropriate environment transition (forward vs backward). - - Vectorized Probability Path - -------------------------- - - `is_vectorized` is used by the Sampler to choose the appropriate probability path. - - Vectorized adapters always use faster paths in probability calculators. - Non-vectorized adapters (e.g., recurrent) use per-step paths with masking and - alignment identical to the legacy reference. - - Performance Notes - ----------------- - - `ctx` is allocated once per rollout and mutated in place to avoid per‑step - overhead. - - If you know trajectory length bounds, you can extend this adapter to - pre‑allocate fixed‑size storage in `init_context` rather than appending to - Python lists. - """ - - def __init__(self, estimator: Estimator) -> None: - """Initialize the adapter with a non-recurrent estimator. - - The estimator must expose `to_probability_distribution(states, est_out, **kw)` - and optionally accept conditioning via `estimator(states, conditioning)`. - """ - self._estimator = estimator - - @property - def is_backward(self) -> bool: - """Whether the wrapped estimator samples in the backward direction.""" - return getattr(self._estimator, "is_backward", False) - - @property - def is_vectorized(self) -> bool: - """Used for vectorized probability calculations.""" - return True - - def init_context( - self, - batch_size: int, - device: torch.device, - conditioning: Optional[torch.Tensor] = None, - ) -> RolloutContext: - """Create a new per-rollout context. - - Stores rollout invariants (batch size, device, optional conditioning) and - initializes empty buffers for per-step artifacts. - """ - return RolloutContext( - batch_size=batch_size, device=device, conditioning=conditioning - ) - - def compute( - self, - states_active: States, - ctx: Any, - step_mask: torch.Tensor, - **policy_kwargs: Any, - ) -> tuple[Distribution, Any]: - """Run the estimator for active rows and build an action Distribution. - - - Uses `step_mask` to slice conditioning to the active subset. - - Saves the raw estimator output in `ctx.current_estimator_output` for - optional recording in `record_step`. - """ - cond_active = None - if ctx.conditioning is not None: - cond_active = ctx.conditioning[step_mask] - - est_out = check_cond_forward( - self._estimator, "estimator", states_active, cond_active - ) - - dist = self._estimator.to_probability_distribution( - states_active, est_out, **policy_kwargs - ) - ctx.current_estimator_output = est_out - - return dist, ctx - - def record( - self, - ctx: Any, - step_mask: torch.Tensor, - sampled_actions: torch.Tensor, - dist: Distribution, - save_logprobs: bool, - save_estimator_outputs: bool, - ) -> None: - """Record per-step artifacts into the context's trajectory-level lists.""" - if save_logprobs: - lp_masked = dist.log_prob(sampled_actions) - if torch.any(torch.isinf(lp_masked)): - raise RuntimeError("Log probabilities are inf. This should not happen.") - step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device) - step_lp[step_mask] = lp_masked - ctx.trajectory_log_probs.append(step_lp) - - if save_estimator_outputs and ctx.current_estimator_output is not None: - est_out = ctx.current_estimator_output - padded = torch.full( - (ctx.batch_size,) + est_out.shape[1:], -float("inf"), device=ctx.device - ) - padded[step_mask] = est_out - ctx.trajectory_estimator_outputs.append(padded) - - def log_prob_of_actions( - self, - states_active: States, - actions_active: torch.Tensor, - ctx: Any, - step_mask: torch.Tensor, - **policy_kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - # Optional fast path: use caller-provided estimator outputs only when - # shapes match. - precomputed = self.get_current_estimator_output(ctx) - if precomputed is not None: - expected_bs = states_active.batch_shape[0] - assert precomputed.shape[0] == expected_bs, ( - "current_estimator_output batch size does not match active states batch size. " - f"Got precomputed={precomputed.shape[0]}, expected={expected_bs}. " - "Likely stale reuse. Ensure PB clears ctx.current_estimator_output each step " - "and PF indexes trajectories.estimator_outputs[t][step_mask]." - ) - est_out = precomputed - dist = self._estimator.to_probability_distribution( - states_active, est_out, **policy_kwargs - ) - else: - # Compute fresh estimator output when no valid precomputed output is - # provided. - # TODO: Should I toggle "save_estimator_outputs" here? - dist, ctx = self.compute(states_active, ctx, step_mask, **policy_kwargs) - - masked_log_probs = dist.log_prob(actions_active) - - if torch.any(torch.isinf(masked_log_probs)): - raise RuntimeError("Log probabilities are inf. This should not happen.") - - step_log_probs = torch.full((ctx.batch_size,), 0.0, device=ctx.device) - step_log_probs[step_mask] = masked_log_probs - - return step_log_probs, ctx - - def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: - """Stack recorded per-step artifacts along time into trajectory-level tensors.""" - log_probs = ( - torch.stack(ctx.trajectory_log_probs, dim=0) - if ctx.trajectory_log_probs - else None - ) - estimator_outputs = ( - torch.stack(ctx.trajectory_estimator_outputs, dim=0) - if ctx.trajectory_estimator_outputs - else None - ) - - return {"log_probs": log_probs, "estimator_outputs": estimator_outputs} - - def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: - """Expose the most recent per-step estimator output saved during `compute`.""" - return getattr(ctx, "current_estimator_output", None) - - -class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): - """Adapter for recurrent estimators that maintain a rollout carry (hidden state). - - Differences from `DefaultEstimatorAdapter`: - - is_vectorized = False: runs sequential, per‑step probability calculators (legacy PF/PB - masks and alignment; PB aligns action at t with state at t+1, t==0 skipped). - - Rollout context manages `ctx.carry` which contains the hidden state of the - recurrent estimator. It is initialized once via `estimator.init_carry(batch_size, - device)` and updated every step. - - `compute(states, ctx, ...)` calls `estimator(states, ctx.carry) -> - (est_out, new_carry)`, updates `ctx.carry`, then builds the Distribution from - `est_out`. - - recording mirrors the default but pads per‑step tensors to batch size and - stacks into `(T, N, ...)`. - - No action‑id mask indexing; illegal actions are handled by the Distribution. - - Requires the estimator to implement: - - `init_carry(batch_size: int, device: torch.device)` - - a recurrent forward returning `(est_out, new_carry)`. - """ - - def __init__(self, estimator: Estimator) -> None: - # Validate that the estimator presents a recurrent interface - # We check for the presence of `init_carry` and a callable that accepts (states, carry). - init_carry = getattr(estimator, "init_carry", None) - if not callable(init_carry): - raise TypeError( - "RecurrentEstimatorAdapter requires an estimator implementing " - "init_carry(batch_size: int, device: torch.device)." - ) - super().__init__(estimator) - - # TODO: Need to support vectorized probability calculations with Transformers. - @property - def is_vectorized(self) -> bool: - return False - - def init_context( - self, - batch_size: int, - device: torch.device, - conditioning: Optional[torch.Tensor] = None, - ) -> RolloutContext: - """Create context and initialize recurrent carry, (estimator hidden state). - - Differs from the default adapter by allocating `ctx.carry` via - `estimator.init_carry(batch_size, device)`. - """ - init_carry = getattr(self._estimator, "init_carry", None) - if not callable(init_carry): - raise TypeError( - "RecurrentEstimatorAdapter requires an estimator that implements " - "init_carry(batch_size: int, device: torch.device).\n" - "A) Recurrent estimators must expose an `init_carry` method.\n" - "B) RecurrentEstimatorAdapter is only compatible with estimators that " - "expose `init_carry`." - ) - ctx = super().init_context(batch_size, device, conditioning) - init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) - ctx.carry = init_carry_fn(batch_size, device) - - return ctx - - def compute( - self, - states_active: States, - ctx: Any, - step_mask: torch.Tensor, - **policy_kwargs: Any, - ) -> tuple[Distribution, Any]: - """Run estimator with carry and update it. - - Differs from the default adapter by calling - `estimator(states_active, ctx.carry) -> (est_out, new_carry)`, storing the - updated carry and saving `current_estimator_output` before building the - Distribution. - """ - # Recurrent estimators are expected to accept (states, carry) -> (out, new_carry) - est_out, new_carry = self._estimator(states_active, ctx.carry) - ctx.carry = new_carry - dist = self._estimator.to_probability_distribution( - states_active, est_out, **policy_kwargs - ) - ctx.current_estimator_output = est_out - - return dist, ctx - - def log_prob_of_actions( - self, - states_active: States, - actions_active: torch.Tensor, - ctx: Any, - step_mask: torch.Tensor, - **policy_kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - # TODO: Should I toggle "save_estimator_outputs" here? - # TODO: Need to look at how we handle logprobs for on-policy training. - dist, ctx = self.compute(states_active, ctx, step_mask, **policy_kwargs) - - masked_log_probs = dist.log_prob(actions_active) - if torch.any(torch.isinf(masked_log_probs)): - raise RuntimeError("Log probabilities are inf. This should not happen.") - step_log_probs = torch.full((ctx.batch_size,), 0.0, device=ctx.device) - step_log_probs[step_mask] = masked_log_probs - return step_log_probs, ctx - - class Sampler: """Wrapper for a PolicyEstimator that enables sampling from GFlowNet environments. @@ -467,18 +26,23 @@ class Sampler: """ def __init__( - self, estimator: Estimator, adapter: Optional[EstimatorAdapter] = None + self, + estimator: Estimator, + adapter: ( + Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None + ) = None, ) -> None: """Initializes a Sampler with a PolicyEstimator. Args: estimator: The PolicyEstimator to use for sampling actions and computing probability distributions. + adapter: An adapter class instance or callable to use for sampling actions + and computing probability distributions. If None, the default adapter + class for the estimator will be used. """ self.estimator = estimator - self.adapter = ( - adapter if adapter is not None else DefaultEstimatorAdapter(estimator) - ) + self.adapter = maybe_instantiate_adapter(estimator, adapter) def sample_actions( self, @@ -531,15 +95,16 @@ def sample_actions( step_mask = torch.ones( states.batch_shape[0], dtype=torch.bool, device=states.device ) - dist, ctx = self.adapter.compute(states, ctx, step_mask, **policy_kwargs) + dist, ctx = self.adapter.compute_dist(states, ctx, step_mask, **policy_kwargs) with torch.no_grad(): actions_tensor = dist.sample() if save_logprobs: - log_probs = dist.log_prob(actions_tensor) - if torch.any(torch.isinf(log_probs)): - raise RuntimeError("Log probabilities are inf. This should not happen.") + # Use adapter to compute step log-probs and pad to batch. + log_probs, ctx = self.adapter.log_probs( + actions_tensor, dist, ctx, step_mask, vectorized=False + ) else: log_probs = None @@ -549,7 +114,7 @@ def sample_actions( step_mask=step_mask, sampled_actions=actions_tensor, dist=dist, - save_logprobs=save_logprobs, + log_probs=log_probs, save_estimator_outputs=save_estimator_outputs, ) @@ -658,7 +223,7 @@ def sample_trajectories( step_mask = ~dones # Compute distribution on active rows - dist, ctx = self.adapter.compute( + dist, ctx = self.adapter.compute_dist( states[step_mask], ctx, step_mask, **policy_kwargs ) @@ -667,13 +232,21 @@ def sample_trajectories( valid_actions_tensor = dist.sample() valid_actions = env.actions_from_tensor(valid_actions_tensor) - # Let adapter record artifacts + if save_logprobs: + # Use adapter to compute step log-probs and pad to batch. + log_probs, ctx = self.adapter.log_probs( + valid_actions_tensor, dist, ctx, step_mask, vectorized=False + ) + else: + log_probs = None + + # Let adapter record artifacts. self.adapter.record( ctx=ctx, step_mask=step_mask, sampled_actions=valid_actions_tensor, dist=dist, - save_logprobs=save_logprobs, + log_probs=log_probs, save_estimator_outputs=save_estimator_outputs, ) diff --git a/src/gfn/utils/handlers.py b/src/gfn/utils/handlers.py index 56877c71..84ba9c72 100644 --- a/src/gfn/utils/handlers.py +++ b/src/gfn/utils/handlers.py @@ -1,16 +1,18 @@ import warnings from contextlib import contextmanager -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import torch from gfn.containers import Container -from gfn.estimators import Estimator from gfn.states import States +if TYPE_CHECKING: + from gfn.estimators import Estimator # type: ignore + def check_cond_forward( - module: Estimator, + module: "Estimator", module_name: str, states: States, condition: Optional[torch.Tensor] = None, diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 15fd744f..58372227 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -1,11 +1,11 @@ import warnings -from typing import Any, Tuple +from typing import Any, Callable, Tuple import torch +from gfn.adapters import EstimatorAdapter, maybe_instantiate_adapter from gfn.containers import Trajectories, Transitions from gfn.estimators import Estimator -from gfn.utils.handlers import check_cond_forward # ------------ # Trajectories @@ -18,8 +18,8 @@ def get_trajectory_pfs_and_pbs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = True, - pf_adapter: Any | None = None, - pb_adapter: Any | None = None, + pf_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, + pb_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: """Calculate PF and PB log-probabilities for trajectories. @@ -40,7 +40,7 @@ def get_trajectory_pfs_and_pbs( fill_value: Fill used for invalid states (e.g., sink state positions). recalculate_all_logprobs: If ``True``, recompute PF even if cached. pf_adapter: Adapter for PF (vectorized vs non‑vectorized decision). - pb_adapter: OAdapter for PB (vectorized vs non‑vectorized decision). + pb_adapter: Adapter for PB (vectorized vs non‑vectorized decision). **policy_kwargs: Extra kwargs passed to estimator's ``to_probability_distribution`` (e.g., temperature, epsilon, sf_bias). @@ -54,7 +54,7 @@ def get_trajectory_pfs_and_pbs( # uncomment next line for debugging # assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy) - if pb_adapter is not None and not isinstance(pb_adapter, type(pf_adapter)): + if pb_adapter is not None and not isinstance(pb_adapter, type(pf_adapter)): # type: ignore warnings.warn( ( "type(pb_adapter)={} and type(pf_adapter)={}, this is probably not what you want " @@ -87,7 +87,7 @@ def get_trajectory_pfs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = True, - adapter: Any | None = None, + adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: """Calculate PF log-probabilities for trajectories. @@ -114,6 +114,9 @@ def get_trajectory_pfs( Raises: ValueError: If backward trajectories are provided. """ + adapter = maybe_instantiate_adapter(pf, adapter) + assert adapter is not None + if trajectories.is_backward: raise ValueError("Backward trajectories are not supported") @@ -130,16 +133,17 @@ def get_trajectory_pfs( log_pf_trajectories = trajectories.log_probs assert log_pf_trajectories is not None else: + # Decide vectorized (legacy) vs non-vectorized (adapter per-step) - vectorized = adapter is None or getattr(adapter, "is_vectorized", True) + if not adapter.is_vectorized: - if not vectorized: # Adapter-driven path N = trajectories.n_trajectories device = trajectories.states.device cond = trajectories.conditioning if cond is not None and len(cond.shape) >= 2: cond = cond[0] + ctx = adapter.init_context(int(N), device, cond) # type: ignore[arg-type] T = trajectories.max_length @@ -175,17 +179,26 @@ def get_trajectory_pfs( # explicitly for the current step. ctx.current_estimator_output = None - # Calculate the log-probabilities of the actions. - step_lp, ctx = adapter.log_prob_of_actions( # type: ignore[union-attr] - step_states, step_actions, ctx, step_mask, **policy_kwargs + # Build distribution for active rows and compute step log-probs via + # adapter. + dist, ctx = adapter.compute_dist( + step_states, ctx, step_mask, **policy_kwargs + ) + step_log_probs, ctx = adapter.log_probs( + step_actions, dist, ctx, step_mask, vectorized=False ) + + # Pad back to full batch size. if fill_value != 0.0: padded = torch.full( - (N,), fill_value, device=device, dtype=step_lp.dtype + (N,), fill_value, device=device, dtype=step_log_probs.dtype ) - padded[step_mask] = step_lp[step_mask] - step_lp = padded - log_pf_trajectories[t] = step_lp + padded[step_mask] = step_log_probs[step_mask] + step_log_probs = padded + + # Store in trajectory-level tensor. + log_pf_trajectories[t] = step_log_probs + else: # Vectorized path. log_pf_trajectories = torch.full_like( @@ -197,30 +210,46 @@ def get_trajectory_pfs( if len(valid_states) == 0: return log_pf_trajectories + # Build conditioning per-step shape to align with valid_states + masked_cond = None + cond = trajectories.conditioning + + if cond is not None: + T = trajectories.states.tensor.shape[0] + # If conditioning already has time dim (T, N, ...), index directly. + if cond.shape[0] == T: + masked_cond = cond[state_mask] + else: + # Broadcast (N, ...) to (T, N, ...), then index. + masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] + + # Create a temporary context sized to valid rows. + ctx_v = adapter.init_context( + int(len(valid_states)), + trajectories.states.device, + conditioning=masked_cond, + ) + + # Optional estimator output cache reuse. if ( trajectories.estimator_outputs is not None and not recalculate_all_logprobs ): - # Reuse cached outputs to build the distribution - est_out = trajectories.estimator_outputs[action_mask] - dist = pf.to_probability_distribution( - valid_states, est_out, **policy_kwargs - ) - valid_log_pf_actions = dist.log_prob(valid_actions.tensor) - else: - # Build conditioning per-step shape to align with valid_states - masked_cond = None - if trajectories.conditioning is not None: - cond_dim = (-1,) * len(trajectories.conditioning.shape) - traj_len = trajectories.states.tensor.shape[0] - masked_cond = trajectories.conditioning.unsqueeze(0).expand( - (traj_len,) + cond_dim - )[state_mask] - est_out = check_cond_forward(pf, "pf", valid_states, masked_cond) - valid_log_pf_actions = pf.to_probability_distribution( - valid_states, est_out, **policy_kwargs - ).log_prob(valid_actions.tensor) + estimator_outputs = trajectories.estimator_outputs[action_mask] + ctx_v.current_estimator_output = estimator_outputs + + # Delegate to adapter for dist and vectorized log-prob calculation. + dist, ctx_v = adapter.compute_dist( + valid_states, + ctx_v, + step_mask=None, + **policy_kwargs, + ) + valid_log_pf_actions, _ = adapter.log_probs( + valid_actions.tensor, dist, ctx_v, step_mask=None, vectorized=True + ) + # Pad back to full batch size. log_pf_trajectories[action_mask] = valid_log_pf_actions.to( log_pf_trajectories.dtype, copy=False ) @@ -237,7 +266,7 @@ def get_trajectory_pbs( pb: Estimator | None, trajectories: Trajectories, fill_value: float = 0.0, - adapter: Any | None = None, + adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: """Calculate PB log-probabilities for trajectories. @@ -294,49 +323,30 @@ def get_trajectory_pbs( # Using all non-initial states, calculate the backward policy, and the logprobs # of those actions. masked_cond = None - if trajectories.conditioning is not None: - # We need to index the conditioning vector to broadcast over the states. - # The conditioning tensor has shape (max_length, n_trajectories, 1) - # We need to index it with the state_mask to get the valid states - masked_cond = trajectories.conditioning[state_mask] - - # Recurrent adapters are only valid for trajectories and never require pb - from gfn.samplers import RecurrentEstimatorAdapter # type: ignore - - if adapter is not None: - is_recurrent = isinstance(adapter, RecurrentEstimatorAdapter) - else: - is_recurrent = False + cond = trajectories.conditioning + if cond is not None: + T = trajectories.states.tensor.shape[0] + if cond.shape[0] == T: + masked_cond = cond[state_mask] + else: + masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] - if is_recurrent or pb is None: - # With recurrent adapter, pb *must* be None (tree DAG); return zeros. - assert pb is None, "When using a RecurrentEstimatorAdapter, pb must be None." + # There is no backward policy in this case. + if pb is None: # If pb is None, we assume that the gflownet DAG is a tree, and therefore # the backward policy probability is always 1 (log probs are 0). valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) valid_log_pb_actions = valid_log_pb_actions.squeeze(-1) # no padding. + log_pb_trajectories[action_mask] = valid_log_pb_actions.to( + log_pb_trajectories.dtype, copy=False + ) + + # There is a backward policy. + else: + adapter = maybe_instantiate_adapter(pb, adapter) - # TODO: Add logging in follow up PR. - # if os.getenv("GFN_DEBUG_REC_PB") == "1": - # print( - # "[DBG] pb=None path: valid_actions.shape=", - # tuple(valid_actions.tensor.shape), - # "valid_log_pb_actions.shape=", - # tuple(valid_log_pb_actions.shape), - # "target_len=", - # int(action_mask.sum().item()), - # ) - - elif pb is not None: - # Choose vectorized (legacy) vs non-vectorized (adapter per-step) - # Vectorized path is used by default via DefaultEstimatorAdapter. - vectorized = adapter is None or getattr(adapter, "is_vectorized", True) - if adapter is None: - from gfn.samplers import DefaultEstimatorAdapter # Avoids circular import. - - adapter = DefaultEstimatorAdapter(pb) - - if not vectorized: + # The backward policy requires step-wise evaluation. + if not adapter.is_vectorized: # Adapter-driven pb evaluation (non-recurrent) N = trajectories.n_trajectories device = trajectories.states.device @@ -346,12 +356,13 @@ def get_trajectory_pbs( cond = cond[0] ctx = adapter.init_context(int(N), device, cond) # type: ignore[arg-type] - T = trajectories.max_length # Iterate per-step with legacy-complete masking (state at t+1, action at t) - for t in range(T): - state_ok = (~trajectories.states.is_sink_state[t + 1]) & ( - ~trajectories.states.is_initial_state[t + 1] - ) + for t in range(trajectories.max_length): + # TODO: these checks are curious - I think one of them is never needed + # because for now we do not support reversed trajectories. + next_state_isnt_sink = ~trajectories.states.is_sink_state[t + 1] + next_state_isnt_initial = ~trajectories.states.is_initial_state[t + 1] + state_ok = next_state_isnt_sink & next_state_isnt_initial if t == 0: # log PB is always zero for the transition s1 -> s0. state_ok = torch.zeros_like(state_ok, dtype=torch.bool) @@ -370,30 +381,43 @@ def get_trajectory_pbs( # Prevent reusing last step's estimator output (batch size may differ, # and estimator output caching isn't needed for PB). ctx.current_estimator_output = None - step_lp, ctx = adapter.log_prob_of_actions( - step_states, step_actions, ctx, step_mask, **policy_kwargs + dist, ctx = adapter.compute_dist( + step_states, + ctx, + step_mask, + **policy_kwargs, + ) + step_lp, ctx = adapter.log_probs( + step_actions, dist, ctx, step_mask, vectorized=False ) - padded = torch.full((N,), fill_value, device=device, dtype=step_lp.dtype) + padded = torch.full( + (N,), + fill_value, + device=device, + dtype=step_lp.dtype, + ) padded[step_mask] = step_lp[step_mask] log_pb_trajectories[t] = padded - return log_pb_trajectories - + # The backward policy supports vectorized evaluation. else: - # Legacy vectorized path - estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond) - valid_log_pb_actions = pb.to_probability_distribution( - valid_states, estimator_outputs - ).log_prob(valid_actions.tensor) + ctx_v = adapter.init_context( + int(len(valid_states)), trajectories.states.device, conditioning=masked_cond # type: ignore[arg-type] + ) + dist, ctx_v = adapter.compute_dist( + valid_states, + ctx_v, + step_mask=None, + **policy_kwargs, + ) + valid_log_pb_actions, _ = adapter.log_probs( + valid_actions.tensor, dist, ctx_v, step_mask=None, vectorized=True + ) log_pb_trajectories[action_mask] = valid_log_pb_actions.to( log_pb_trajectories.dtype, copy=False ) - log_pb_trajectories[action_mask] = valid_log_pb_actions.to( - log_pb_trajectories.dtype, copy=False - ) - assert log_pb_trajectories.shape == ( trajectories.max_length, trajectories.n_trajectories, @@ -412,8 +436,8 @@ def get_transition_pfs_and_pbs( pb: Estimator | None, transitions: Transitions, recalculate_all_logprobs: bool = True, - pf_adapter: Any | None = None, - pb_adapter: Any | None = None, + pf_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, + pb_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: """Calculate PF and PB log-probabilities for transitions. @@ -442,7 +466,7 @@ def get_transition_pfs_and_pbs( if transitions.is_backward: raise ValueError("Backward transitions are not supported") - if pb_adapter is not None and not isinstance(pb_adapter, type(pf_adapter)): + if pb_adapter is not None and not isinstance(pb_adapter, type(pf_adapter)): # type: ignore warnings.warn( ( "type(pb_adapter)={} and type(pf_adapter)={}, this is probably not what you want " @@ -468,7 +492,7 @@ def get_transition_pfs( pf: Estimator, transitions: Transitions, recalculate_all_logprobs: bool = True, - adapter: Any | None = None, + adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: """Calculate PF log-probabilities for transitions. @@ -490,6 +514,9 @@ def get_transition_pfs( Returns: Tensor of shape ``(M,)`` containing PF log-probabilities. """ + adapter = maybe_instantiate_adapter(pf, adapter) + assert adapter is not None + states = transitions.states actions = transitions.actions @@ -497,38 +524,29 @@ def get_transition_pfs( log_pf_actions = transitions.log_probs assert log_pf_actions is not None else: - if adapter is not None or True: - from gfn.samplers import RecurrentEstimatorAdapter # type: ignore - if adapter is None: - from gfn.samplers import DefaultEstimatorAdapter # type: ignore + from gfn.adapters import RecurrentEstimatorAdapter # type: ignore - adapter = DefaultEstimatorAdapter(pf) - elif isinstance(adapter, RecurrentEstimatorAdapter): - raise TypeError( - "RecurrentEstimatorAdapter is only supported for Trajectories" - ) - assert adapter is not None - - N = transitions.n_transitions - device = transitions.states.device - cond = transitions.conditioning - ctx = adapter.init_context(int(N), device, cond) - mask = torch.ones(N, dtype=torch.bool, device=device) - - # Evaluate the log PF of the actions, with optional conditioning. - # TODO: Inefficient duplication in case of tempered policy - # The Transitions container should then have some - # estimator_outputs attribute as well, to avoid duplication here ? - # See (#156). - step_lp, _ = adapter.log_prob_of_actions( - states[mask], - actions.tensor[mask], - ctx, - mask, - **policy_kwargs, + if isinstance(adapter, RecurrentEstimatorAdapter): + raise TypeError( + "RecurrentEstimatorAdapter is only supported for Trajectories" ) - log_pf_actions = step_lp + + N = transitions.n_transitions + device = transitions.states.device + cond = transitions.conditioning + ctx = adapter.init_context(int(N), device, cond) + mask = torch.ones(N, dtype=torch.bool, device=device) + + # Evaluate the log PF of the actions, with optional conditioning. + # TODO: Inefficient duplication in case of tempered policy + # The Transitions container should then have some + # estimator_outputs attribute as well, to avoid duplication here ? + # See (#156). + dist, ctx = adapter.compute_dist(states[mask], ctx, mask, **policy_kwargs) + log_pf_actions, _ = adapter.log_probs( + actions.tensor[mask], dist, ctx, mask, vectorized=False + ) return log_pf_actions @@ -536,7 +554,7 @@ def get_transition_pfs( def get_transition_pbs( pb: Estimator | None, transitions: Transitions, - adapter: Any | None = None, + adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: """Calculate PB log-probabilities for transitions. @@ -557,16 +575,6 @@ def get_transition_pbs( Returns: Tensor of shape ``(M,)`` containing PB log-probabilities. """ - # # automatically removes invalid transitions (i.e. s_f -> s_f) - # valid_next_states = transitions.next_states[~transitions.is_terminating] - # non_exit_actions = transitions.actions[~transitions.actions.is_exit] - - # # Evaluate the log PB of the actions, with optional conditioning. - # masked_cond = ( - # transitions.conditioning[~transitions.is_terminating] - # if transitions.conditioning is not None - # else None - # ) # TODO: We support a fill_value for trajectories, but not for transitions. # Should we add it here, or remove it for trajectories? @@ -574,45 +582,39 @@ def get_transition_pbs( (transitions.n_transitions,), device=transitions.states.device ) - if adapter is not None or True: - from gfn.samplers import RecurrentEstimatorAdapter # type: ignore + # If pb is None, we assume that the gflownet DAG is a tree, and therefore + # the backward policy probability is always 1 (log probs are 0). + if pb is None: + return log_pb_actions - if adapter is None and pb is not None: - from gfn.samplers import DefaultEstimatorAdapter # type: ignore + adapter = maybe_instantiate_adapter(pb, adapter) + assert adapter is not None - adapter = DefaultEstimatorAdapter(pb) - elif isinstance(adapter, RecurrentEstimatorAdapter): - raise TypeError( - "RecurrentEstimatorAdapter is only supported for Trajectories" - ) - assert adapter is not None + from gfn.adapters import RecurrentEstimatorAdapter # type: ignore - # If pb is None, we assume that the gflownet DAG is a tree, and therefore - # the backward policy probability is always 1 (log probs are 0). - if pb is None: - return log_pb_actions + if isinstance(adapter, RecurrentEstimatorAdapter): + raise TypeError("RecurrentEstimatorAdapter is only supported for Trajectories") - N = transitions.n_transitions - device = transitions.states.device - cond = transitions.conditioning - ctx = adapter.init_context(int(N), device, cond) - # Legacy-complete masking for PB on transitions: - # require non-terminating next_states and non-exit actions simultaneously - # automatically removes invalid transitions (i.e. s_f -> s_f) - state_ok = ~transitions.is_terminating - action_ok = ~transitions.actions.is_exit - mask = state_ok & action_ok - - if not torch.any(mask): - return log_pb_actions - - step_lp, _ = adapter.log_prob_of_actions( - transitions.next_states[mask], - transitions.actions.tensor[mask], - ctx, - mask, - **policy_kwargs, - ) - log_pb_actions[mask] = step_lp[mask] + ctx = adapter.init_context( + int(transitions.n_transitions), + transitions.states.device, + transitions.conditioning, + ) + + # Legacy-complete masking for PB on transitions: + # require non-terminating next_states and non-exit actions simultaneously + # automatically removes invalid transitions (i.e. s_f -> s_f) + mask = ~transitions.is_terminating & ~transitions.actions.is_exit + + if not torch.any(mask): + return log_pb_actions + + dist, ctx = adapter.compute_dist( + transitions.next_states[mask], ctx, mask, **policy_kwargs + ) + step_lp, _ = adapter.log_probs( + transitions.actions.tensor[mask], dist, ctx, mask, vectorized=False + ) + log_pb_actions[mask] = step_lp[mask] return log_pb_actions diff --git a/testing/test_adaptor_estimator_gflownet_integration.py b/testing/test_adaptor_estimator_gflownet_integration.py index 05bd03c6..b547a0cb 100644 --- a/testing/test_adaptor_estimator_gflownet_integration.py +++ b/testing/test_adaptor_estimator_gflownet_integration.py @@ -3,6 +3,7 @@ import pytest import torch +from gfn.adapters import DefaultEstimatorAdapter, RecurrentEstimatorAdapter from gfn.estimators import ( DiscretePolicyEstimator, RecurrentDiscretePolicyEstimator, @@ -10,7 +11,6 @@ ) from gfn.gflownet import DBGFlowNet, TBGFlowNet from gfn.gym.bitSequence import BitSequence -from gfn.samplers import DefaultEstimatorAdapter, RecurrentEstimatorAdapter from gfn.utils.modules import MLP, RecurrentDiscreteSequenceModel diff --git a/testing/test_probability_calculations.py b/testing/test_probability_calculations.py index e3e7952f..47d2d817 100644 --- a/testing/test_probability_calculations.py +++ b/testing/test_probability_calculations.py @@ -1,10 +1,11 @@ import pytest import torch +from gfn.adapters import DefaultEstimatorAdapter from gfn.estimators import DiscretePolicyEstimator from gfn.gym import HyperGrid from gfn.preprocessors import IdentityPreprocessor -from gfn.samplers import DefaultEstimatorAdapter, Sampler +from gfn.samplers import Sampler from gfn.utils.handlers import check_cond_forward from gfn.utils.prob_calculations import ( get_trajectory_pbs, @@ -327,7 +328,7 @@ def test_transition_pb_vectorized_vs_nonvectorized_parity(): torch.testing.assert_close(vec, nvec) -def test_adapter_log_prob_of_actions_precomputed_matches_forward(): +def test_adapter_log_probs_precomputed_matches_forward(): env, pf_estimator, _ = _build_env_and_pf() states = env.reset(batch_shape=(5,)) @@ -343,11 +344,13 @@ def test_adapter_log_prob_of_actions_precomputed_matches_forward(): step_mask = torch.ones(5, dtype=torch.bool, device=states.device) # Baseline: adapter recomputes estimator outputs internally - lp1, _ = adapter.log_prob_of_actions(states, actions_tensor, ctx1, step_mask) + dist1, ctx1 = adapter.compute_dist(states, ctx1, step_mask) + lp1, _ = adapter.log_probs(actions_tensor, dist1, ctx1, step_mask, vectorized=False) - # Precomputed: adapter uses provided estimator outputs (fast path). + # Precomputed: adapter reuses provided estimator outputs (fast path) ctx2.current_estimator_output = est_out - lp2, _ = adapter.log_prob_of_actions(states, actions_tensor, ctx2, step_mask) + dist2, ctx2 = adapter.compute_dist(states, ctx2, step_mask) + lp2, _ = adapter.log_probs(actions_tensor, dist2, ctx2, step_mask, vectorized=False) torch.testing.assert_close(lp1, lp2) @@ -434,7 +437,3 @@ def test_get_transition_pbs_matches_legacy_with_default_adapter(): adapter=DefaultEstimatorAdapter(pb_estimator), ) torch.testing.assert_close(modern, legacy) - - -if __name__ == "__main__": - test_trajectory_pb_vectorized_vs_nonvectorized_parity() diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 8245201c..c247e123 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -4,6 +4,11 @@ import torch from torch.distributions import Categorical +from gfn.adapters import ( + DefaultEstimatorAdapter, + RecurrentEstimatorAdapter, + RolloutContext, +) from gfn.containers import Trajectories, Transitions from gfn.containers.replay_buffer import ReplayBuffer from gfn.estimators import ( @@ -21,10 +26,7 @@ OneHotPreprocessor, ) from gfn.samplers import ( - DefaultEstimatorAdapter, LocalSearchSampler, - RecurrentEstimatorAdapter, - RolloutContext, Sampler, ) from gfn.states import States @@ -376,20 +378,17 @@ def test_to_transition( n_components_s0=1, ) - try: - _ = trajectories.to_transitions() - - bwd_trajectories = Trajectories.reverse_backward_trajectories(bwd_trajectories) - # evaluate with pf_estimator - backward_traj_pfs = get_trajectory_pfs( - pf=pf_estimator, - trajectories=bwd_trajectories, - recalculate_all_logprobs=False, - ) - bwd_trajectories.log_probs = backward_traj_pfs - _ = bwd_trajectories.to_transitions() - except Exception as e: - raise ValueError(f"Error while testing {env_name}") from e + _ = trajectories.to_transitions() + + bwd_trajectories = Trajectories.reverse_backward_trajectories(bwd_trajectories) + # evaluate with pf_estimator + backward_traj_pfs = get_trajectory_pfs( + pf=pf_estimator, + trajectories=bwd_trajectories, + recalculate_all_logprobs=False, + ) + bwd_trajectories.log_probs = backward_traj_pfs + _ = bwd_trajectories.to_transitions() @pytest.mark.parametrize( @@ -531,10 +530,11 @@ def test_default_adapter_compute_record_finalize(): ctx = adapter.init_context(n, device, conditioning=None) step_mask = torch.ones(n, dtype=torch.bool, device=device) - dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) actions = dist.sample() + log_probs, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) adapter.record( - ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True + ctx, step_mask, actions, dist, log_probs=log_probs, save_estimator_outputs=True ) out = adapter.finalize(ctx) assert out["log_probs"] is not None and out["log_probs"].shape == (1, n) @@ -559,18 +559,20 @@ def test_recurrent_adapter_flow(): ctx = adapter.init_context(n, device, conditioning=None) step_mask = torch.ones(n, dtype=torch.bool, device=device) - dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) actions = dist.sample() # carry should update when we record multiple steps h0 = ctx.carry["hidden"].clone() + lp, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) adapter.record( - ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True + ctx, step_mask, actions, dist, log_probs=lp, save_estimator_outputs=True ) # second step - dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) actions = dist.sample() + lp, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) adapter.record( - ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True + ctx, step_mask, actions, dist, log_probs=lp, save_estimator_outputs=True ) h1 = ctx.carry["hidden"].clone() assert torch.all(h1 == h0 + 1) @@ -638,14 +640,15 @@ def test_integration_recurrent_sequence_model_with_adapter( # Run two steps and verify carry and artifact shapes step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) for _ in range(2): - dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) actions = dist.sample() + lp, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) adapter.record( ctx, step_mask, actions, dist, - save_logprobs=True, + log_probs=lp, save_estimator_outputs=True, ) @@ -693,10 +696,12 @@ def test_integration_transformer_sequence_model_with_adapter( states = _SeqStates(tokens, vocab_size) step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) - dist, ctx = adapter.compute(cast(States, states), ctx, step_mask) + + dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) actions = dist.sample() + lp, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) adapter.record( - ctx, step_mask, actions, dist, save_logprobs=True, save_estimator_outputs=True + ctx, step_mask, actions, dist, log_probs=lp, save_estimator_outputs=True ) out = adapter.finalize(ctx) @@ -704,7 +709,3 @@ def test_integration_transformer_sequence_model_with_adapter( assert ( out["estimator_outputs"] is not None and out["estimator_outputs"].shape[0] == 1 ) - - -if __name__ == "__main__": - test_to_transition("DiscreteEBM") diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index a77a48b6..0e4af5d1 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -742,4 +742,4 @@ def test_hypergrid_exploration_smoke(): if __name__ == "__main__": - test_graph_triangle_smoke() + test_conditional_basic("tb") diff --git a/tutorials/examples/train_bitsequence_recurrent.py b/tutorials/examples/train_bitsequence_recurrent.py index 9c922a24..f5e8d63d 100644 --- a/tutorials/examples/train_bitsequence_recurrent.py +++ b/tutorials/examples/train_bitsequence_recurrent.py @@ -16,10 +16,10 @@ import torch from tqdm import tqdm +from gfn.adapters import RecurrentEstimatorAdapter from gfn.estimators import RecurrentDiscretePolicyEstimator from gfn.gflownet import PFBasedGFlowNet, TBGFlowNet from gfn.gym.bitSequence import BitSequence -from gfn.samplers import RecurrentEstimatorAdapter from gfn.states import DiscreteStates from gfn.utils.common import set_seed from gfn.utils.modules import RecurrentDiscreteSequenceModel From e2755e6f6b09b8c17b25bb17c14c8b392414222a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 05:06:38 -0400 Subject: [PATCH 19/27] removed strict type ceck --- docs/source/guides/estimator_adapters.md | 95 +++++++++++++----------- 1 file changed, 50 insertions(+), 45 deletions(-) diff --git a/docs/source/guides/estimator_adapters.md b/docs/source/guides/estimator_adapters.md index d7f08150..c005f353 100644 --- a/docs/source/guides/estimator_adapters.md +++ b/docs/source/guides/estimator_adapters.md @@ -19,7 +19,7 @@ The Sampler remains estimator-agnostic. Adapters own any estimator-specific stat ## Adapters -Adapters conform to an abstract class structure (see `gfn/samplers.py`): +Adapters conform to an abstract class structure: - Properties - `is_backward: bool` — whether the wrapped estimator is a backward policy. @@ -29,21 +29,25 @@ Adapters conform to an abstract class structure (see `gfn/samplers.py`): - `init_context(batch_size: int, device: torch.device, conditioning: Tensor|None) -> Any` - Allocates a rollout context once per batch (Sampler). Stores invariants (batch size, device, optional conditioning) and initializes any adapter state (e.g., recurrent carry) along with per-step artifact buffers. - - `compute(states_active: States, ctx: Any, step_mask: Tensor, **policy_kwargs) -> (Distribution, Any)` - - Runs the estimator forward on the active rows and returns a torch Distribution over actions. - - Must handle conditioning slicing with `step_mask` when applicable. + - `compute_dist(states_active: States, ctx: Any, step_mask: Tensor|None, **policy_kwargs) -> (Distribution, Any)` + - Runs the estimator forward on the provided rows and returns a torch Distribution over actions. + - Slices `conditioning` with `step_mask` when provided (non‑vectorized); uses full conditioning when `step_mask=None` (vectorized). + - Sets `ctx.current_estimator_output` to the raw estimator output. Vectorized callers may prefill `ctx.current_estimator_output` to reuse cached outputs. - - `record(ctx: Any, step_mask: Tensor, sampled_actions: Tensor, dist: Distribution, save_logprobs: bool, save_estimator_outputs: bool) -> None` - - Optionally record per-step artifacts into buffers owned by the context (e.g., log-probs, estimator outputs). Padding back to batch size happens here, using zeros for log-probs and `-inf` for estimator outputs to match existing conventions. + - `log_probs(actions_active: Tensor, dist: Distribution, ctx: Any, step_mask: Tensor|None, vectorized: bool = False) -> (Tensor, Any)` + - Computes log-probs from `dist` for the given actions. + - When `vectorized=False`, returns a padded `(N,)` tensor (zeros where `~step_mask`), with a strict inf-check (raises on `±inf`). + - When `vectorized=True`, returns the raw `dist.log_prob(...)` without padding or inf-check (vectorized paths can legitimately include `-inf` for illegal actions). - - `log_prob_of_actions(states_active: States, actions_active: Tensor, ctx: Any, step_mask: Tensor, **policy_kwargs) -> (Tensor, Any)` - - Computes log-probs for a batch of (state, action) pairs corresponding to `True` entries of `step_mask` and returns a padded `(N,)` vector. + - `record(ctx: Any, step_mask: Tensor, sampled_actions: Tensor, dist: Distribution, log_probs: Optional[Tensor], save_estimator_outputs: bool) -> None` + - Records per-step artifacts owned by the context. It never recomputes log-probs; pass `log_probs=None` to skip recording them. + - Pads estimator outputs to `(N, ...)` using `-inf` before appending when `save_estimator_outputs=True`. - - `finalize(ctx) -> -> dict[str, Optional[torch.Tensor]]` - - Realizes the buffers of the context object into tensors which can be used by the rest of the library (e.g., Trajectories objects). + - `finalize(ctx) -> dict[str, Optional[Tensor]]` + - Stacks per-step buffers into trajectory-level tensors, e.g. `(T, N, ...)`, returning `{"log_probs": Tensor|None, "estimator_outputs": Tensor|None}`. - `get_current_estimator_output(ctx: Any) -> Tensor|None` - - Convenience to expose the last estimator output after `compute`. + - Returns the last estimator output saved during `compute_dist`. - Context - The rollout context (created by `init_context`) owns: @@ -65,61 +69,62 @@ Adapters conform to an abstract class structure (see `gfn/samplers.py`): ## Vectorized vs Non-Vectorized Probability Paths -Probability calculators (PF/PB for trajectories and transitions) branch on `adapter.is_vectorized`: +Probability calculators (PF/PB for trajectories and transitions) branch on `adapter.is_vectorized` but use the same two adapter calls in both paths: -- Vectorized (fast path) - - Used when `adapter is None` or `adapter.is_vectorized is True`. - - Implements the legacy vectorized logic exactly (see the reference implementation below). - - No adapter calls are needed; the estimator is called on vectorized masks, and distributions compute `log_prob` over the masked actions. This path is the most efficient and is used during training when possible. +- `compute_dist(states_active, ctx, step_mask=None or mask)` +- `log_probs(actions_active, dist, ctx, step_mask=None or mask, vectorized=...)` + +Key differences: -- Non-Vectorized (per-step path) - - Used when `adapter.is_vectorized is False` (e.g., recurrent adapters). - - The calculators iterate per-step with legacy-accurate masks and alignment: - - PF (trajectories): `step_mask = ~states.is_sink_state[t] & ~actions.is_dummy[t]` - - PB (trajectories): align actions at time `t` with states at time `t+1`, and use `step_mask = ~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] & ~actions.is_dummy[t] & ~actions.is_exit[t]` and skip `t==0`. - - Transitions: use the same masks as the legacy vectorized functions and make a single adapter call per batch. - - No mask indexing with action ids is used; masking is solely via the legacy boolean masks and the Distribution handles illegal actions internally. +- Vectorized (fast path) + - `step_mask=None` and `vectorized=True`. + - May reuse cached estimator outputs by pre-setting `ctx.current_estimator_output` (e.g., PF with stored `trajectories.estimator_outputs`). + - `log_probs` returns raw `dist.log_prob(...)` and does not raise on `-inf` (illegal actions can produce `-inf`). -In both branches, behavior matches the legacy reference exactly, so tests compare outputs between vectorized and non-vectorized paths for parity. +- Non‑Vectorized (per-step path) + - Uses legacy-accurate boolean masks: + - PF (trajectories): `~states.is_sink_state[t] & ~actions.is_dummy[t]` + - PB (trajectories): align actions at `t` with states at `t+1`, using `~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] & ~actions.is_dummy[t] & ~actions.is_exit[t]`, skipping `t==0`. + - Transitions: one per-batch call with legacy masks. + - `log_probs` pads back to `(N,)` at inactive rows and raises if any `±inf` remains after masking. ## Integration with the Sampler The Sampler uses the adapter lifecycle: - `ctx = adapter.init_context(batch_size, device, conditioning)` - While some trajectories are active: - - `(dist, ctx) = adapter.compute(states[step_mask], ctx, step_mask, **policy_kwargs)` + - `(dist, ctx) = adapter.compute_dist(states[step_mask], ctx, step_mask, **policy_kwargs)` - Sample actions from `dist`; build actions for the full batch - - `adapter.record(ctx, step_mask, sampled_actions, dist, save_logprobs, save_estimator_outputs)` + - `log_probs = adapter.log_probs(valid_actions_tensor, dist, ctx, step_mask, vectorized=False)` (or `None` if skipping) + - `adapter.record(ctx, step_mask, sampled_actions=valid_actions_tensor, dist=dist, log_probs=log_probs, save_estimator_outputs=...)` - Step the environment forward/backward based on `adapter.is_backward` - After rollout: `artifacts = adapter.finalize(ctx)` and populate `Trajectories`. ## How to Implement a New Adapter -A new Adapter will only likely need changes to `compute`, `record`, and `log_prob_of_actions`. You can rely otherwise on the defaults. However we detail all of the steps below for completeness: - -1) Decide if your estimator needs a recurrent carry - some persistent state or cache that is utilized throughout the trajectory. - - If yes, set `is_vectorized = False` and implement `init_context` to initialize `carry`. Implement `compute` to update `carry` each step. - - If no, set `is_vectorized = True` and follow the default adapter pattern. +1) Decide on vectorization: + - If your estimator maintains a recurrent carry, set `is_vectorized = False` and implement carry management in `init_context` and `compute_dist`. + - Otherwise set `is_vectorized = True` and follow the default adapter pattern. -2) Implement `compute` - - Handle conditioning slicing with `step_mask` when conditioning is provided. - - Call your estimator and construct a torch Distribution via `to_probability_distribution(states_active, est_out, **policy_kwargs)`. +2) Implement `init_context(batch_size, device, conditioning)` + - Save invariants and allocate any adapter-specific state. Initialize empty per-step buffers. -3) Implement `record` - - If you want to capture per-step log-probs and/or estimator outputs, compute them for active rows and pad back to `(N,)` (log-probs) or `(N, ...)` (estimator outputs) before appending to the context buffers. +3) Implement `compute_dist(states_active, ctx, step_mask, **policy_kwargs)` + - Slice `conditioning` by `step_mask` for non‑vectorized calls; use full conditioning when `step_mask=None`. + - Call your estimator, set `ctx.current_estimator_output`, and return a Distribution via `to_probability_distribution`. -4) Implement `log_prob_of_actions` - - Given `(states_active, actions_active)` for the active rows, compute the Distribution (reusing the same forward logic) and return a padded `(N,)` vector of `log_prob`. - - Do not modify masks here; calculators pass in `step_mask` already built from existing masks. +4) Implement `log_probs(actions_active, dist, ctx, step_mask, vectorized=False)` + - Non‑vectorized: strict inf-check, return a padded `(N,)` tensor. + - Vectorized: return raw `dist.log_prob(...)` (may include `-inf` for illegal actions). -5) Implement `finalize` - - Given the contents of your context, return the trajectory-level objects required by the Sampler. +5) Implement `record(ctx, step_mask, sampled_actions, dist, log_probs, save_estimator_outputs)` + - Never recompute log-probs here; only store what was provided. + - When saving estimator outputs, pad to `(N, ...)` using `-inf`. -5) Mark `is_backward` if your estimator is a backward policy; the sampler will step the environment backward accordingly. +6) Implement `finalize(ctx)` + - Stack per-step buffers into `(T, N, ...)` tensors and return a dict of artifacts. -6) Performance Guidance - - For vectorized adapters, prefer the vectorized probability path (legacy implementation). It’s much faster and avoids per-step overhead. - - For non-vectorized adapters, keep per-step code minimal and avoid Python-side loops that can be vectorized. +7) Set `is_backward` appropriately so the Sampler chooses forward/backward environment steps. ## Reference: Legacy Implementations From ba6f0bddd8a444c76e43dace73436f3e90c356c5 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 05:35:45 -0400 Subject: [PATCH 20/27] shrank docs --- src/gfn/adapters.py | 158 ++++++++++------------------- src/gfn/samplers.py | 96 ++++++++---------- src/gfn/utils/prob_calculations.py | 155 +++++++++++++--------------- 3 files changed, 166 insertions(+), 243 deletions(-) diff --git a/src/gfn/adapters.py b/src/gfn/adapters.py index fac92e7f..7c8952b9 100644 --- a/src/gfn/adapters.py +++ b/src/gfn/adapters.py @@ -15,26 +15,25 @@ class EstimatorAdapter(ABC): """Adapter interface for estimator-specific policy behavior. - This abstract base class defines the minimal interface the Sampler relies on, - allowing us to keep one generic sampling loop while plugging in different estimator - behaviors (e.g., non‑recurrent, recurrent with carry, tempered variants) - without modifying the Sampler. - - The adapter owns an opaque RolloutContext object. The Sampler never inspects - it and simply passes it back to the adapter at each step. The adapter is - responsible for: - - initializing the context in `init_context`. - - compute the action distribution, while updating any internal state (e.g., - recurrent `carry`) - - compute the log probabilities of the actions in `log_probs`. - - recording per‑step artifacts in `record` (e.g., log_probs, - estimator outputs), typically with mask-aware padding. - - output trajectory-length artifacts via `finalize(ctx)` - - The context should be allocated once per rollout. Masking should be applied inside - the adapter (via `step_mask`) when slicing conditioning or padding per‑step - tensors back to full batch size. The Sampler can therefore be oblivious to - estimator details (conditioning, carry, etc.). + Keeps the sampling loop generic; estimator-specific logic lives here. + + Responsibilities: + - init_context(batch_size, device, conditioning): allocate rollout context. + - compute_dist(states_active, ctx, step_mask, **kw): run estimator on active + rows, return a torch Distribution, and update `ctx` if needed (e.g., carry, + cached outputs). + - log_probs(actions_active, dist, ctx, step_mask, vectorized): compute + log-probabilities for active rows; when ``vectorized=False`` return a + batch-sized tensor padded via ``step_mask``. + - record(ctx, step_mask, sampled_actions, dist, log_probs, save_estimator_outputs): + optionally materialize per‑step artifacts in `ctx`. + - finalize(ctx): stack recorded artifacts along time and return a dict. + + Notes: + - The sampler never inspects `ctx`; masking and padding happen inside the + adapter. + - ``is_backward`` selects forward vs backward environment steps. + - ``is_vectorized`` selects fast vectorized vs per‑step probability paths. """ @property @@ -93,11 +92,10 @@ def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: class RolloutContext: - """Structured, mutable context owned by adapters. + """Structured per‑rollout state owned by adapters. - Uses fixed attributes for core fields and an `extras` dict for adapter- - specific extensions without changing the class shape. This keeps most - accesses fast and typed while preserving flexibility similar to dicts. + Holds rollout invariants and optional per‑step buffers; use ``extras`` for + adapter‑specific fields without changing the class shape. """ __slots__ = ( @@ -128,7 +126,7 @@ def __init__( class DefaultEstimatorAdapter(EstimatorAdapter): - """Adapter for non-recurrent estimators (current default behavior). + """Adapter for non‑recurrent estimators (default). Overview -------- @@ -136,68 +134,29 @@ class DefaultEstimatorAdapter(EstimatorAdapter): It exposes the minimal interface required by the `EstimatorAdapter` abstract base class while keeping the sampler loop estimator-agnostic. - - If conditioning is provided, the estimator accepts `(states, conditioning)`; - otherwise it accepts `(states)`. - - The estimator provides `to_probability_distribution(states, est_out, **kw)` - returning a torch Distribution over actions for the masked states. - - The adapter owns an opaque rollout context `ctx` which the Sampler never reads. The - context (a TensorDict) is created once per rollout and mutated in place: - - - init_context: stores rollout invariants and optional conditioning. Also prepares - per‑step buffers for artifacts trajectory-level artifacts (log_probs, - estimator_outputs). - - - compute: - 1) Selects the appropriate estimator call signature depending on whether - conditioning is present. If conditioning is present, the adapter slices - it with `step_mask` so shapes match `states_active`. - 2) Calls the estimator forward pass to obtain the raw `est_out`. - 3) Converts `est_out` into a torch Distribution with - `to_probability_distribution`. - 4) Saves `est_out` into `ctx.current_estimator_output`. - - - record: - Materializes optional per‑step artifacts into context‑managed buffers with - mask‑aware padding back to the full rollout batch size: - * Log‑probs: computes `dist.log_prob(sampled_actions)` for active rows only, - then writes into a 1D tensor of length batch_size filled with zeros and masked - assignment for active positions. Appends this to a list (one tensor per time step). - * Estimator outputs: if requested, pads the last estimator output - (`ctx.current_estimator_output`) to shape `(batch_size, ...)` using `-inf` for - inactive rows and appends to a list (one tensor per time step). - - - finalize: Stacks recorded per‑step lists along the time dimension into tensors of - shape `(trajectory_legnth, batch_size, ...)` suitable for `Trajectories`. - Returns `None` for any artifact that was never recorded. - - Masking & Shapes - ---------------- - - `states_active` always corresponds to `states[~dones]` inside the sampler. - - The adapter receives `step_mask` (shape `(N,)`) to slice any step‑dependent - inputs (e.g., conditioning) and to pad per‑step outputs to the full batch. - - Padded tensors use `0.0` for log‑probs and `-inf` for estimator outputs to - maintain compatibility with downstream code. - - Backward/Forward Direction - -------------------------- - - `is_backward` is forwarded from the underlying estimator so the sampler can - choose the appropriate environment transition (forward vs backward). - - Vectorized Probability Path - -------------------------- - - `is_vectorized` is used by the Sampler to choose the appropriate probability path. - - Vectorized adapters always use faster paths in probability calculators. - Non-vectorized adapters (e.g., recurrent) use per-step paths with masking and - alignment identical to the legacy reference. - - Performance Notes - ----------------- - - `ctx` is allocated once per rollout and mutated in place to avoid per‑step - overhead. - - If you know trajectory length bounds, you can extend this adapter to - pre‑allocate fixed‑size storage in `init_context` rather than appending to - Python lists. + Workflow with RolloutContext: + - ``init_context(batch_size, device, conditioning)``: store invariants and + allocate per‑step buffers. + - ``compute_dist(states_active, ctx, step_mask, **kw)``: slice conditioning + by ``step_mask`` when provided, run the estimator on active rows, cache + ``est_out`` in ``ctx.current_estimator_output``, and return a Distribution. + - ``log_probs(actions_active, dist, ctx, step_mask, vectorized)``: compute + log‑probs for active rows; when ``vectorized=False``, return a batch‑padded + tensor using ``step_mask``. + - ``record(...)``: append per‑step artifacts; pad log‑probs with ``0.0`` and + estimator outputs with ``-inf``. + - ``finalize(ctx)``: stack recorded lists along time to tensors of shape + ``(T, N, ...)``. + + Masking and path selection + - ``states_active == states[~dones]``; ``step_mask`` has shape ``(N,)``. + - ``is_backward`` is forwarded from the estimator. + - ``is_vectorized == True`` enables vectorized probability calculators when + available. + + Performance + - One context per rollout; mutate in place. If trajectory length bounds are + known, pre‑allocation of those buffers is possible. """ def __init__(self, estimator: "Estimator") -> None: @@ -352,24 +311,13 @@ def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): - """Adapter for recurrent estimators that maintain a rollout carry (hidden state). - - Differences from `DefaultEstimatorAdapter`: - - is_vectorized = False: runs sequential, per‑step probability calculators for - PF/PB. PB aligns action at t with state at t+1, t==0 skipped). - - Rollout context manages `ctx.carry` which contains the hidden state of the - recurrent estimator. It is initialized once via `estimator.init_carry(batch_size, - device)` and updated every step. - - `compute(states, ctx, ...)` calls `estimator(states, ctx.carry) -> - (estimator_outputs, new_carry)`, updates `ctx.carry`, then builds the - Distribution from `estimator_outputs`. - - recording mirrors the default but pads per‑step tensors to batch size and - stacks into `(T, N, ...)`. - - No action‑id mask indexing; illegal actions are handled by the Distribution. - - Requires the estimator to implement: - - `init_carry(batch_size: int, device: torch.device)` - - a recurrent forward returning `(estimator_outputs, new_carry)`. + """Adapter for recurrent estimators with rollout carry (hidden state). + + - Requires ``estimator.init_carry(batch_size, device)`` and a forward that + returns ``(estimator_outputs, new_carry)``. + - Maintains ``ctx.carry`` across steps and updates it each call. + - ``is_vectorized=False``; probability calculators use the per‑step path + with legacy masks/alignment. """ def __init__(self, estimator: "Estimator") -> None: diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 1bed10a0..9373f31e 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -14,15 +14,17 @@ class Sampler: - """Wrapper for a PolicyEstimator that enables sampling from GFlowNet environments. + """Adapter‑driven sampler for GFlowNet environments. - A Sampler encapsulates a PolicyEstimator and provides methods to sample individual - actions or complete trajectories from GFlowNet environments. It can be used for - both forward and backward sampling, depending on the estimator's configuration. + Delegates policy logic to an adapter: the adapter builds action + distributions, computes step log‑probs, and records artifacts into a + rollout context. Direction (forward/backward) is determined by + ``adapter.is_backward``. Attributes: - estimator: The PolicyEstimator used for sampling actions and computing - probability distributions. + estimator: The underlying policy estimator (adapter wraps it). + adapter: The adapter used to build action distributions, compute step log‑probs, + and record artifacts into a rollout context. """ def __init__( @@ -54,36 +56,29 @@ def sample_actions( ctx: Any | None = None, **policy_kwargs: Any, ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]: - """Samples actions from the given states using the policy estimator. + """Sample one step from ``states`` via the adapter. - This method samples actions from the probability distribution defined by the - policy estimator. - - When sampling off-policy, ensure to set `save_logprobs=False`. Log probabilities - for off-policy actions should be calculated separately during GFlowNet training. + Initializes or reuses a rollout context with ``adapter.init_context``, + builds a Distribution with ``adapter.compute_dist``, optionally computes + log‑probs with ``adapter.log_probs``, and lets ``adapter.record`` + persist per‑step artifacts. Args: - env: The environment where the states and actions are defined. - states: A batch of states to sample actions from. - conditioning: Optional tensor of conditioning information for conditional - policies. If provided, the estimator must support conditional sampling. - save_estimator_outputs: If True, returns the raw outputs from the estimator - before conversion to probability distributions. This is useful for - off-policy training with tempered policies. - save_logprobs: If True, calculates and returns the log probabilities of - the sampled actions under the policy distribution. This is useful for - on-policy training. - **policy_kwargs: Keyword arguments passed to the estimator's - `to_probability_distribution` method. Common parameters include: - - `temperature`: Scalar to divide logits by before softmax - - `epsilon`: Probability of choosing random actions (exploration) - - `sf_bias`: Bias to apply to exit action logits + env: Environment providing action/state conversion utilities. + states: Batch of states to act on. + conditioning: Optional conditioning for conditional policies. + save_estimator_outputs: If True, return the raw estimator outputs + cached by the adapter for this step. Useful for off-policy training + with tempered policies. + save_logprobs: If True, return per‑step log‑probs padded to batch. + Useful for on-policy training. + **policy_kwargs: Extra kwargs forwarded to + ``to_probability_distribution``. Returns: - A tuple containing: - - An Actions object with the sampled actions - - Optional tensor of log probabilities (if save_logprobs=True) - - Optional tensor of estimator outputs (if save_estimator_outputs=True) + ``(Actions, log_probs | None, estimator_outputs | None)``. The + estimator outputs come from + ``adapter.get_current_estimator_output(ctx)`` when requested. """ if ctx is None: ctx = self.adapter.init_context( @@ -143,34 +138,27 @@ def sample_trajectories( save_logprobs: bool = False, **policy_kwargs: Any, ) -> Trajectories: - """Samples complete trajectories from the environment. + """Roll out complete trajectories using the adapter. - This method samples trajectories by sequentially sampling actions from the - policy estimator. It supports both forward and backward sampling, depending on - the estimator's `is_backward` flag. If forward sampling, it samples until all - trajectories reach the sink state. If backward sampling, it samples until all - trajectories reach the initial state. + Reuses a single rollout context across steps, calling + ``compute_dist``/``log_probs``/``record`` each iteration and + ``finalize`` at the end to stack trajectory‑level artifacts. Uses + ``adapter.is_backward`` to choose the environment step function. Args: - env: The environment to sample trajectories from. - n: Number of trajectories to sample, all starting from s0. Must be - provided if `states` is None. - states: Initial states to start trajectories from. It should have batch_shape - of length 1 (no trajectory dim). If `None`, `n` must be provided and we - initialize `n` trajectories with the environment's initial state. - conditioning: Optional tensor of conditioning information for conditional - policies. Must match the batch shape of states. - save_estimator_outputs: If True, saves the estimator outputs for each - step. Useful for off-policy training with tempered policies. - save_logprobs: If True, calculates and saves the log probabilities of - sampled actions. Useful for on-policy training. - **policy_kwargs: Keyword arguments passed to the policy estimator. - See `sample_actions` for details. + env: Environment to sample in. + n: Number of trajectories if ``states`` is None. + states: Starting states (batch shape length 1) or ``None``. + conditioning: Optional conditioning aligned with the batch. + save_estimator_outputs: If True, store per‑step estimator outputs. Useful + for off-policy training with tempered policies. + save_logprobs: If True, store per‑step log‑probs. Useful for on-policy + training. + **policy_kwargs: Extra kwargs forwarded to the policy. Returns: - A Trajectories object containing the sampled trajectories with batch_shape - (max_length+1, n_trajectories) for states and (max_length, n_trajectories) - for actions. + A ``Trajectories`` with stacked states/actions and any artifacts + produced by ``adapter.finalize``. Note: For backward trajectories, the reward is computed at the initial state diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 58372227..3d51271e 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -22,32 +22,29 @@ def get_trajectory_pfs_and_pbs( pb_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate PF and PB log-probabilities for trajectories. + """Calculate PF and PB log‑probabilities for trajectories via adapters. - This function delegates to :func:`get_trajectory_pfs` and - :func:`get_trajectory_pbs`, forwarding optional adapter(s) and policy kwargs. + Delegates to ``get_trajectory_pfs`` and ``get_trajectory_pbs`` while + forwarding optional adapters and policy kwargs. - Vectorized vs non-vectorized - - If the adapter is None or ``adapter.is_vectorized is True``, the legacy - vectorized path is used (fast path, strict parity with legacy code). - - If ``adapter.is_vectorized is False`` (e.g., recurrent), a non‑vectorized - per‑step path is used with legacy-accurate masks and alignment. + Vectorized vs non‑vectorized: + - ``adapter is None`` or ``adapter.is_vectorized=True`` → vectorized path + (fast, parity with legacy). + - ``adapter.is_vectorized=False`` (e.g., recurrent) → per‑step path with + legacy masks/alignment. Args: pf: Forward policy estimator. - pb: Backward policy estimator, or ``None`` if the DAG is a tree (PB=1). - trajectories: Trajectories container to evaluate. - fill_value: Fill used for invalid states (e.g., sink state positions). - recalculate_all_logprobs: If ``True``, recompute PF even if cached. - pf_adapter: Adapter for PF (vectorized vs non‑vectorized decision). - pb_adapter: Adapter for PB (vectorized vs non‑vectorized decision). - **policy_kwargs: Extra kwargs passed to estimator's - ``to_probability_distribution`` (e.g., temperature, epsilon, sf_bias). + pb: Backward policy estimator, or ``None`` for trees (PB=1). + trajectories: Trajectories to evaluate. + fill_value: Value used to pad invalid positions. + recalculate_all_logprobs: If True, recompute PF even if cached. + pf_adapter: Optional adapter for PF. + pb_adapter: Optional adapter for PB. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - Tuple[Tensor, Tensor]: - - PF log-probs with shape ``(T, N)`` - - PB log-probs with shape ``(T, N)`` + ``(log_pf[T,N], log_pb[T,N])`` """ # fill value is the value used for invalid states (sink state usually) @@ -90,26 +87,24 @@ def get_trajectory_pfs( adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: - """Calculate PF log-probabilities for trajectories. + """Calculate PF log‑probabilities for trajectories. - Vectorized vs non-vectorized - - Vectorized when ``adapter is None`` or ``adapter.is_vectorized is True``: - uses the legacy vectorized implementation (strict parity with reference). - - Non‑vectorized when ``adapter.is_vectorized is False``: evaluates per‑step - using legacy masks (PF: ``~states.is_sink_state[t] & ~actions.is_dummy[t]``), - passing the active subset to the adapter without any action‑id mask indexing. + - Vectorized: ``adapter is None`` or ``adapter.is_vectorized=True`` → legacy + vectorized implementation. + - Non‑vectorized: ``adapter.is_vectorized=False`` → per‑step evaluation with + legacy masks ``~is_sink_state[t] & ~is_dummy[t]``; no action‑id indexing. Args: pf: Forward policy estimator. - trajectories: Trajectories container to evaluate. - fill_value: Fill used for invalid states (e.g., sink state positions). - recalculate_all_logprobs: If ``True``, recompute PF even if cached. - adapter: Optional adapter controlling vectorized vs non‑vectorized path. - **policy_kwargs: Extra kwargs passed to - ``to_probability_distribution`` (e.g., temperature, epsilon). + trajectories: Trajectories to evaluate. + fill_value: Value used to pad invalid positions. + recalculate_all_logprobs: If True, recompute PF even if cached. Useful for + off-policy training. + adapter: Optional adapter deciding the evaluation path. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - Tensor of shape ``(T, N)`` containing PF log-probabilities. + ``log_pf`` of shape ``(T, N)``. Raises: ValueError: If backward trajectories are provided. @@ -269,27 +264,23 @@ def get_trajectory_pbs( adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: - """Calculate PB log-probabilities for trajectories. + """Calculate PB log‑probabilities for trajectories. - Vectorized vs non-vectorized - - Vectorized when ``adapter is None`` or ``adapter.is_vectorized is True``: - uses the legacy vectorized implementation (strict parity with reference). - - Non‑vectorized when ``adapter.is_vectorized is False``: evaluates per‑step - using legacy masks/alignment: - PB aligns actions at time ``t`` with states at time ``t+1`` and uses - ``~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] - & ~actions.is_dummy[t] & ~actions.is_exit[t]``, skipping ``t==0``. + - Vectorized: ``adapter is None`` or ``adapter.is_vectorized=True``. + - Non‑vectorized: ``adapter.is_vectorized=False`` with legacy alignment + (action at ``t`` with state at ``t+1``) and mask + ``~is_sink_state[t+1] & ~is_initial_state[t+1] & ~is_dummy[t] & ~is_exit[t]``; + skip ``t==0``. Args: - pb: Backward policy estimator, or ``None`` for tree DAGs (PB=1). - trajectories: Trajectories container to evaluate. - fill_value: Fill used for invalid states (e.g., sink state positions). - adapter: Optional adapter controlling vectorized vs non‑vectorized path. - **policy_kwargs: Extra kwargs passed to - ``to_probability_distribution``. + pb: Backward policy estimator, or ``None`` for trees (PB=1). + trajectories: Trajectories to evaluate. + fill_value: Value used to pad invalid positions. + adapter: Optional adapter deciding the evaluation path. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - Tensor of shape ``(T, N)`` containing PB log-probabilities. + ``log_pb`` of shape ``(T, N)``. Raises: ValueError: If backward trajectories are provided. @@ -440,25 +431,25 @@ def get_transition_pfs_and_pbs( pb_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate PF and PB log-probabilities for transitions. + """Calculate PF and PB log‑probabilities for transitions. - Vectorized vs non-vectorized mirrors trajectories: - - Vectorized (adapter is None or ``is_vectorized=True``): legacy vectorized path. - - Non‑vectorized (``is_vectorized=False``): per‑batch adapter call with legacy - masks; no action‑id mask indexing. + Mirrors the trajectories logic: + - Vectorized when ``adapter is None`` or ``is_vectorized=True``. + - Non‑vectorized when ``is_vectorized=False``: per‑batch adapter call with + legacy masks; no action‑id indexing. Args: pf: Forward policy estimator. - pb: Backward policy estimator, or ``None`` for tree DAGs (PB=1). - transitions: Transitions container to evaluate. - recalculate_all_logprobs: If ``True``, recompute PF even if cached. + pb: Backward policy estimator, or ``None`` for trees (PB=1). + transitions: Transitions to evaluate. + recalculate_all_logprobs: If True, recompute PF even if cached. Useful for + off-policy training. pf_adapter: Optional adapter for PF. pb_adapter: Optional adapter for PB. - **policy_kwargs: Extra kwargs passed to - ``to_probability_distribution``. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - Tuple[Tensor, Tensor]: PF and PB log-probabilities of shape ``(M,)``. + ``(log_pf[M], log_pb[M])``. Raises: ValueError: If backward transitions are provided. @@ -495,24 +486,22 @@ def get_transition_pfs( adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: - """Calculate PF log-probabilities for transitions. + """Calculate PF log‑probabilities for transitions. - Vectorized vs non-vectorized - - Vectorized when ``adapter is None`` or ``adapter.is_vectorized is True``: - legacy vectorized path. - - Non‑vectorized when ``adapter.is_vectorized is False``: single adapter call - with legacy masks and no action‑id indexing. + - Vectorized: ``adapter is None`` or ``is_vectorized=True``. + - Non‑vectorized: ``is_vectorized=False``; single adapter call with legacy + masks; no action‑id indexing. Args: pf: Forward policy estimator. - transitions: Transitions container to evaluate. - recalculate_all_logprobs: If ``True``, recompute PF even if cached. - adapter: Optional adapter controlling vectorized vs non‑vectorized path. - **policy_kwargs: Extra kwargs passed to - ``to_probability_distribution``. + transitions: Transitions to evaluate. + recalculate_all_logprobs: If True, recompute PF even if cached. Useful for + off-policy training. + adapter: Optional adapter deciding the evaluation path. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - Tensor of shape ``(M,)`` containing PF log-probabilities. + ``log_pf`` of shape ``(M,)``. """ adapter = maybe_instantiate_adapter(pf, adapter) assert adapter is not None @@ -557,23 +546,21 @@ def get_transition_pbs( adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: - """Calculate PB log-probabilities for transitions. + """Calculate PB log‑probabilities for transitions. - Vectorized vs non-vectorized - - Vectorized when ``adapter is None`` or ``adapter.is_vectorized is True``: - legacy vectorized path. - - Non‑vectorized when ``adapter.is_vectorized is False``: single adapter call - with legacy masks and no action‑id indexing. + - Vectorized: ``adapter is None`` or ``is_vectorized=True``. + - Non‑vectorized: ``is_vectorized=False``; single adapter call with legacy + masks; no action‑id indexing. Recurrent adapters are not supported for + transitions. Args: - pb: Backward policy estimator, or ``None`` for tree DAGs (PB=1). - transitions: Transitions container to evaluate. - adapter: Optional adapter controlling vectorized vs non‑vectorized path. - **policy_kwargs: Extra kwargs passed to - ``to_probability_distribution``. + pb: Backward policy estimator, or ``None`` for trees (PB=1). + transitions: Transitions to evaluate. + adapter: Optional adapter deciding the evaluation path. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - Tensor of shape ``(M,)`` containing PB log-probabilities. + ``log_pb`` of shape ``(M,)``. """ # TODO: We support a fill_value for trajectories, but not for transitions. From db36953e57ed5721a6644db1946dedc637aff13d Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 12:04:45 -0400 Subject: [PATCH 21/27] added notes --- src/gfn/adapters.py | 2 ++ src/gfn/samplers.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/gfn/adapters.py b/src/gfn/adapters.py index 7c8952b9..0d519c39 100644 --- a/src/gfn/adapters.py +++ b/src/gfn/adapters.py @@ -267,6 +267,7 @@ def log_probs( return step_lp, ctx + # To merge with log probs / calc dist. def record( self, ctx: Any, @@ -290,6 +291,7 @@ def record( padded[step_mask] = estimator_outputs ctx.trajectory_estimator_outputs.append(padded) + # To move to sampler. def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: """Stack recorded per-step artifacts along time into trajectory-level tensors.""" log_probs = ( diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 9373f31e..03539308 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -294,9 +294,10 @@ def sample_trajectories( # Broadcast conditioning tensor to match states batch shape if needed if conditioning is not None: - # The states have batch shape (max_length, n_trajectories) - # The conditioning tensor should have shape (n_trajectories,) or (n_trajectories, 1) - # We need to broadcast it to (max_length, n_trajectories, 1) for the estimator + # The states have batch shape (max_length, n_trajectories). The + # conditioning tensor should have shape (n_trajectories,) or + # (n_trajectories, 1). We need to broadcast it to (max_length, + # n_trajectories, 1) for the estimator if len(conditioning.shape) == 1: # conditioning has shape (n_trajectories,) conditioning = ( From 3226660602f732d31d13537c44559f02dfad645b Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 13:11:28 -0400 Subject: [PATCH 22/27] removed finalize --- docs/source/guides/estimator_adapters.md | 13 +--- src/gfn/adapters.py | 19 ------ src/gfn/samplers.py | 45 ++++++++----- testing/test_samplers_and_trajectories.py | 79 ++++++++++++++++------- 4 files changed, 87 insertions(+), 69 deletions(-) diff --git a/docs/source/guides/estimator_adapters.md b/docs/source/guides/estimator_adapters.md index c005f353..ecafac58 100644 --- a/docs/source/guides/estimator_adapters.md +++ b/docs/source/guides/estimator_adapters.md @@ -43,9 +43,6 @@ Adapters conform to an abstract class structure: - Records per-step artifacts owned by the context. It never recomputes log-probs; pass `log_probs=None` to skip recording them. - Pads estimator outputs to `(N, ...)` using `-inf` before appending when `save_estimator_outputs=True`. - - `finalize(ctx) -> dict[str, Optional[Tensor]]` - - Stacks per-step buffers into trajectory-level tensors, e.g. `(T, N, ...)`, returning `{"log_probs": Tensor|None, "estimator_outputs": Tensor|None}`. - - `get_current_estimator_output(ctx: Any) -> Tensor|None` - Returns the last estimator output saved during `compute_dist`. @@ -96,15 +93,12 @@ The Sampler uses the adapter lifecycle: - `(dist, ctx) = adapter.compute_dist(states[step_mask], ctx, step_mask, **policy_kwargs)` - Sample actions from `dist`; build actions for the full batch - `log_probs = adapter.log_probs(valid_actions_tensor, dist, ctx, step_mask, vectorized=False)` (or `None` if skipping) - - `adapter.record(ctx, step_mask, sampled_actions=valid_actions_tensor, dist=dist, log_probs=log_probs, save_estimator_outputs=...)` - Step the environment forward/backward based on `adapter.is_backward` -- After rollout: `artifacts = adapter.finalize(ctx)` and populate `Trajectories`. ## How to Implement a New Adapter - 1) Decide on vectorization: - - If your estimator maintains a recurrent carry, set `is_vectorized = False` and implement carry management in `init_context` and `compute_dist`. - - Otherwise set `is_vectorized = True` and follow the default adapter pattern. + - If your estimator requires non-vectorized roll out (e.g., a recurrent carry), set `is_vectorized = False` and implement carry management in `init_context` and `compute_dist`. + - Otherwise set `is_vectorized = True` and follow the default adapter pattern. This will be true in most cases. 2) Implement `init_context(batch_size, device, conditioning)` - Save invariants and allocate any adapter-specific state. Initialize empty per-step buffers. @@ -121,9 +115,6 @@ The Sampler uses the adapter lifecycle: - Never recompute log-probs here; only store what was provided. - When saving estimator outputs, pad to `(N, ...)` using `-inf`. -6) Implement `finalize(ctx)` - - Stack per-step buffers into `(T, N, ...)` tensors and return a dict of artifacts. - 7) Set `is_backward` appropriately so the Sampler chooses forward/backward environment steps. ## Reference: Legacy Implementations diff --git a/src/gfn/adapters.py b/src/gfn/adapters.py index 0d519c39..84445fce 100644 --- a/src/gfn/adapters.py +++ b/src/gfn/adapters.py @@ -27,7 +27,6 @@ class EstimatorAdapter(ABC): batch-sized tensor padded via ``step_mask``. - record(ctx, step_mask, sampled_actions, dist, log_probs, save_estimator_outputs): optionally materialize per‑step artifacts in `ctx`. - - finalize(ctx): stack recorded artifacts along time and return a dict. Notes: - The sampler never inspects `ctx`; masking and padding happen inside the @@ -145,8 +144,6 @@ class while keeping the sampler loop estimator-agnostic. tensor using ``step_mask``. - ``record(...)``: append per‑step artifacts; pad log‑probs with ``0.0`` and estimator outputs with ``-inf``. - - ``finalize(ctx)``: stack recorded lists along time to tensors of shape - ``(T, N, ...)``. Masking and path selection - ``states_active == states[~dones]``; ``step_mask`` has shape ``(N,)``. @@ -291,22 +288,6 @@ def record( padded[step_mask] = estimator_outputs ctx.trajectory_estimator_outputs.append(padded) - # To move to sampler. - def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: - """Stack recorded per-step artifacts along time into trajectory-level tensors.""" - log_probs = ( - torch.stack(ctx.trajectory_log_probs, dim=0) - if ctx.trajectory_log_probs - else None - ) - estimator_outputs = ( - torch.stack(ctx.trajectory_estimator_outputs, dim=0) - if ctx.trajectory_estimator_outputs - else None - ) - - return {"log_probs": log_probs, "estimator_outputs": estimator_outputs} - def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: """Expose the most recent per-step estimator output saved during `compute`.""" return getattr(ctx, "current_estimator_output", None) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 03539308..eecd1d18 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -141,8 +141,7 @@ def sample_trajectories( """Roll out complete trajectories using the adapter. Reuses a single rollout context across steps, calling - ``compute_dist``/``log_probs``/``record`` each iteration and - ``finalize`` at the end to stack trajectory‑level artifacts. Uses + ``compute_dist`` & ``log_probs`` each iteration. Uses ``adapter.is_backward`` to choose the environment step function. Args: @@ -157,8 +156,7 @@ def sample_trajectories( **policy_kwargs: Extra kwargs forwarded to the policy. Returns: - A ``Trajectories`` with stacked states/actions and any artifacts - produced by ``adapter.finalize``. + A ``Trajectories`` with stacked states/actions and any artifacts. Note: For backward trajectories, the reward is computed at the initial state @@ -277,20 +275,33 @@ def sample_trajectories( dones = dones | new_dones trajectories_states.append(states) - # Stack all states and actions + # Stack all states and actions. stacked_states = env.States.stack(trajectories_states) - stacked_actions = env.Actions.stack(trajectories_actions)[ - 1: - ] # Drop dummy action - # Finalize stacked trajectory artifacts from the context (already shaped (T, N, ...)) - trajectory_artifacts = self.adapter.finalize(ctx) # type: ignore[attr-defined] - stacked_logprobs = trajectory_artifacts.get("log_probs", None) - stacked_estimator_outputs = trajectory_artifacts.get("estimator_outputs", None) - - if stacked_logprobs is not None and len(stacked_logprobs) == 0: - stacked_logprobs = None - if stacked_estimator_outputs is not None and len(stacked_estimator_outputs) == 0: - stacked_estimator_outputs = None + + # Stack actions, drop dummy action. + stacked_actions = env.Actions.stack(trajectories_actions)[1:] + + # Get trajectory artifacts from the context (already shaped (T, N, ...)) + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + + # Stacked logprobs and estimator outputs are only None if there are no + # valid trajectories. + if stacked_logprobs is not None: + if len(stacked_logprobs) == 0: + stacked_logprobs = None + + if stacked_estimator_outputs is not None: + if len(stacked_estimator_outputs) == 0: + stacked_estimator_outputs = None # Broadcast conditioning tensor to match states batch shape if needed if conditioning is not None: diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index c247e123..914cceaf 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -522,7 +522,7 @@ def test_rollout_context_basic(): assert ctx.extras["foo"] == 123 -def test_default_adapter_compute_record_finalize(): +def test_default_adapter_compute_record(): adapter = DefaultEstimatorAdapter(cast(Estimator, _DummyEstimator())) device = torch.device("cpu") n = 5 @@ -536,11 +536,20 @@ def test_default_adapter_compute_record_finalize(): adapter.record( ctx, step_mask, actions, dist, log_probs=log_probs, save_estimator_outputs=True ) - out = adapter.finalize(ctx) - assert out["log_probs"] is not None and out["log_probs"].shape == (1, n) - assert out["estimator_outputs"] is not None and out["estimator_outputs"].shape[ - :2 - ] == (1, n) + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + assert stacked_logprobs is not None + assert stacked_logprobs.shape == (1, n) + assert stacked_estimator_outputs is not None + assert stacked_estimator_outputs.shape[:2] == (1, n) def test_recurrent_adapter_requires_init_carry(): @@ -576,11 +585,20 @@ def test_recurrent_adapter_flow(): ) h1 = ctx.carry["hidden"].clone() assert torch.all(h1 == h0 + 1) - out = adapter.finalize(ctx) - assert out["log_probs"] is not None and out["log_probs"].shape == (2, n) - assert out["estimator_outputs"] is not None and out["estimator_outputs"].shape[ - :2 - ] == (2, n) + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + assert stacked_logprobs is not None + assert stacked_logprobs.shape == (2, n) + assert stacked_estimator_outputs is not None + assert stacked_estimator_outputs.shape[:2] == (2, n) # ---------------------- Integration with real recurrent modules ---------------------- @@ -652,13 +670,21 @@ def test_integration_recurrent_sequence_model_with_adapter( save_estimator_outputs=True, ) - out = adapter.finalize(ctx) - log_probs = out["log_probs"] - estimator_outputs = out["estimator_outputs"] - assert log_probs is not None - assert log_probs.shape[0] == 2 - assert estimator_outputs is not None - assert estimator_outputs.shape[0] == 2 + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + + assert stacked_logprobs is not None + assert stacked_logprobs.shape[0] == 2 + assert stacked_estimator_outputs is not None + assert stacked_estimator_outputs.shape[0] == 2 @pytest.mark.parametrize("positional_embedding", ["learned", "sinusoidal"]) @@ -704,8 +730,17 @@ def test_integration_transformer_sequence_model_with_adapter( ctx, step_mask, actions, dist, log_probs=lp, save_estimator_outputs=True ) - out = adapter.finalize(ctx) - assert out["log_probs"] is not None and out["log_probs"].shape[0] == 1 - assert ( - out["estimator_outputs"] is not None and out["estimator_outputs"].shape[0] == 1 + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None ) + assert stacked_logprobs is not None + assert stacked_logprobs.shape[0] == 1 + assert stacked_estimator_outputs is not None + assert stacked_estimator_outputs.shape[0] == 1 From 4c2c1dfff3d1dd90c7a712a1f5cf131ed345cce5 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 13:32:42 -0400 Subject: [PATCH 23/27] removed check_cond_forward --- src/gfn/adapters.py | 15 +++++-- src/gfn/utils/handlers.py | 27 +------------ testing/test_probability_calculations.py | 51 +++++++++++++++++++----- 3 files changed, 54 insertions(+), 39 deletions(-) diff --git a/src/gfn/adapters.py b/src/gfn/adapters.py index 84445fce..37e0f343 100644 --- a/src/gfn/adapters.py +++ b/src/gfn/adapters.py @@ -6,7 +6,10 @@ from torch.distributions import Distribution from gfn.states import States -from gfn.utils.handlers import check_cond_forward +from gfn.utils.handlers import ( + has_conditioning_exception_handler, + no_conditioning_exception_handler, +) if TYPE_CHECKING: from gfn.estimators import Estimator @@ -226,9 +229,13 @@ def compute_dist( else: cond_active = ctx.conditioning[step_mask] - estimator_outputs = check_cond_forward( - self._estimator, "estimator", states_active, cond_active - ) + # Call estimator with or without conditioning. + if cond_active is not None: + with has_conditioning_exception_handler("estimator", self._estimator): + estimator_outputs = self._estimator(states_active, cond_active) + else: + with no_conditioning_exception_handler("estimator", self._estimator): + estimator_outputs = self._estimator(states_active) # Build the distribution. dist = self._estimator.to_probability_distribution( diff --git a/src/gfn/utils/handlers.py b/src/gfn/utils/handlers.py index 84ba9c72..bd494a2a 100644 --- a/src/gfn/utils/handlers.py +++ b/src/gfn/utils/handlers.py @@ -1,33 +1,8 @@ import warnings from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional - -import torch +from typing import Any from gfn.containers import Container -from gfn.states import States - -if TYPE_CHECKING: - from gfn.estimators import Estimator # type: ignore - - -def check_cond_forward( - module: "Estimator", - module_name: str, - states: States, - condition: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Call estimator forward with or without conditioning using error handlers. - - Uses the same exception handling policy as the legacy utility to keep - behavior consistent across adapters and probability code paths. - """ - if condition is not None: - with has_conditioning_exception_handler(module_name, module): - return module(states, condition) - else: - with no_conditioning_exception_handler(module_name, module): - return module(states) @contextmanager diff --git a/testing/test_probability_calculations.py b/testing/test_probability_calculations.py index 47d2d817..f5207561 100644 --- a/testing/test_probability_calculations.py +++ b/testing/test_probability_calculations.py @@ -6,7 +6,10 @@ from gfn.gym import HyperGrid from gfn.preprocessors import IdentityPreprocessor from gfn.samplers import Sampler -from gfn.utils.handlers import check_cond_forward +from gfn.utils.handlers import ( + has_conditioning_exception_handler, + no_conditioning_exception_handler, +) from gfn.utils.prob_calculations import ( get_trajectory_pbs, get_trajectory_pfs, @@ -60,7 +63,13 @@ def _legacy_get_trajectory_pfs( (traj_len,) + cond_dim )[state_mask] - estimator_outputs = check_cond_forward(pf, "pf", valid_states, masked_cond) + # Call estimator with or without conditioning. + if masked_cond is not None: + with has_conditioning_exception_handler("pf", pf): + estimator_outputs = pf(valid_states, masked_cond) + else: + with no_conditioning_exception_handler("pf", pf): + estimator_outputs = pf(valid_states) valid_log_pf_actions = pf.to_probability_distribution( valid_states, estimator_outputs @@ -167,7 +176,15 @@ def _legacy_get_trajectory_pbs( masked_cond = trajectories.conditioning[state_mask] if pb is not None: - estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond) + + # Call estimator with or without conditioning. + if masked_cond is not None: + with has_conditioning_exception_handler("pb", pb): + estimator_outputs = pb(valid_states, masked_cond) + else: + with no_conditioning_exception_handler("pb", pb): + estimator_outputs = pb(valid_states) + valid_log_pb_actions = pb.to_probability_distribution( valid_states, estimator_outputs ).log_prob(valid_actions.tensor) @@ -332,9 +349,11 @@ def test_adapter_log_probs_precomputed_matches_forward(): env, pf_estimator, _ = _build_env_and_pf() states = env.reset(batch_shape=(5,)) - # Compute estimator outputs once (precomputed path) - est_out = check_cond_forward(pf_estimator, "pf", states, None) - dist = pf_estimator.to_probability_distribution(states, est_out) + # Compute estimator outputs once (precomputed path) - no conditioning. + with no_conditioning_exception_handler("pf", pf_estimator): + estimator_outputs = pf_estimator(states) + + dist = pf_estimator.to_probability_distribution(states, estimator_outputs) with torch.no_grad(): actions_tensor = dist.sample() @@ -348,7 +367,7 @@ def test_adapter_log_probs_precomputed_matches_forward(): lp1, _ = adapter.log_probs(actions_tensor, dist1, ctx1, step_mask, vectorized=False) # Precomputed: adapter reuses provided estimator outputs (fast path) - ctx2.current_estimator_output = est_out + ctx2.current_estimator_output = estimator_outputs dist2, ctx2 = adapter.compute_dist(states, ctx2, step_mask) lp2, _ = adapter.log_probs(actions_tensor, dist2, ctx2, step_mask, vectorized=False) @@ -369,7 +388,14 @@ def _legacy_get_transition_pfs( assert log_pf_actions is not None return log_pf_actions - estimator_outputs = check_cond_forward(pf, "pf", states, transitions.conditioning) + # Call estimator with or without conditioning. + if transitions.conditioning is not None: + with has_conditioning_exception_handler("pf", pf): + estimator_outputs = pf(states, transitions.conditioning) + else: + with no_conditioning_exception_handler("pf", pf): + estimator_outputs = pf(states) + log_pf_actions = pf.to_probability_distribution(states, estimator_outputs).log_prob( actions.tensor ) @@ -390,7 +416,14 @@ def _legacy_get_transition_pbs(pb: DiscretePolicyEstimator | None, transitions): ) if pb is not None: - estimator_outputs = check_cond_forward(pb, "pb", valid_next_states, masked_cond) + # Call estimator with or without conditioning. + if masked_cond is not None: + with has_conditioning_exception_handler("pb", pb): + estimator_outputs = pb(valid_next_states, masked_cond) + else: + with no_conditioning_exception_handler("pb", pb): + estimator_outputs = pb(valid_next_states) + valid_log_pb_actions = pb.to_probability_distribution( valid_next_states, estimator_outputs ).log_prob(non_exit_actions.tensor) From d066c97a7868ba57d1af8136459717a22e973a87 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 15:00:19 -0400 Subject: [PATCH 24/27] removed record step --- docs/source/guides/estimator_adapters.md | 33 ++++---- src/gfn/adapters.py | 96 +++++++++++------------ src/gfn/samplers.py | 66 ++++++++-------- testing/test_samplers_and_trajectories.py | 50 ++++++------ 4 files changed, 116 insertions(+), 129 deletions(-) diff --git a/docs/source/guides/estimator_adapters.md b/docs/source/guides/estimator_adapters.md index ecafac58..7316176a 100644 --- a/docs/source/guides/estimator_adapters.md +++ b/docs/source/guides/estimator_adapters.md @@ -29,19 +29,16 @@ Adapters conform to an abstract class structure: - `init_context(batch_size: int, device: torch.device, conditioning: Tensor|None) -> Any` - Allocates a rollout context once per batch (Sampler). Stores invariants (batch size, device, optional conditioning) and initializes any adapter state (e.g., recurrent carry) along with per-step artifact buffers. - - `compute_dist(states_active: States, ctx: Any, step_mask: Tensor|None, **policy_kwargs) -> (Distribution, Any)` + - `compute_dist(states_active: States, ctx: Any, step_mask: Tensor|None, save_estimator_outputs: bool = False, **policy_kwargs) -> (Distribution, Any)` - Runs the estimator forward on the provided rows and returns a torch Distribution over actions. - Slices `conditioning` with `step_mask` when provided (non‑vectorized); uses full conditioning when `step_mask=None` (vectorized). - - Sets `ctx.current_estimator_output` to the raw estimator output. Vectorized callers may prefill `ctx.current_estimator_output` to reuse cached outputs. + - When `save_estimator_outputs=True`, sets `ctx.current_estimator_output` and appends a padded copy (pad with `-inf` on inactive rows) to `ctx.trajectory_estimator_outputs` for non‑vectorized calls. - - `log_probs(actions_active: Tensor, dist: Distribution, ctx: Any, step_mask: Tensor|None, vectorized: bool = False) -> (Tensor, Any)` + - `log_probs(actions_active: Tensor, dist: Distribution, ctx: Any, step_mask: Tensor|None, vectorized: bool = False, save_logprobs: bool = False) -> (Tensor, Any)` - Computes log-probs from `dist` for the given actions. - When `vectorized=False`, returns a padded `(N,)` tensor (zeros where `~step_mask`), with a strict inf-check (raises on `±inf`). - When `vectorized=True`, returns the raw `dist.log_prob(...)` without padding or inf-check (vectorized paths can legitimately include `-inf` for illegal actions). - - - `record(ctx: Any, step_mask: Tensor, sampled_actions: Tensor, dist: Distribution, log_probs: Optional[Tensor], save_estimator_outputs: bool) -> None` - - Records per-step artifacts owned by the context. It never recomputes log-probs; pass `log_probs=None` to skip recording them. - - Pads estimator outputs to `(N, ...)` using `-inf` before appending when `save_estimator_outputs=True`. + - When `save_logprobs=True`, appends the returned log-probs (padded or vectorized) to `ctx.trajectory_log_probs`. - `get_current_estimator_output(ctx: Any) -> Tensor|None` - Returns the last estimator output saved during `compute_dist`. @@ -57,12 +54,12 @@ Adapters conform to an abstract class structure: - `DefaultEstimatorAdapter` - `is_vectorized = True` - No carry. Works with both the Sampler and vectorized probability calculators. - - In the Sampler, it slices conditioning by `step_mask`, runs the estimator, builds the Distribution, and optionally records artifacts. + - In the Sampler, it slices conditioning by `step_mask`, runs the estimator, builds the Distribution, and records artifacts when flags are set. - `RecurrentEstimatorAdapter` - `is_vectorized = False` - Maintains a `carry` in the context (initialized via `estimator.init_carry(batch_size, device)`). - - In the Sampler, it calls the estimator as `(states_active, ctx.carry) -> (est_out, new_carry)`, stores `new_carry`, builds the Distribution, and optionally records artifacts. + - In the Sampler, it calls the estimator as `(states_active, ctx.carry) -> (est_out, new_carry)`, stores `new_carry`, builds the Distribution, and records artifacts when flags are set. ## Vectorized vs Non-Vectorized Probability Paths @@ -90,9 +87,9 @@ Key differences: The Sampler uses the adapter lifecycle: - `ctx = adapter.init_context(batch_size, device, conditioning)` - While some trajectories are active: - - `(dist, ctx) = adapter.compute_dist(states[step_mask], ctx, step_mask, **policy_kwargs)` + - `(dist, ctx) = adapter.compute_dist(states[step_mask], ctx, step_mask, save_estimator_outputs=..., **policy_kwargs)` - Sample actions from `dist`; build actions for the full batch - - `log_probs = adapter.log_probs(valid_actions_tensor, dist, ctx, step_mask, vectorized=False)` (or `None` if skipping) + - `log_probs = adapter.log_probs(valid_actions_tensor, dist, ctx, step_mask, vectorized=False, save_logprobs=...)` (or skip) - Step the environment forward/backward based on `adapter.is_backward` ## How to Implement a New Adapter @@ -103,19 +100,17 @@ The Sampler uses the adapter lifecycle: 2) Implement `init_context(batch_size, device, conditioning)` - Save invariants and allocate any adapter-specific state. Initialize empty per-step buffers. -3) Implement `compute_dist(states_active, ctx, step_mask, **policy_kwargs)` +3) Implement `compute_dist(states_active, ctx, step_mask, save_estimator_outputs=False, **policy_kwargs)` - Slice `conditioning` by `step_mask` for non‑vectorized calls; use full conditioning when `step_mask=None`. - - Call your estimator, set `ctx.current_estimator_output`, and return a Distribution via `to_probability_distribution`. + - Call your estimator, build and return a Distribution via `to_probability_distribution`. + - When `save_estimator_outputs=True`, set `ctx.current_estimator_output` and append a padded copy to `ctx.trajectory_estimator_outputs` for non‑vectorized calls. -4) Implement `log_probs(actions_active, dist, ctx, step_mask, vectorized=False)` +4) Implement `log_probs(actions_active, dist, ctx, step_mask, vectorized=False, save_logprobs=False)` - Non‑vectorized: strict inf-check, return a padded `(N,)` tensor. - Vectorized: return raw `dist.log_prob(...)` (may include `-inf` for illegal actions). + - When `save_logprobs=True`, append the returned tensor to `ctx.trajectory_log_probs`. -5) Implement `record(ctx, step_mask, sampled_actions, dist, log_probs, save_estimator_outputs)` - - Never recompute log-probs here; only store what was provided. - - When saving estimator outputs, pad to `(N, ...)` using `-inf`. - -7) Set `is_backward` appropriately so the Sampler chooses forward/backward environment steps. +6) Set `is_backward` appropriately so the Sampler chooses forward/backward environment steps. ## Reference: Legacy Implementations diff --git a/src/gfn/adapters.py b/src/gfn/adapters.py index 37e0f343..bc9624fe 100644 --- a/src/gfn/adapters.py +++ b/src/gfn/adapters.py @@ -22,14 +22,12 @@ class EstimatorAdapter(ABC): Responsibilities: - init_context(batch_size, device, conditioning): allocate rollout context. - - compute_dist(states_active, ctx, step_mask, **kw): run estimator on active - rows, return a torch Distribution, and update `ctx` if needed (e.g., carry, - cached outputs). - - log_probs(actions_active, dist, ctx, step_mask, vectorized): compute - log-probabilities for active rows; when ``vectorized=False`` return a + - compute_dist(states_active, ctx, step_mask, save_estimator_outputs, **kw): + run estimator on active rows, return a torch Distribution, and update `ctx` + if needed (e.g., carry, cached outputs). + - log_probs(actions_active, dist, ctx, step_mask, vectorized, save_logprobs): + compute log-probabilities for active rows; when ``vectorized=False`` return a batch-sized tensor padded via ``step_mask``. - - record(ctx, step_mask, sampled_actions, dist, log_probs, save_estimator_outputs): - optionally materialize per‑step artifacts in `ctx`. Notes: - The sampler never inspects `ctx`; masking and padding happen inside the @@ -61,6 +59,7 @@ def compute_dist( states_active: States, ctx: Any, step_mask: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, **policy_kwargs: Any, ) -> tuple[Distribution, Any]: ... # fmt: skip @@ -73,21 +72,10 @@ def log_probs( ctx: Any, step_mask: Optional[torch.Tensor] = None, vectorized: bool = False, + save_logprobs: bool = False, ) -> tuple[torch.Tensor, Any]: ... # fmt: skip - @abstractmethod - def record( - self, - ctx: Any, - step_mask: torch.Tensor, - sampled_actions: torch.Tensor, - dist: Distribution, - log_probs: Optional[torch.Tensor], - save_estimator_outputs: bool, - ) -> None: - ... # fmt: skip - # Optional helper for `sample_actions` BC def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: ... # fmt: skip @@ -144,9 +132,7 @@ class while keeping the sampler loop estimator-agnostic. ``est_out`` in ``ctx.current_estimator_output``, and return a Distribution. - ``log_probs(actions_active, dist, ctx, step_mask, vectorized)``: compute log‑probs for active rows; when ``vectorized=False``, return a batch‑padded - tensor using ``step_mask``. - - ``record(...)``: append per‑step artifacts; pad log‑probs with ``0.0`` and - estimator outputs with ``-inf``. + tensor using ``step_mask``. Per‑step artifacts are recorded when flags are set. Masking and path selection - ``states_active == states[~dones]``; ``step_mask`` has shape ``(N,)``. @@ -197,6 +183,7 @@ def compute_dist( states_active: States, ctx: Any, step_mask: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, **policy_kwargs: Any, ) -> tuple[Distribution, Any]: """Run the estimator for active rows and build an action Distribution. @@ -242,8 +229,22 @@ def compute_dist( states_active, estimator_outputs, **policy_kwargs ) - # TODO: Make optional. - ctx.current_estimator_output = estimator_outputs + # Save current estimator output only when requested. + if save_estimator_outputs: + ctx.current_estimator_output = estimator_outputs + + # If we are in a non-vectorized path (masked), append a padded copy to trajectory. + if step_mask is not None: + padded = torch.full( + (ctx.batch_size,) + estimator_outputs.shape[1:], + -float("inf"), + device=ctx.device, + ) + padded[step_mask] = estimator_outputs + ctx.trajectory_estimator_outputs.append(padded) + + else: + ctx.current_estimator_output = None return dist, ctx @@ -254,11 +255,14 @@ def log_probs( ctx: Any, step_mask: Optional[torch.Tensor] = None, vectorized: bool = False, + save_logprobs: bool = False, ) -> tuple[torch.Tensor, Any]: """Compute log-probs, optionally padding back to full batch when non-vectorized.""" lp = dist.log_prob(actions_active) if vectorized: + if save_logprobs: + ctx.trajectory_log_probs.append(lp) return lp, ctx # Non-vectorized path strict check. None of these should be -inf after masking. @@ -269,31 +273,10 @@ def log_probs( step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device, dtype=lp.dtype) step_lp[step_mask] = lp - return step_lp, ctx + if save_logprobs: + ctx.trajectory_log_probs.append(step_lp) - # To merge with log probs / calc dist. - def record( - self, - ctx: Any, - step_mask: torch.Tensor, - sampled_actions: torch.Tensor, - dist: Distribution, - log_probs: Optional[torch.Tensor], - save_estimator_outputs: bool, - ) -> None: - """Record per-step artifacts into the context's trajectory-level lists.""" - if log_probs is not None: - ctx.trajectory_log_probs.append(log_probs) - - if save_estimator_outputs and ctx.current_estimator_output is not None: - estimator_outputs = ctx.current_estimator_output - padded = torch.full( - (ctx.batch_size,) + estimator_outputs.shape[1:], - -float("inf"), - device=ctx.device, - ) - padded[step_mask] = estimator_outputs - ctx.trajectory_estimator_outputs.append(padded) + return step_lp, ctx def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: """Expose the most recent per-step estimator output saved during `compute`.""" @@ -357,6 +340,7 @@ def compute_dist( states_active: States, ctx: Any, step_mask: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, **policy_kwargs: Any, ) -> tuple[Distribution, Any]: """Run estimator with carry and update it. @@ -374,8 +358,20 @@ def compute_dist( **policy_kwargs, ) - # TODO: Make optional. - ctx.current_estimator_output = estimator_outputs + # Save current estimator output only when requested. + if save_estimator_outputs: + ctx.current_estimator_output = estimator_outputs + + if step_mask is not None: + padded = torch.full( + (ctx.batch_size,) + estimator_outputs.shape[1:], + -float("inf"), + device=ctx.device, + ) + padded[step_mask] = estimator_outputs + ctx.trajectory_estimator_outputs.append(padded) + else: + ctx.current_estimator_output = None return dist, ctx diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index eecd1d18..45bcc972 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -18,13 +18,12 @@ class Sampler: Delegates policy logic to an adapter: the adapter builds action distributions, computes step log‑probs, and records artifacts into a - rollout context. Direction (forward/backward) is determined by + rollout context via method flags. Direction (forward/backward) is determined by ``adapter.is_backward``. Attributes: estimator: The underlying policy estimator (adapter wraps it). - adapter: The adapter used to build action distributions, compute step log‑probs, - and record artifacts into a rollout context. + adapter: The adapter used to build action distributions and compute step log‑probs. """ def __init__( @@ -59,9 +58,9 @@ def sample_actions( """Sample one step from ``states`` via the adapter. Initializes or reuses a rollout context with ``adapter.init_context``, - builds a Distribution with ``adapter.compute_dist``, optionally computes - log‑probs with ``adapter.log_probs``, and lets ``adapter.record`` - persist per‑step artifacts. + builds a Distribution with ``adapter.compute_dist``, and optionally computes + log‑probs with ``adapter.log_probs``. Per‑step artifacts are recorded by + the adapter when the corresponding flags are set. Args: env: Environment providing action/state conversion utilities. @@ -90,7 +89,13 @@ def sample_actions( step_mask = torch.ones( states.batch_shape[0], dtype=torch.bool, device=states.device ) - dist, ctx = self.adapter.compute_dist(states, ctx, step_mask, **policy_kwargs) + dist, ctx = self.adapter.compute_dist( + states, + ctx, + step_mask, + save_estimator_outputs=save_estimator_outputs, + **policy_kwargs, + ) with torch.no_grad(): actions_tensor = dist.sample() @@ -98,21 +103,16 @@ def sample_actions( if save_logprobs: # Use adapter to compute step log-probs and pad to batch. log_probs, ctx = self.adapter.log_probs( - actions_tensor, dist, ctx, step_mask, vectorized=False + actions_tensor, + dist, + ctx, + step_mask, + vectorized=False, + save_logprobs=True, ) else: log_probs = None - # Allow adapter to record per-step artifacts for callers that reuse ctx. - self.adapter.record( - ctx=ctx, - step_mask=step_mask, - sampled_actions=actions_tensor, - dist=dist, - log_probs=log_probs, - save_estimator_outputs=save_estimator_outputs, - ) - actions = env.actions_from_tensor(actions_tensor) estimator_output = None @@ -210,7 +210,11 @@ def sample_trajectories( # Compute distribution on active rows dist, ctx = self.adapter.compute_dist( - states[step_mask], ctx, step_mask, **policy_kwargs + states[step_mask], + ctx, + step_mask, + save_estimator_outputs=save_estimator_outputs, + **policy_kwargs, ) # Sample actions for active rows @@ -219,25 +223,17 @@ def sample_trajectories( valid_actions = env.actions_from_tensor(valid_actions_tensor) if save_logprobs: - # Use adapter to compute step log-probs and pad to batch. - log_probs, ctx = self.adapter.log_probs( - valid_actions_tensor, dist, ctx, step_mask, vectorized=False + # Use adapter to compute step log-probs and pad to batch (recorded in ctx). + _, ctx = self.adapter.log_probs( + valid_actions_tensor, + dist, + ctx, + step_mask, + vectorized=False, + save_logprobs=True, ) - else: - log_probs = None - - # Let adapter record artifacts. - self.adapter.record( - ctx=ctx, - step_mask=step_mask, - sampled_actions=valid_actions_tensor, - dist=dist, - log_probs=log_probs, - save_estimator_outputs=save_estimator_outputs, - ) actions[step_mask] = valid_actions - trajectories_actions.append(actions) if self.adapter.is_backward: diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 914cceaf..68b61921 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -530,11 +530,12 @@ def test_default_adapter_compute_record(): ctx = adapter.init_context(n, device, conditioning=None) step_mask = torch.ones(n, dtype=torch.bool, device=device) - dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) actions = dist.sample() - log_probs, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) - adapter.record( - ctx, step_mask, actions, dist, log_probs=log_probs, save_estimator_outputs=True + _, ctx = adapter.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) stacked_logprobs = ( torch.stack(ctx.trajectory_log_probs, dim=0) @@ -568,20 +569,22 @@ def test_recurrent_adapter_flow(): ctx = adapter.init_context(n, device, conditioning=None) step_mask = torch.ones(n, dtype=torch.bool, device=device) - dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) actions = dist.sample() # carry should update when we record multiple steps h0 = ctx.carry["hidden"].clone() - lp, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) - adapter.record( - ctx, step_mask, actions, dist, log_probs=lp, save_estimator_outputs=True + _, ctx = adapter.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) # second step - dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) actions = dist.sample() - lp, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) - adapter.record( - ctx, step_mask, actions, dist, log_probs=lp, save_estimator_outputs=True + _, ctx = adapter.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) h1 = ctx.carry["hidden"].clone() assert torch.all(h1 == h0 + 1) @@ -658,16 +661,12 @@ def test_integration_recurrent_sequence_model_with_adapter( # Run two steps and verify carry and artifact shapes step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) for _ in range(2): - dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) actions = dist.sample() - lp, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) - adapter.record( - ctx, - step_mask, - actions, - dist, - log_probs=lp, - save_estimator_outputs=True, + _, ctx = adapter.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) stacked_logprobs = ( @@ -723,11 +722,12 @@ def test_integration_transformer_sequence_model_with_adapter( step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) - dist, ctx = adapter.compute_dist(cast(States, states), ctx, step_mask) + dist, ctx = adapter.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) actions = dist.sample() - lp, ctx = adapter.log_probs(actions, dist, ctx, step_mask, vectorized=False) - adapter.record( - ctx, step_mask, actions, dist, log_probs=lp, save_estimator_outputs=True + _, ctx = adapter.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) stacked_logprobs = ( From e638be999ff04ff0ebfa77a5ffa757124a832cf5 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 13 Oct 2025 22:24:20 -0400 Subject: [PATCH 25/27] lint errors --- src/gfn/adapters.py | 429 --------------------- src/gfn/estimators.py | 331 ++++++++++++++-- src/gfn/gflownet/base.py | 47 +-- src/gfn/gflownet/detailed_balance.py | 58 +-- src/gfn/gflownet/flow_matching.py | 22 +- src/gfn/gflownet/sub_trajectory_balance.py | 23 +- src/gfn/gflownet/trajectory_balance.py | 24 +- src/gfn/samplers.py | 123 +++--- src/gfn/utils/prob_calculations.py | 367 +++++++++--------- 9 files changed, 570 insertions(+), 854 deletions(-) delete mode 100644 src/gfn/adapters.py diff --git a/src/gfn/adapters.py b/src/gfn/adapters.py deleted file mode 100644 index bc9624fe..00000000 --- a/src/gfn/adapters.py +++ /dev/null @@ -1,429 +0,0 @@ -from abc import ABC, abstractmethod -from inspect import signature -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast - -import torch -from torch.distributions import Distribution - -from gfn.states import States -from gfn.utils.handlers import ( - has_conditioning_exception_handler, - no_conditioning_exception_handler, -) - -if TYPE_CHECKING: - from gfn.estimators import Estimator - - -class EstimatorAdapter(ABC): - """Adapter interface for estimator-specific policy behavior. - - Keeps the sampling loop generic; estimator-specific logic lives here. - - Responsibilities: - - init_context(batch_size, device, conditioning): allocate rollout context. - - compute_dist(states_active, ctx, step_mask, save_estimator_outputs, **kw): - run estimator on active rows, return a torch Distribution, and update `ctx` - if needed (e.g., carry, cached outputs). - - log_probs(actions_active, dist, ctx, step_mask, vectorized, save_logprobs): - compute log-probabilities for active rows; when ``vectorized=False`` return a - batch-sized tensor padded via ``step_mask``. - - Notes: - - The sampler never inspects `ctx`; masking and padding happen inside the - adapter. - - ``is_backward`` selects forward vs backward environment steps. - - ``is_vectorized`` selects fast vectorized vs per‑step probability paths. - """ - - @property - def is_backward(self) -> bool: - ... # fmt: skip - - @property - def is_vectorized(self) -> bool: - ... # fmt: skip - - @abstractmethod - def init_context( - self, - batch_size: int, - device: torch.device, - conditioning: Optional[torch.Tensor] = None, - ) -> Any: - ... # fmt: skip - - @abstractmethod - def compute_dist( - self, - states_active: States, - ctx: Any, - step_mask: Optional[torch.Tensor] = None, - save_estimator_outputs: bool = False, - **policy_kwargs: Any, - ) -> tuple[Distribution, Any]: - ... # fmt: skip - - @abstractmethod - def log_probs( - self, - actions_active: torch.Tensor, - dist: Distribution, - ctx: Any, - step_mask: Optional[torch.Tensor] = None, - vectorized: bool = False, - save_logprobs: bool = False, - ) -> tuple[torch.Tensor, Any]: - ... # fmt: skip - - # Optional helper for `sample_actions` BC - def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: - ... # fmt: skip - - -class RolloutContext: - """Structured per‑rollout state owned by adapters. - - Holds rollout invariants and optional per‑step buffers; use ``extras`` for - adapter‑specific fields without changing the class shape. - """ - - __slots__ = ( - "batch_size", - "device", - "conditioning", - "carry", - "trajectory_log_probs", - "trajectory_estimator_outputs", - "current_estimator_output", - "extras", - ) - - def __init__( - self, - batch_size: int, - device: torch.device, - conditioning: Optional[torch.Tensor] = None, - ) -> None: - self.batch_size = batch_size - self.device = device - self.conditioning = conditioning - self.carry = None - self.trajectory_log_probs: List[torch.Tensor] = [] - self.trajectory_estimator_outputs: List[torch.Tensor] = [] - self.current_estimator_output: Optional[torch.Tensor] = None - self.extras: Dict[str, Any] = {} - - -class DefaultEstimatorAdapter(EstimatorAdapter): - """Adapter for non‑recurrent estimators (default). - - Overview - -------- - This adapter bridges the generic sampling loop and is used throughout the codebase. - It exposes the minimal interface required by the `EstimatorAdapter` abstract base - class while keeping the sampler loop estimator-agnostic. - - Workflow with RolloutContext: - - ``init_context(batch_size, device, conditioning)``: store invariants and - allocate per‑step buffers. - - ``compute_dist(states_active, ctx, step_mask, **kw)``: slice conditioning - by ``step_mask`` when provided, run the estimator on active rows, cache - ``est_out`` in ``ctx.current_estimator_output``, and return a Distribution. - - ``log_probs(actions_active, dist, ctx, step_mask, vectorized)``: compute - log‑probs for active rows; when ``vectorized=False``, return a batch‑padded - tensor using ``step_mask``. Per‑step artifacts are recorded when flags are set. - - Masking and path selection - - ``states_active == states[~dones]``; ``step_mask`` has shape ``(N,)``. - - ``is_backward`` is forwarded from the estimator. - - ``is_vectorized == True`` enables vectorized probability calculators when - available. - - Performance - - One context per rollout; mutate in place. If trajectory length bounds are - known, pre‑allocation of those buffers is possible. - """ - - def __init__(self, estimator: "Estimator") -> None: - """Initialize the adapter with a non-recurrent estimator. - - The estimator must expose `to_probability_distribution(states, est_out, **kw)` - and optionally accept conditioning via `estimator(states, conditioning)`. - """ - self._estimator = estimator - - @property - def is_backward(self) -> bool: - """Whether the wrapped estimator samples in the backward direction.""" - return getattr(self._estimator, "is_backward", False) - - @property - def is_vectorized(self) -> bool: - """Used for vectorized probability calculations.""" - return True - - def init_context( - self, - batch_size: int, - device: torch.device, - conditioning: Optional[torch.Tensor] = None, - ) -> RolloutContext: - """Create a new per-rollout context. - - Stores rollout invariants (batch size, device, optional conditioning) and - initializes empty buffers for per-step artifacts. - """ - return RolloutContext( - batch_size=batch_size, device=device, conditioning=conditioning - ) - - def compute_dist( - self, - states_active: States, - ctx: Any, - step_mask: Optional[torch.Tensor] = None, - save_estimator_outputs: bool = False, - **policy_kwargs: Any, - ) -> tuple[Distribution, Any]: - """Run the estimator for active rows and build an action Distribution. - - - Uses `step_mask` to slice conditioning to the active subset. When `step_mask` - is None, the estimator running in a vectorized context. - - Saves the raw estimator output in `ctx.current_estimator_output` for - optional recording in `record_step`. - """ - precopmputed_estimator_outputs = getattr(ctx, "current_estimator_output", None) - - # Reuse precomputed outputs only in vectorized contexts (no step_mask). - if step_mask is None and precopmputed_estimator_outputs is not None: - expected_bs = states_active.batch_shape[0] - if precopmputed_estimator_outputs.shape[0] != expected_bs: - raise RuntimeError( - "current_estimator_output batch size does not match active states. " - f"Got {precopmputed_estimator_outputs.shape[0]}, expected {expected_bs}. " - "This indicates stale cache reuse; ensure per-step masking when setting " - "ctx.current_estimator_output and clear it when not valid." - ) - estimator_outputs = precopmputed_estimator_outputs - - # Otherwise, compute the estimator outputs. - else: - cond_active = None - if ctx.conditioning is not None: - if step_mask is None: - cond_active = ctx.conditioning - else: - cond_active = ctx.conditioning[step_mask] - - # Call estimator with or without conditioning. - if cond_active is not None: - with has_conditioning_exception_handler("estimator", self._estimator): - estimator_outputs = self._estimator(states_active, cond_active) - else: - with no_conditioning_exception_handler("estimator", self._estimator): - estimator_outputs = self._estimator(states_active) - - # Build the distribution. - dist = self._estimator.to_probability_distribution( - states_active, estimator_outputs, **policy_kwargs - ) - - # Save current estimator output only when requested. - if save_estimator_outputs: - ctx.current_estimator_output = estimator_outputs - - # If we are in a non-vectorized path (masked), append a padded copy to trajectory. - if step_mask is not None: - padded = torch.full( - (ctx.batch_size,) + estimator_outputs.shape[1:], - -float("inf"), - device=ctx.device, - ) - padded[step_mask] = estimator_outputs - ctx.trajectory_estimator_outputs.append(padded) - - else: - ctx.current_estimator_output = None - - return dist, ctx - - def log_probs( - self, - actions_active: torch.Tensor, - dist: Distribution, - ctx: Any, - step_mask: Optional[torch.Tensor] = None, - vectorized: bool = False, - save_logprobs: bool = False, - ) -> tuple[torch.Tensor, Any]: - """Compute log-probs, optionally padding back to full batch when non-vectorized.""" - lp = dist.log_prob(actions_active) - - if vectorized: - if save_logprobs: - ctx.trajectory_log_probs.append(lp) - return lp, ctx - - # Non-vectorized path strict check. None of these should be -inf after masking. - if torch.any(torch.isinf(lp)): - raise RuntimeError("Log probabilities are inf. This should not happen.") - - assert step_mask is not None, "step_mask is required when vectorized=False" - step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device, dtype=lp.dtype) - step_lp[step_mask] = lp - - if save_logprobs: - ctx.trajectory_log_probs.append(step_lp) - - return step_lp, ctx - - def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: - """Expose the most recent per-step estimator output saved during `compute`.""" - return getattr(ctx, "current_estimator_output", None) - - -class RecurrentEstimatorAdapter(DefaultEstimatorAdapter): - """Adapter for recurrent estimators with rollout carry (hidden state). - - - Requires ``estimator.init_carry(batch_size, device)`` and a forward that - returns ``(estimator_outputs, new_carry)``. - - Maintains ``ctx.carry`` across steps and updates it each call. - - ``is_vectorized=False``; probability calculators use the per‑step path - with legacy masks/alignment. - """ - - def __init__(self, estimator: "Estimator") -> None: - # Validate that the estimator presents a recurrent interface - # We check for the presence of `init_carry` and a callable that accepts (states, carry). - init_carry = getattr(estimator, "init_carry", None) - if not callable(init_carry): - raise TypeError( - "RecurrentEstimatorAdapter requires an estimator implementing " - "init_carry(batch_size: int, device: torch.device)." - ) - super().__init__(estimator) - - # TODO: Need to support vectorized probability calculations with Transformers. - @property - def is_vectorized(self) -> bool: - return False - - def init_context( - self, - batch_size: int, - device: torch.device, - conditioning: Optional[torch.Tensor] = None, - ) -> RolloutContext: - """Create context and initialize recurrent carry, (estimator hidden state). - - Differs from the default adapter by allocating `ctx.carry` via - `estimator.init_carry(batch_size, device)`. - """ - init_carry = getattr(self._estimator, "init_carry", None) - if not callable(init_carry): - raise TypeError( - "RecurrentEstimatorAdapter requires an estimator that implements " - "init_carry(batch_size: int, device: torch.device).\n" - "A) Recurrent estimators must expose an `init_carry` method.\n" - "B) RecurrentEstimatorAdapter is only compatible with estimators that " - "expose `init_carry`." - ) - ctx = super().init_context(batch_size, device, conditioning) - init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) - ctx.carry = init_carry_fn(batch_size, device) - - return ctx - - def compute_dist( - self, - states_active: States, - ctx: Any, - step_mask: Optional[torch.Tensor] = None, - save_estimator_outputs: bool = False, - **policy_kwargs: Any, - ) -> tuple[Distribution, Any]: - """Run estimator with carry and update it. - - Differs from the default adapter by calling - `estimator(states_active, ctx.carry) -> (est_out, new_carry)`, storing the - updated carry and saving `current_estimator_output` before building the - Distribution. - """ - estimator_outputs, new_carry = self._estimator(states_active, ctx.carry) - ctx.carry = new_carry - dist = self._estimator.to_probability_distribution( - states_active, - estimator_outputs, - **policy_kwargs, - ) - - # Save current estimator output only when requested. - if save_estimator_outputs: - ctx.current_estimator_output = estimator_outputs - - if step_mask is not None: - padded = torch.full( - (ctx.batch_size,) + estimator_outputs.shape[1:], - -float("inf"), - device=ctx.device, - ) - padded[step_mask] = estimator_outputs - ctx.trajectory_estimator_outputs.append(padded) - else: - ctx.current_estimator_output = None - - return dist, ctx - - -def maybe_instantiate_adapter( - estimator: "Estimator", - adapter: Callable[["Estimator"], EstimatorAdapter] | EstimatorAdapter | None, -) -> EstimatorAdapter: - """Maybe instantiate an adapter for a given estimator. - - Args: - estimator: The estimator to instantiate an adapter for. - adapter: An adapter class instance or callable to use for sampling actions - and computing probability distributions. If None, the default adapter class - for the estimator will be used. - - Returns: - An adapter instance. - """ - # If no adapter is provided, use the default adapter class for the estimator, - # which we need to retrieve and instantiate here. - if adapter is None: - adapter_cls = estimator.default_adapter_class - assert ( - adapter_cls is not None - ), "Estimator has no default adapter class and no adapter was provided" - adapter_cls = cast(Callable[["Estimator"], EstimatorAdapter], adapter_cls) - return adapter_cls(estimator) - - # If an adapter class is provided, instantiate it with the estimator. - elif isinstance(adapter, type) and issubclass(adapter, EstimatorAdapter): - - # We have to assume that the adapter class accepts exactly 1 argument - # (estimator). - sig = signature(adapter) - - # Count parameters excluding 'self' - params = [p for p in sig.parameters.values() if p.name != "self"] - if len(params) != 1: - raise TypeError( - f"Adapter class {adapter.__name__} must accept exactly 1 argument " - f"(estimator) to use automatic adapter instantiation, " - f"but has {len(params)} parameters: {[p.name for p in params]}," - f"You can provide an adapter instance to the GFlowNet instead." - ) - - adapter_factory = cast(Callable[["Estimator"], EstimatorAdapter], adapter) - return adapter_factory(estimator) - - # If an adapter instance is provided, use it. - elif isinstance(adapter, EstimatorAdapter): - return adapter - - else: - raise ValueError(f"Invalid adapter type: {type(adapter)}") diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 825fabb5..51a48b5a 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Dict, List, Optional, Protocol, cast, runtime_checkable import torch import torch.nn as nn @@ -8,14 +8,13 @@ from torch.distributions import Categorical, Distribution from gfn.actions import GraphActions, GraphActionType -from gfn.adapters import ( - DefaultEstimatorAdapter, - EstimatorAdapter, - RecurrentEstimatorAdapter, -) from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, States from gfn.utils.distributions import GraphActionDistribution, UnsqueezedCategorical +from gfn.utils.handlers import ( + has_conditioning_exception_handler, + no_conditioning_exception_handler, +) REDUCTION_FUNCTIONS = { "mean": torch.mean, @@ -24,6 +23,285 @@ } +class RolloutContext: + """Structured per‑rollout state owned by estimators. + + Holds rollout invariants and optional per‑step buffers; use ``extras`` for + estimator‑specific fields without changing the class shape. + """ + + __slots__ = ( + "batch_size", + "device", + "conditioning", + "carry", + "trajectory_log_probs", + "trajectory_estimator_outputs", + "current_estimator_output", + "extras", + ) + + def __init__( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> None: + self.batch_size = batch_size + self.device = device + self.conditioning = conditioning + self.carry = None + self.trajectory_log_probs: List[torch.Tensor] = [] + self.trajectory_estimator_outputs: List[torch.Tensor] = [] + self.current_estimator_output: Optional[torch.Tensor] = None + self.extras: Dict[str, Any] = {} + + +@runtime_checkable +class PolicyEstimatorProtocol(Protocol): + """Static-typing surface for estimators that are policy-capable. + + This protocol captures the methods provided by the PolicyMixin so that external + code (e.g., samplers/probability calculators) can use a precise type rather than + relying on dynamic attributes. This helps static analyzers avoid false positives + like "Tensor is not callable" when calling mixin methods. + """ + + is_vectorized: bool + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> Any: ... # noqa: E704 + + def compute_dist( + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: ... # noqa: E704 + + def log_probs( + self, + actions_active: torch.Tensor, + dist: Distribution, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + vectorized: bool = False, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: ... + + +class PolicyMixin: + """Mixin enabling an `Estimator` to act as a policy (distribution over actions). + + Provides the generic rollout API (`init_context`, `compute_dist`, `log_probs`) + directly on the estimator. Standard policies should inherit from this mixin. + """ + + @property + def is_vectorized(self) -> bool: + """Used for vectorized probability calculations.""" + return True + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> RolloutContext: + """Create a new per-rollout context. + + Stores rollout invariants (batch size, device, optional conditioning) and + initializes empty buffers for per-step artifacts. + + """ + return RolloutContext( + batch_size=batch_size, device=device, conditioning=conditioning + ) + + def compute_dist( + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + """Run the estimator for active rows and build an action Distribution. + + Args: + states_active: The states to run the estimator on. + ctx: The context to run the estimator on. + step_mask: The mask to slice the conditioning to the active subset. + save_estimator_outputs: Whether to save the estimator outputs. + **policy_kwargs: Additional keyword arguments to pass to the estimator. + + Returns: + A tuple containing the distribution and the context. + + - Uses `step_mask` to slice conditioning to the active subset. When `step_mask` + is None, the estimator running in a vectorized context. + - Saves the raw estimator output in `ctx.current_estimator_output` for + optional recording in `record_step`. + """ + precopmputed_estimator_outputs = getattr(ctx, "current_estimator_output", None) + + if step_mask is None and precopmputed_estimator_outputs is not None: + expected_bs = states_active.batch_shape[0] + if precopmputed_estimator_outputs.shape[0] != expected_bs: + raise RuntimeError( + "current_estimator_output batch size does not match active states. " + f"Got {precopmputed_estimator_outputs.shape[0]}, expected {expected_bs}. " + "This indicates stale cache reuse; ensure per-step masking when setting " + "ctx.current_estimator_output and clear it when not valid." + ) + estimator_outputs = precopmputed_estimator_outputs + + # Otherwise, compute the estimator outputs. + else: + cond_active = None + if ctx.conditioning is not None: + if step_mask is None: + cond_active = ctx.conditioning + else: + cond_active = ctx.conditioning[step_mask] + + # Call estimator with or without conditioning. + if cond_active is not None: + with has_conditioning_exception_handler("estimator", self._estimator): + estimator_outputs = self.module(states_active, cond_active) + else: + with no_conditioning_exception_handler("estimator", self._estimator): + estimator_outputs = self.module(states_active) + + # Build the distribution. + dist = self.to_probability_distribution( + states_active, estimator_outputs, **policy_kwargs + ) + + # Save current estimator output only when requested. + if save_estimator_outputs: + ctx.current_estimator_output = estimator_outputs + + # If we are in a non-vectorized path (masked), append a padded copy to trajectory. + if step_mask is not None: + padded = torch.full( + (ctx.batch_size,) + estimator_outputs.shape[1:], + -float("inf"), + device=ctx.device, + ) + padded[step_mask] = estimator_outputs + ctx.trajectory_estimator_outputs.append(padded) + + else: + ctx.current_estimator_output = None + + return dist, ctx + + def log_probs( + self, + actions_active: torch.Tensor, + dist: Distribution, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + vectorized: bool = False, + save_logprobs: bool = False, + ) -> tuple[torch.Tensor, Any]: + """Compute log-probs, optionally padding back to full batch when non-vectorized.""" + lp = dist.log_prob(actions_active) + + if vectorized: + if save_logprobs: + ctx.trajectory_log_probs.append(lp) + return lp, ctx + + # Non-vectorized path strict check. None of these should be -inf after masking. + if torch.any(torch.isinf(lp)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + + assert step_mask is not None, "step_mask is required when vectorized=False" + step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device, dtype=lp.dtype) + step_lp[step_mask] = lp + + if save_logprobs: + ctx.trajectory_log_probs.append(step_lp) + + return step_lp, ctx + + def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: + """Expose the most recent per-step estimator output saved during `compute`.""" + return getattr(ctx, "current_estimator_output", None) + + +class RecurrentPolicyMixin(PolicyMixin): + """Mixin for recurrent policies that maintain and update a rollout carry.""" + + @property + def is_vectorized(self) -> bool: + return False + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> RolloutContext: + ctx = super().init_context(batch_size, device, conditioning) + init_carry = getattr(self, "init_carry", None) + if not callable(init_carry): + raise TypeError( + "Recurrent policy requires init_carry(batch_size: int, device: torch.device)." + ) + init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) + ctx.carry = init_carry_fn(batch_size, device) + + return ctx + + def compute_dist( + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + """Run estimator with carry and update it. + + Differs from the default PolicyMixin by calling + `estimator(states_active, ctx.carry) -> (est_out, new_carry)`, storing the + updated carry and saving `current_estimator_output` before building the + Distribution. + """ + estimator_outputs, new_carry = self(states_active, ctx.carry) # type: ignore + ctx.carry = new_carry + dist = self.to_probability_distribution( + states_active, + estimator_outputs, + **policy_kwargs, + ) + + # Save current estimator output only when requested. + if save_estimator_outputs: + ctx.current_estimator_output = estimator_outputs + + if step_mask is not None: + padded = torch.full( + (ctx.batch_size,) + estimator_outputs.shape[1:], + -float("inf"), + device=ctx.device, + ) + padded[step_mask] = estimator_outputs + ctx.trajectory_estimator_outputs.append(padded) + else: + ctx.current_estimator_output = None + + return dist, ctx + + class Estimator(ABC, nn.Module): r"""Base class for modules mapping states to distributions or scalar values. @@ -56,11 +334,8 @@ class Estimator(ABC, nn.Module): `IdentityPreprocessor`. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. - _default_adapter_class: The default adapter class for this estimator. """ - _default_adapter_class = DefaultEstimatorAdapter - def __init__( self, module: nn.Module, @@ -123,11 +398,6 @@ def expected_output_dim(self) -> Optional[int]: is not well-defined (e.g., when the output is a TensorDict for GraphActions). """ - @property - def default_adapter_class(self) -> type[EstimatorAdapter] | None: - """The default adapter class for this estimator.""" - return self._default_adapter_class - def to_probability_distribution( self, states: States, @@ -173,11 +443,8 @@ class ScalarEstimator(Estimator): that can be used as input to the module. is_backward: Always False for ScalarEstimator (since it's direction-agnostic). reduction_function: Function used to reduce multi-dimensional outputs to scalars. - _default_adapter_class: There is no default adapter class for this estimator. """ - _default_adapter_class = None - def __init__( self, module: nn.Module, @@ -241,7 +508,6 @@ class LogitBasedEstimator(Estimator): preprocessor: Preprocessor object that transforms raw States objects to tensors. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. - _default_adapter_class: The default adapter class for this estimator. """ @staticmethod @@ -389,16 +655,14 @@ class ConditionalLogZEstimator(ScalarEstimator): """Conditional logZ estimator. This estimator is used to estimate the logZ of a GFlowNet from a conditioning - tensor. Since conditioning is a tensor, it does not have a preprocessor. Reduction is used to aggregate the outputs of the module into a single scalar. + tensor. Since conditioning is a tensor, it does not have a preprocessor. + Reduction is used to aggregate the outputs of the module into a single scalar. Attributes: module: The neural network module to use. reduction: String name of one of the REDUCTION_FUNCTIONS keys. - _default_adapter_class: There is no default adapter class for this estimator. """ - _default_adapter_class = None - def __init__(self, module: nn.Module, reduction: str = "mean"): super().__init__(module, preprocessor=None, reduction=reduction) @@ -406,7 +670,7 @@ def _calculate_module_output(self, input: torch.Tensor) -> torch.Tensor: return self.module(input) -class DiscretePolicyEstimator(LogitBasedEstimator): +class DiscretePolicyEstimator(PolicyMixin, LogitBasedEstimator): r"""Forward or backward policy estimators for discrete environments. Estimates either: @@ -422,7 +686,6 @@ class DiscretePolicyEstimator(LogitBasedEstimator): preprocessor: Preprocessor object that transforms raw States objects to tensors. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. - _default_adapter_class: The default adapter class for this estimator. """ def __init__( @@ -527,7 +790,6 @@ class ConditionalDiscretePolicyEstimator(DiscretePolicyEstimator): preprocessor: Preprocessor object that transforms raw States objects to tensors. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. - _default_adapter_class: The default adapter class for this estimator. """ def __init__( @@ -608,11 +870,8 @@ class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): is_backward: Always False for ConditionalScalarEstimator (since it's direction-agnostic). reduction_function: Function used to reduce multi-dimensional outputs to scalars. - _default_adapter_class: There is no default adapter class for this estimator. """ - _default_adapter_class = None - def __init__( self, state_module: nn.Module, @@ -692,7 +951,7 @@ def to_probability_distribution( raise NotImplementedError -class DiscreteGraphPolicyEstimator(LogitBasedEstimator): +class DiscreteGraphPolicyEstimator(PolicyMixin, LogitBasedEstimator): r"""Forward or backward policy estimators for graph-based environments. Estimates either, where $s$ and $s'$ are graph states: @@ -708,7 +967,6 @@ class DiscreteGraphPolicyEstimator(LogitBasedEstimator): preprocessor: Preprocessor object that transforms GraphStates objects to tensors. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. - _default_adapter_class: The default adapter class for this estimator. """ def to_probability_distribution( @@ -835,7 +1093,7 @@ def expected_output_dim(self) -> Optional[int]: return None -class RecurrentDiscretePolicyEstimator(DiscretePolicyEstimator): +class RecurrentDiscretePolicyEstimator(RecurrentPolicyMixin, DiscretePolicyEstimator): """Discrete policy estimator for recurrent architectures with explicit carry. Many sequence models (e.g., RNN/LSTM/GRU/Transformer in autoregressive mode) @@ -850,11 +1108,9 @@ class RecurrentDiscretePolicyEstimator(DiscretePolicyEstimator): - Ensuring the per-step output (``logits`` over actions) is derived from the latest token/time step while the internal model may process sequences. - Interaction with the sampler/adapters - ------------------------------------- - The sampler uses a ``RecurrentEstimatorAdapter`` which calls this estimator + The sampler uses a ``RecurrentPolicyMixin`` which calls this estimator with the current carry, updates the carry on every step, and records - per-step artifacts. Non-recurrent estimators should use the default adapter + per-step artifacts. Non-recurrent estimators should use the default PolicyMixin and the standard ``DiscretePolicyEstimator`` base class instead. Notes @@ -862,7 +1118,7 @@ class RecurrentDiscretePolicyEstimator(DiscretePolicyEstimator): - Forward is intended for on-policy generation; off-policy evaluation over entire trajectories typically requires different batching and masking. - ``init_carry`` is a hard requirement for compatibility with the recurrent - adapter. + PolicyMixin. Attributes: module: The neural network module to use. @@ -870,11 +1126,8 @@ class RecurrentDiscretePolicyEstimator(DiscretePolicyEstimator): preprocessor: Preprocessor object that transforms states to tensors. is_backward: Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. - _default_adapter_class: The default adapter class for this estimator. """ - _default_adapter_class = RecurrentEstimatorAdapter - def __init__( self, module: nn.Module, diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index acf65f3a..bd228268 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,12 +1,11 @@ import math import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, Generic, Tuple, TypeVar +from typing import Any, Generic, Tuple, TypeVar import torch import torch.nn as nn -from gfn.adapters import EstimatorAdapter from gfn.containers import Container, Trajectories from gfn.env import Env from gfn.estimators import Estimator @@ -176,13 +175,6 @@ def __init__( pf: Estimator, pb: Estimator | None, constant_pb: bool = False, - *, - pf_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, - pb_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, ) -> None: """Initializes a PFBasedGFlowNet instance. @@ -194,11 +186,7 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. - pf_adapter: Optional adapter for PF probability calculation/sampling (e.g., - recurrent). When provided, used both in the sampler and in - probability recomputation paths. - pb_adapter: Optional adapter for PB probability calculation. Used in - probability recomputation paths when `pb` is provided. + """ super().__init__() # Technical note: pb may be constant for a variety of edge cases, for example, @@ -228,11 +216,6 @@ def __init__( self.pb = pb self.constant_pb = constant_pb - # Optional adapters controlling estimator interactions via - # vectorized / non-vectorized probability paths. - self.pf_adapter = pf_adapter - self.pb_adapter = pb_adapter - # Advisory: recurrent PF with non-recurrent PB is unusual # (tree DAGs typically prefer pb=None with constant_pb=True). # Import locally to avoid circular imports during module import time. @@ -247,6 +230,7 @@ def __init__( "Consider using pb=None with constant_pb=True for tree DAGs.", ) # Disallow recurrent PB estimators universally. + # I'm not actually sure we should disallow this. if isinstance(self.pb, RecurrentDiscretePolicyEstimator): raise TypeError( "Recurrent PB estimators are not supported. Use a non-recurrent PB " @@ -275,7 +259,7 @@ def sample_trajectories( Returns: A Trajectories object containing the sampled trajectories. """ - sampler = Sampler(estimator=self.pf, adapter=self.pf_adapter) + sampler = Sampler(estimator=self.pf) trajectories = sampler.sample_trajectories( env, n=n, @@ -320,35 +304,23 @@ def __init__( pf: Estimator, pb: Estimator | None, constant_pb: bool = False, - *, - pf_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, - pb_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, ) -> None: """Initializes a TrajectoryBasedGFlowNet instance. Args: pf: The forward policy estimator. - pb: The backward policy estimator, or None if the gflownet DAG is a tree, and - pb is therefore always 1. + pb: The backward policy estimator, or None if the gflownet DAG is a tree, + and pb is therefore always 1. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. - pf_adapter: Optional adapter for PF probability calculation/sampling. When - provided, used both in the sampler and in probability recomputation paths. - pb_adapter: Optional adapter for PB probability calculation. Used in - probability recomputation paths when `pb` is provided. + """ super().__init__( pf, pb, constant_pb=constant_pb, - pf_adapter=pf_adapter, - pb_adapter=pb_adapter, ) def get_pfs_and_pbs( @@ -372,9 +344,6 @@ def get_pfs_and_pbs( the current self.pf - this is usually for off-policy learning with replay buffer. - Uses the PF and PB adapters to evaluate the logprobs, with their optional - adapters if provided. - Args: trajectories: The Trajectories object to evaluate. fill_value: Value to use for invalid states (e.g., $s_f$ added to shorter @@ -391,8 +360,6 @@ def get_pfs_and_pbs( trajectories, fill_value, recalculate_all_logprobs, - pf_adapter=self.pf_adapter, - pb_adapter=self.pb_adapter, ) def get_scores( diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index ab7b1950..8879c79a 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -1,10 +1,9 @@ import math -from typing import Callable, Tuple +from typing import Tuple import torch from gfn.actions import Actions -from gfn.adapters import EstimatorAdapter from gfn.containers import Trajectories, Transitions from gfn.env import Env from gfn.estimators import ConditionalScalarEstimator, Estimator, ScalarEstimator @@ -77,13 +76,6 @@ def __init__( log_reward_clip_min: float = -float("inf"), safe_log_prob_min: bool = True, constant_pb: bool = False, - *, - pf_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, - pb_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, ) -> None: """Initializes a DBGFlowNet instance. @@ -101,20 +93,11 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. - pf_adapter: Optional estimator adapter controlling PF probability - computation/sampling. - pb_adapter: Optional estimator adapter controlling PB probability - computation. + """ - super().__init__( - pf, - pb, - constant_pb=constant_pb, - pf_adapter=pf_adapter, - pb_adapter=pb_adapter, - ) - # Disallow recurrent PF or recurrent adapter for transition-based DB - from gfn.adapters import RecurrentEstimatorAdapter # type: ignore + super().__init__(pf, pb, constant_pb=constant_pb) + + # Disallow recurrent PF for transition-based DB from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore if isinstance(self.pf, RecurrentDiscretePolicyEstimator): @@ -122,17 +105,6 @@ def __init__( "DBGFlowNet does not support recurrent PF estimators (transitions path cannot propagate carry)." ) - # Get the class whether pf_adapter is a class or instance - adapter_class = ( - self.pf_adapter - if isinstance(self.pf_adapter, type) - else type(self.pf_adapter) - ) - if issubclass(adapter_class, RecurrentEstimatorAdapter): - raise TypeError( - "DBGFlowNet does not support RecurrentEstimatorAdapter (transitions path cannot propagate carry)." - ) - assert any( isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator] @@ -191,8 +163,6 @@ def get_pfs_and_pbs( self.pb, transitions, recalculate_all_logprobs, - pf_adapter=self.pf_adapter, - pb_adapter=self.pb_adapter, ) def get_scores( @@ -349,13 +319,6 @@ def __init__( pf: Estimator, pb: Estimator | None, constant_pb: bool = False, - *, - pf_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, - pb_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, ) -> None: """Initializes a ModifiedDBGFlowNet instance. @@ -363,16 +326,9 @@ def __init__( pf: Forward policy estimator. pb: Backward policy estimator or None. constant_pb: See base class. - pf_adapter: Optional adapter for PF. - pb_adapter: Optional adapter for PB. + """ - super().__init__( - pf, - pb, - constant_pb=constant_pb, - pf_adapter=pf_adapter, - pb_adapter=pb_adapter, - ) + super().__init__(pf, pb, constant_pb=constant_pb) def get_scores( self, transitions: Transitions, recalculate_all_logprobs: bool = True diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index cebdd19c..403f7ecb 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -5,7 +5,10 @@ from gfn.containers import StatesContainer, Trajectories from gfn.env import DiscreteEnv -from gfn.estimators import ConditionalDiscretePolicyEstimator, DiscretePolicyEstimator +from gfn.estimators import ( + DefaultPolicyMixin, + DiscretePolicyEstimator, +) from gfn.gflownet.base import GFlowNet, loss_reduce from gfn.samplers import Sampler from gfn.states import DiscreteStates @@ -33,13 +36,9 @@ class FMGFlowNet(GFlowNet[StatesContainer[DiscreteStates]]): estimating the log flow of the edges (states -> next_states). alpha: A scalar weight for the reward matching loss. - Adapter note - ------------ Flow Matching does not rely on PF/PB probability recomputation. Any trajectory sampling provided by this class is for diagnostics/visualization and can only use - the default (non-recurrent) adapter. Sampler adapters (e.g. - `RecurrentEstimatorAdapter`) are not exposed as configuration options for this - class. + the default (non-recurrent) PolicyMixin interface. """ def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): @@ -51,11 +50,10 @@ def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): alpha: A scalar weight for the reward matching loss. """ super().__init__() - assert isinstance( - logF, - DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, - ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" + logF, DefaultPolicyMixin + ), "logF must use the DefaultPolicyMixin interface" + self.logF = logF self.alpha = alpha @@ -80,10 +78,6 @@ def sample_trajectories( Returns: A Trajectories object containing the sampled trajectories. - - Notes: - This helper uses the default sampler adapter; custom sampler adapters are - not supported for Flow Matching. """ if not env.is_discrete: raise NotImplementedError( diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 07697276..bcb1f919 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -1,10 +1,9 @@ import math import warnings -from typing import Callable, List, Literal, Tuple, TypeAlias +from typing import List, Literal, Tuple, TypeAlias import torch -from gfn.adapters import EstimatorAdapter from gfn.containers import Trajectories from gfn.env import Env from gfn.estimators import ConditionalScalarEstimator, Estimator, ScalarEstimator @@ -84,13 +83,6 @@ def __init__( log_reward_clip_min: float = -float("inf"), forward_looking: bool = False, constant_pb: bool = False, - *, - pf_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, - pb_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, ): """Initializes a SubTBGFlowNet instance. @@ -108,18 +100,9 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. - pf_adapter: Optional estimator adapter controlling PF probability - computation/sampling. - pb_adapter: Optional estimator adapter controlling PB probability - computation. + """ - super().__init__( - pf, - pb, - constant_pb=constant_pb, - pf_adapter=pf_adapter, - pb_adapter=pb_adapter, - ) + super().__init__(pf, pb, constant_pb=constant_pb) assert any( isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator] diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 91dd5c45..4fd0586a 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -3,12 +3,11 @@ and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446). """ -from typing import Callable, cast +from typing import cast import torch import torch.nn as nn -from gfn.adapters import EstimatorAdapter from gfn.containers import Trajectories from gfn.env import Env from gfn.estimators import Estimator, ScalarEstimator @@ -48,13 +47,6 @@ def __init__( init_logZ: float = 0.0, log_reward_clip_min: float = -float("inf"), constant_pb: bool = False, - *, - pf_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, - pb_adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, ): """Initializes a TBGFlowNet instance. @@ -69,20 +61,8 @@ def __init__( constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. - pf_adapter: Optional estimator adapter controlling how PF probabilities are - computed and sampled (e.g., `RecurrentEstimatorAdapter`). When provided, - it is used both by the Sampler and by probability recomputation paths. - pb_adapter: Optional estimator adapter for PB probability computation. If - provided and `pb` is an Estimator, it will be used in probability - recomputation paths that require PB. """ - super().__init__( - pf, - pb, - constant_pb=constant_pb, - pf_adapter=pf_adapter, - pb_adapter=pb_adapter, - ) + super().__init__(pf, pb, constant_pb=constant_pb) self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) self.log_reward_clip_min = log_reward_clip_min diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 45bcc972..686ef924 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,12 +1,11 @@ -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, cast import torch from gfn.actions import Actions -from gfn.adapters import EstimatorAdapter, maybe_instantiate_adapter from gfn.containers import Trajectories from gfn.env import Env -from gfn.estimators import Estimator +from gfn.estimators import Estimator, PolicyEstimatorProtocol from gfn.states import GraphStates, States from gfn.utils.common import ensure_same_device from gfn.utils.graphs import graph_states_share_storage @@ -14,36 +13,21 @@ class Sampler: - """Adapter‑driven sampler for GFlowNet environments. + """Estimator‑driven sampler for GFlowNet environments. - Delegates policy logic to an adapter: the adapter builds action - distributions, computes step log‑probs, and records artifacts into a - rollout context via method flags. Direction (forward/backward) is determined by - ``adapter.is_backward``. + The estimator builds action distributions, computes step log‑probs, and records + artifacts into a rollout context via method flags. Direction (forward/backward) + is determined by ``estimator.is_backward``. Attributes: - estimator: The underlying policy estimator (adapter wraps it). - adapter: The adapter used to build action distributions and compute step log‑probs. + estimator: The underlying policy estimator. Must expose the methods contained + in the `PolicyMixin` mixin. """ - def __init__( - self, - estimator: Estimator, - adapter: ( - Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None - ) = None, - ) -> None: - """Initializes a Sampler with a PolicyEstimator. - - Args: - estimator: The PolicyEstimator to use for sampling actions and computing - probability distributions. - adapter: An adapter class instance or callable to use for sampling actions - and computing probability distributions. If None, the default adapter - class for the estimator will be used. - """ + def __init__(self, estimator: Estimator) -> None: + """Initializes a Sampler with a PolicyEstimator.""" self.estimator = estimator - self.adapter = maybe_instantiate_adapter(estimator, adapter) + # TODO: Assert that the estimator exposes the methods contained in the `PolicyMixin` mixin. def sample_actions( self, @@ -55,19 +39,19 @@ def sample_actions( ctx: Any | None = None, **policy_kwargs: Any, ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]: - """Sample one step from ``states`` via the adapter. + """Sample one step from ``states`` via the estimator. - Initializes or reuses a rollout context with ``adapter.init_context``, - builds a Distribution with ``adapter.compute_dist``, and optionally computes - log‑probs with ``adapter.log_probs``. Per‑step artifacts are recorded by - the adapter when the corresponding flags are set. + Initializes or reuses a rollout context with ``estimator.init_context``, + builds a Distribution with ``estimator.compute_dist``, and optionally computes + log‑probs with ``estimator.log_probs``. Per‑step artifacts are recorded by + the estimator when the corresponding flags are set. Args: env: Environment providing action/state conversion utilities. states: Batch of states to act on. conditioning: Optional conditioning for conditional policies. save_estimator_outputs: If True, return the raw estimator outputs - cached by the adapter for this step. Useful for off-policy training + cached by the PolicyMixin for this step. Useful for off-policy training with tempered policies. save_logprobs: If True, return per‑step log‑probs padded to batch. Useful for on-policy training. @@ -77,10 +61,21 @@ def sample_actions( Returns: ``(Actions, log_probs | None, estimator_outputs | None)``. The estimator outputs come from - ``adapter.get_current_estimator_output(ctx)`` when requested. + ``PolicyMixin.get_current_estimator_output(ctx)`` when requested. """ + # NOTE: Explicitly cast to the policy protocol so static analyzers know + # the estimator exposes the mixin methods (init_context/compute_dist/log_probs). + policy_estimator = cast(PolicyEstimatorProtocol, self.estimator) + # Runtime guard: ensure the estimator actually implements the required protocol methods. + # This keeps helpful error messages when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_estimator, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + if ctx is None: - ctx = self.adapter.init_context( + ctx = policy_estimator.init_context( batch_size=states.batch_shape[0], device=states.device, conditioning=conditioning, @@ -89,7 +84,7 @@ def sample_actions( step_mask = torch.ones( states.batch_shape[0], dtype=torch.bool, device=states.device ) - dist, ctx = self.adapter.compute_dist( + dist, ctx = policy_estimator.compute_dist( states, ctx, step_mask, @@ -101,8 +96,8 @@ def sample_actions( actions_tensor = dist.sample() if save_logprobs: - # Use adapter to compute step log-probs and pad to batch. - log_probs, ctx = self.adapter.log_probs( + # Use estimator to compute step log-probs and pad to batch. + log_probs, ctx = policy_estimator.log_probs( actions_tensor, dist, ctx, @@ -117,18 +112,19 @@ def sample_actions( estimator_output = None if save_estimator_outputs: - if not hasattr(self.adapter, "get_current_estimator_output"): + if not hasattr(policy_estimator, "get_current_estimator_output"): raise TypeError( - "Adapter does not support get_current_estimator_output and save_estimator_outputs is True!" + "Estimator does not support get_current_estimator_output and save_estimator_outputs is True!" ) - estimator_output = self.adapter.get_current_estimator_output(ctx) + estimator_output = policy_estimator.get_current_estimator_output(ctx) assert estimator_output is not None assert log_probs is None or log_probs.shape == actions.batch_shape return actions, log_probs, estimator_output - def sample_trajectories( + # TODO: How to avoid "Sampler.sample_trajectories' is too complex" error? + def sample_trajectories( # noqa: C901 self, env: Env, n: Optional[int] = None, @@ -138,11 +134,11 @@ def sample_trajectories( save_logprobs: bool = False, **policy_kwargs: Any, ) -> Trajectories: - """Roll out complete trajectories using the adapter. + """Roll out complete trajectories using the estimator. Reuses a single rollout context across steps, calling ``compute_dist`` & ``log_probs`` each iteration. Uses - ``adapter.is_backward`` to choose the environment step function. + ``estimator.is_backward`` to choose the environment step function. Args: env: Environment to sample in. @@ -162,7 +158,17 @@ def sample_trajectories( For backward trajectories, the reward is computed at the initial state (s0) rather than the terminal state (sf). """ - if self.adapter.is_backward: + # NOTE: Cast to the policy protocol for static typing across mixin methods/properties. + policy_estimator = cast(PolicyEstimatorProtocol, self.estimator) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_estimator, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + + if policy_estimator.is_backward: # [ASSUMPTION] When backward sampling, all provided states are the # terminating states (can be passed to log_reward fn) assert ( @@ -186,9 +192,10 @@ def sample_trajectories( assert states.batch_shape == conditioning.shape[: len(states.batch_shape)] ensure_same_device(states.device, conditioning.device) - dones = ( - states.is_initial_state if self.adapter.is_backward else states.is_sink_state - ) + if policy_estimator.is_backward: + dones = states.is_initial_state + else: + dones = states.is_sink_state # Define dummy actions to avoid errors when stacking empty lists. trajectories_states: List[States] = [states] @@ -196,20 +203,22 @@ def sample_trajectories( env.actions_from_batch_shape((n_trajectories,)) ] # Placeholder kept for backward-compatibility of shapes; logprobs are - # recorded and stacked by the adapter. + # recorded and stacked by the estimator via the context. trajectories_terminating_idx = torch.zeros( n_trajectories, dtype=torch.long, device=device ) step = 0 - ctx = self.adapter.init_context(n_trajectories, device, conditioning) + if not hasattr(policy_estimator, "init_context"): + raise TypeError("Estimator is not policy-capable (missing PolicyMixin)") + ctx = policy_estimator.init_context(n_trajectories, device, conditioning) while not all(dones): actions = env.actions_from_batch_shape((n_trajectories,)) step_mask = ~dones # Compute distribution on active rows - dist, ctx = self.adapter.compute_dist( + dist, ctx = policy_estimator.compute_dist( states[step_mask], ctx, step_mask, @@ -223,8 +232,8 @@ def sample_trajectories( valid_actions = env.actions_from_tensor(valid_actions_tensor) if save_logprobs: - # Use adapter to compute step log-probs and pad to batch (recorded in ctx). - _, ctx = self.adapter.log_probs( + # Use estimator to compute step log-probs and pad to batch (recorded in ctx). + _, ctx = policy_estimator.log_probs( valid_actions_tensor, dist, ctx, @@ -236,7 +245,7 @@ def sample_trajectories( actions[step_mask] = valid_actions trajectories_actions.append(actions) - if self.adapter.is_backward: + if policy_estimator.is_backward: new_states = env._backward_step(states, actions) # type: ignore[attr-defined] else: new_states = env._step(states, actions) # type: ignore[attr-defined] @@ -262,7 +271,7 @@ def sample_trajectories( # to filter out the already done ones. new_dones = ( new_states.is_initial_state - if self.adapter.is_backward + if policy_estimator.is_backward else new_states.is_sink_state ) & ~dones trajectories_terminating_idx[new_dones] = step @@ -324,7 +333,7 @@ def sample_trajectories( conditioning=conditioning, actions=stacked_actions, terminating_idx=trajectories_terminating_idx, - is_backward=self.adapter.is_backward, + is_backward=policy_estimator.is_backward, log_rewards=None, # will be calculated later log_probs=stacked_logprobs, estimator_outputs=stacked_estimator_outputs, @@ -355,7 +364,7 @@ def __init__( self, pf_estimator: Estimator, pb_estimator: Estimator, - ): + ) -> None: """Initializes a LocalSearchSampler with forward and backward estimators. Args: diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 3d51271e..bc4f90cf 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -1,11 +1,44 @@ -import warnings -from typing import Any, Callable, Tuple +from typing import Any, Protocol, Tuple, cast, runtime_checkable import torch +from torch.distributions import Distribution -from gfn.adapters import EstimatorAdapter, maybe_instantiate_adapter from gfn.containers import Trajectories, Transitions -from gfn.estimators import Estimator +from gfn.estimators import Estimator, RecurrentPolicyMixin + + +# NOTE: We use a Protocol to make the policy-capable estimator interface explicit for the type checker. +# This avoids Pyright errors like "Object of type 'Tensor' is not callable" when calling +# estimator.compute_dist/log_probs that live on the PolicyMixin, not on the base Estimator. +@runtime_checkable +class PolicyEstimatorProtocol(Protocol): + is_vectorized: bool + + def init_context( # noqa: E704 + self, + batch_size: int, + device: torch.device, + conditioning: torch.Tensor | None = None, + ) -> Any: ... + + def compute_dist( # noqa: E704 + self, + states_active: Any, + ctx: Any, + step_mask: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: ... + + def log_probs( # noqa: E704 + self, + actions_active: torch.Tensor, + dist: Distribution, + ctx: Any, + step_mask: torch.Tensor | None = None, + vectorized: bool = False, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: ... + # ------------ # Trajectories @@ -18,20 +51,12 @@ def get_trajectory_pfs_and_pbs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = True, - pf_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, - pb_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate PF and PB log‑probabilities for trajectories via adapters. + """Calculate PF and PB log‑probabilities for trajectories. Delegates to ``get_trajectory_pfs`` and ``get_trajectory_pbs`` while - forwarding optional adapters and policy kwargs. - - Vectorized vs non‑vectorized: - - ``adapter is None`` or ``adapter.is_vectorized=True`` → vectorized path - (fast, parity with legacy). - - ``adapter.is_vectorized=False`` (e.g., recurrent) → per‑step path with - legacy masks/alignment. + forwarding policy kwargs. Args: pf: Forward policy estimator. @@ -39,40 +64,26 @@ def get_trajectory_pfs_and_pbs( trajectories: Trajectories to evaluate. fill_value: Value used to pad invalid positions. recalculate_all_logprobs: If True, recompute PF even if cached. - pf_adapter: Optional adapter for PF. - pb_adapter: Optional adapter for PB. **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: ``(log_pf[T,N], log_pb[T,N])`` """ + # TODO: Remove this assertion and move to a test. # fill value is the value used for invalid states (sink state usually) - - # uncomment next line for debugging - # assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy) - - if pb_adapter is not None and not isinstance(pb_adapter, type(pf_adapter)): # type: ignore - warnings.warn( - ( - "type(pb_adapter)={} and type(pf_adapter)={}, this is probably not what you want " - "unless you explicitly want to use different sampling logic for the two policies " - "(with different estimator architectures). This is very uncommon." - ).format(type(pb_adapter), type(pf_adapter)) - ) + assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy) log_pf_trajectories = get_trajectory_pfs( pf, trajectories, fill_value=fill_value, recalculate_all_logprobs=recalculate_all_logprobs, - adapter=pf_adapter, **policy_kwargs, ) log_pb_trajectories = get_trajectory_pbs( pb, trajectories, fill_value=fill_value, - adapter=pb_adapter, **policy_kwargs, ) @@ -84,15 +95,13 @@ def get_trajectory_pfs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = True, - adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: """Calculate PF log‑probabilities for trajectories. - - Vectorized: ``adapter is None`` or ``adapter.is_vectorized=True`` → legacy - vectorized implementation. - - Non‑vectorized: ``adapter.is_vectorized=False`` → per‑step evaluation with - legacy masks ``~is_sink_state[t] & ~is_dummy[t]``; no action‑id indexing. + Non‑vectorized (per‑step) evaluation with masks + ``~is_sink_state[t] & ~is_dummy[t]`` & no action‑id indexing is supported when + specifically needed (estimator.is_vectorized=False). Args: pf: Forward policy estimator. @@ -100,7 +109,6 @@ def get_trajectory_pfs( fill_value: Value used to pad invalid positions. recalculate_all_logprobs: If True, recompute PF even if cached. Useful for off-policy training. - adapter: Optional adapter deciding the evaluation path. **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: @@ -109,8 +117,9 @@ def get_trajectory_pfs( Raises: ValueError: If backward trajectories are provided. """ - adapter = maybe_instantiate_adapter(pf, adapter) - assert adapter is not None + # TODO: Ensure that the estimator is policy-capable here. + if not hasattr(pf, "init_context"): + raise TypeError("Estimator is not policy-capable (missing PolicyMixin)") if trajectories.is_backward: raise ValueError("Backward trajectories are not supported") @@ -129,17 +138,29 @@ def get_trajectory_pfs( assert log_pf_trajectories is not None else: - # Decide vectorized (legacy) vs non-vectorized (adapter per-step) - if not adapter.is_vectorized: + # Decide vectorized vs non-vectorized based on estimator capability + # Tell the type-checker we expect the Policy mixin surface here. + policy_pf = cast(PolicyEstimatorProtocol, pf) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_pf, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + is_vectorized = bool(getattr(policy_pf, "is_vectorized", True)) - # Adapter-driven path + if not is_vectorized: + # Per-step path. N = trajectories.n_trajectories device = trajectories.states.device cond = trajectories.conditioning + + # TODO: Why do we need this? if cond is not None and len(cond.shape) >= 2: cond = cond[0] - ctx = adapter.init_context(int(N), device, cond) # type: ignore[arg-type] + ctx = policy_pf.init_context(int(N), device, cond) # type: ignore[arg-type] T = trajectories.max_length log_pf_trajectories = torch.full( @@ -174,12 +195,11 @@ def get_trajectory_pfs( # explicitly for the current step. ctx.current_estimator_output = None - # Build distribution for active rows and compute step log-probs via - # adapter. - dist, ctx = adapter.compute_dist( + # Build distribution for active rows and compute step log-probs + dist, ctx = policy_pf.compute_dist( step_states, ctx, step_mask, **policy_kwargs ) - step_log_probs, ctx = adapter.log_probs( + step_log_probs, ctx = policy_pf.log_probs( step_actions, dist, ctx, step_mask, vectorized=False ) @@ -218,8 +238,7 @@ def get_trajectory_pfs( # Broadcast (N, ...) to (T, N, ...), then index. masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] - # Create a temporary context sized to valid rows. - ctx_v = adapter.init_context( + ctx_v = policy_pf.init_context( int(len(valid_states)), trajectories.states.device, conditioning=masked_cond, @@ -233,14 +252,14 @@ def get_trajectory_pfs( estimator_outputs = trajectories.estimator_outputs[action_mask] ctx_v.current_estimator_output = estimator_outputs - # Delegate to adapter for dist and vectorized log-prob calculation. - dist, ctx_v = adapter.compute_dist( + # Build distribution and compute vectorized log-probs + dist, ctx_v = policy_pf.compute_dist( valid_states, ctx_v, step_mask=None, **policy_kwargs, ) - valid_log_pf_actions, _ = adapter.log_probs( + valid_log_pf_actions, _ = policy_pf.log_probs( valid_actions.tensor, dist, ctx_v, step_mask=None, vectorized=True ) @@ -261,22 +280,20 @@ def get_trajectory_pbs( pb: Estimator | None, trajectories: Trajectories, fill_value: float = 0.0, - adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: """Calculate PB log‑probabilities for trajectories. - - Vectorized: ``adapter is None`` or ``adapter.is_vectorized=True``. - - Non‑vectorized: ``adapter.is_vectorized=False`` with legacy alignment - (action at ``t`` with state at ``t+1``) and mask - ``~is_sink_state[t+1] & ~is_initial_state[t+1] & ~is_dummy[t] & ~is_exit[t]``; - skip ``t==0``. + Non‑vectorized (per‑step) evaluation with with alignment + (action at ``t`` with state at ``t+1``) and mask + ``~is_sink_state[t+1] & ~is_initial_state[t+1] & ~is_dummy[t] & ~is_exit[t]``; + skip ``t==0``. is supported when specifically needed + (estimator.is_vectorized=False). Args: pb: Backward policy estimator, or ``None`` for trees (PB=1). trajectories: Trajectories to evaluate. fill_value: Value used to pad invalid positions. - adapter: Optional adapter deciding the evaluation path. **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: @@ -332,83 +349,87 @@ def get_trajectory_pbs( log_pb_trajectories.dtype, copy=False ) - # There is a backward policy. - else: - adapter = maybe_instantiate_adapter(pb, adapter) - - # The backward policy requires step-wise evaluation. - if not adapter.is_vectorized: - # Adapter-driven pb evaluation (non-recurrent) - N = trajectories.n_trajectories - device = trajectories.states.device - cond = trajectories.conditioning - - if cond is not None and len(cond.shape) >= 2: - cond = cond[0] - ctx = adapter.init_context(int(N), device, cond) # type: ignore[arg-type] - - # Iterate per-step with legacy-complete masking (state at t+1, action at t) - for t in range(trajectories.max_length): - # TODO: these checks are curious - I think one of them is never needed - # because for now we do not support reversed trajectories. - next_state_isnt_sink = ~trajectories.states.is_sink_state[t + 1] - next_state_isnt_initial = ~trajectories.states.is_initial_state[t + 1] - state_ok = next_state_isnt_sink & next_state_isnt_initial - if t == 0: - # log PB is always zero for the transition s1 -> s0. - state_ok = torch.zeros_like(state_ok, dtype=torch.bool) - - action_ok = (~trajectories.actions.is_dummy[t]) & ( - ~trajectories.actions.is_exit[t] - ) - step_mask = state_ok & action_ok + assert log_pb_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) - if not torch.any(step_mask): - continue + return log_pb_trajectories - step_states = trajectories.states[t + 1][step_mask] - step_actions = trajectories.actions.tensor[t][step_mask] + # There is a backward policy. + policy_pb = cast(PolicyEstimatorProtocol, pb) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_pb, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + is_vectorized = bool(getattr(policy_pb, "is_vectorized", True)) + + if not is_vectorized: + # Per-step pb evaluation (state at t+1, action at t) + N = trajectories.n_trajectories + device = trajectories.states.device + cond = trajectories.conditioning + if cond is not None and len(cond.shape) >= 2: + cond_step0 = cond[0] # TODO: Why do we need this? + ctx = policy_pb.init_context(int(N), device, cond_step0) # type: ignore[arg-type] + + # Iterate per-step with masking (state at t+1, action at t) + for t in range(trajectories.max_length): + # TODO: these checks are curious - I think one of them is never needed + # because for now we do not support reversed trajectories. + next_state_isnt_sink = ~trajectories.states.is_sink_state[t + 1] + next_state_isnt_initial = ~trajectories.states.is_initial_state[t + 1] + state_ok = next_state_isnt_sink & next_state_isnt_initial + if t == 0: + # log PB is always zero for the transition s1 -> s0. + state_ok = torch.zeros_like(state_ok, dtype=torch.bool) + + action_ok = (~trajectories.actions.is_dummy[t]) & ( + ~trajectories.actions.is_exit[t] + ) + step_mask = state_ok & action_ok - # Prevent reusing last step's estimator output (batch size may differ, - # and estimator output caching isn't needed for PB). - ctx.current_estimator_output = None - dist, ctx = adapter.compute_dist( - step_states, - ctx, - step_mask, - **policy_kwargs, - ) - step_lp, ctx = adapter.log_probs( - step_actions, dist, ctx, step_mask, vectorized=False - ) + if not torch.any(step_mask): + continue - padded = torch.full( - (N,), - fill_value, - device=device, - dtype=step_lp.dtype, - ) - padded[step_mask] = step_lp[step_mask] - log_pb_trajectories[t] = padded + step_states = trajectories.states[t + 1][step_mask] + step_actions = trajectories.actions.tensor[t][step_mask] - # The backward policy supports vectorized evaluation. - else: - ctx_v = adapter.init_context( - int(len(valid_states)), trajectories.states.device, conditioning=masked_cond # type: ignore[arg-type] + # Prevent reusing last step's estimator output (batch size may differ, + # and estimator output caching isn't needed for PB). + ctx.current_estimator_output = None + dist, ctx = policy_pb.compute_dist( + step_states, ctx, step_mask, **policy_kwargs ) - dist, ctx_v = adapter.compute_dist( - valid_states, - ctx_v, - step_mask=None, - **policy_kwargs, - ) - valid_log_pb_actions, _ = adapter.log_probs( - valid_actions.tensor, dist, ctx_v, step_mask=None, vectorized=True - ) - log_pb_trajectories[action_mask] = valid_log_pb_actions.to( - log_pb_trajectories.dtype, copy=False + step_lp, ctx = policy_pb.log_probs( + step_actions, dist, ctx, step_mask, vectorized=False ) + padded = torch.full((N,), fill_value, device=device, dtype=step_lp.dtype) + padded[step_mask] = step_lp[step_mask] + log_pb_trajectories[t] = padded + + # The backward policy supports vectorized evaluation. + else: + ctx_v = policy_pb.init_context( + int(len(valid_states)), trajectories.states.device, conditioning=masked_cond # type: ignore[arg-type] + ) + dist, ctx_v = policy_pb.compute_dist( + valid_states, + ctx_v, + step_mask=None, + **policy_kwargs, + ) + valid_log_pb_actions, _ = policy_pb.log_probs( + valid_actions.tensor, dist, ctx_v, step_mask=None, vectorized=True + ) + log_pb_trajectories[action_mask] = valid_log_pb_actions.to( + log_pb_trajectories.dtype, copy=False + ) + assert log_pb_trajectories.shape == ( trajectories.max_length, trajectories.n_trajectories, @@ -427,25 +448,16 @@ def get_transition_pfs_and_pbs( pb: Estimator | None, transitions: Transitions, recalculate_all_logprobs: bool = True, - pf_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, - pb_adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: """Calculate PF and PB log‑probabilities for transitions. - Mirrors the trajectories logic: - - Vectorized when ``adapter is None`` or ``is_vectorized=True``. - - Non‑vectorized when ``is_vectorized=False``: per‑batch adapter call with - legacy masks; no action‑id indexing. - Args: pf: Forward policy estimator. pb: Backward policy estimator, or ``None`` for trees (PB=1). transitions: Transitions to evaluate. recalculate_all_logprobs: If True, recompute PF even if cached. Useful for off-policy training. - pf_adapter: Optional adapter for PF. - pb_adapter: Optional adapter for PB. **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: @@ -457,21 +469,10 @@ def get_transition_pfs_and_pbs( if transitions.is_backward: raise ValueError("Backward transitions are not supported") - if pb_adapter is not None and not isinstance(pb_adapter, type(pf_adapter)): # type: ignore - warnings.warn( - ( - "type(pb_adapter)={} and type(pf_adapter)={}, this is probably not what you want " - "unless you explicitly want to use different sampling logic for the two policies " - "(with different estimator architectures). This is very uncommon." - ).format(type(pb_adapter), type(pf_adapter)) - ) - log_pf_transitions = get_transition_pfs( - pf, transitions, recalculate_all_logprobs, adapter=pf_adapter, **policy_kwargs - ) - log_pb_transitions = get_transition_pbs( - pb, transitions, adapter=pb_adapter, **policy_kwargs + pf, transitions, recalculate_all_logprobs, **policy_kwargs ) + log_pb_transitions = get_transition_pbs(pb, transitions, **policy_kwargs) assert log_pf_transitions.shape == (transitions.n_transitions,) assert log_pb_transitions.shape == (transitions.n_transitions,) @@ -483,29 +484,22 @@ def get_transition_pfs( pf: Estimator, transitions: Transitions, recalculate_all_logprobs: bool = True, - adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: """Calculate PF log‑probabilities for transitions. - - Vectorized: ``adapter is None`` or ``is_vectorized=True``. - - Non‑vectorized: ``is_vectorized=False``; single adapter call with legacy - masks; no action‑id indexing. + Non‑vectorized: `single estimator call with legacy masks; no action‑id indexing. Args: pf: Forward policy estimator. transitions: Transitions to evaluate. recalculate_all_logprobs: If True, recompute PF even if cached. Useful for off-policy training. - adapter: Optional adapter deciding the evaluation path. **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: ``log_pf`` of shape ``(M,)``. """ - adapter = maybe_instantiate_adapter(pf, adapter) - assert adapter is not None - states = transitions.states actions = transitions.actions @@ -514,26 +508,32 @@ def get_transition_pfs( assert log_pf_actions is not None else: - from gfn.adapters import RecurrentEstimatorAdapter # type: ignore - - if isinstance(adapter, RecurrentEstimatorAdapter): - raise TypeError( - "RecurrentEstimatorAdapter is only supported for Trajectories" - ) + if isinstance(pf, RecurrentPolicyMixin): + raise TypeError("RecurrentPolicyMixin is only supported for Trajectories") N = transitions.n_transitions device = transitions.states.device cond = transitions.conditioning - ctx = adapter.init_context(int(N), device, cond) + + # For static typing, cast to the policy protocol before calling mixin methods. + policy_pf = cast(PolicyEstimatorProtocol, pf) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_pf, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + ctx = policy_pf.init_context(int(N), device, cond) mask = torch.ones(N, dtype=torch.bool, device=device) - # Evaluate the log PF of the actions, with optional conditioning. + # Evaluate the log PF of the actions # TODO: Inefficient duplication in case of tempered policy # The Transitions container should then have some # estimator_outputs attribute as well, to avoid duplication here ? # See (#156). - dist, ctx = adapter.compute_dist(states[mask], ctx, mask, **policy_kwargs) - log_pf_actions, _ = adapter.log_probs( + dist, ctx = policy_pf.compute_dist(states[mask], ctx, mask, **policy_kwargs) + log_pf_actions, _ = policy_pf.log_probs( actions.tensor[mask], dist, ctx, mask, vectorized=False ) @@ -543,20 +543,16 @@ def get_transition_pfs( def get_transition_pbs( pb: Estimator | None, transitions: Transitions, - adapter: Callable[[Estimator], EstimatorAdapter] | EstimatorAdapter | None = None, **policy_kwargs: Any, ) -> torch.Tensor: """Calculate PB log‑probabilities for transitions. - - Vectorized: ``adapter is None`` or ``is_vectorized=True``. - - Non‑vectorized: ``is_vectorized=False``; single adapter call with legacy - masks; no action‑id indexing. Recurrent adapters are not supported for - transitions. + - Non‑vectorized ``is_vectorized=False`` single estimator call with legacy + masks; no action‑id indexing. Args: pb: Backward policy estimator, or ``None`` for trees (PB=1). transitions: Transitions to evaluate. - adapter: Optional adapter deciding the evaluation path. **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: @@ -574,15 +570,22 @@ def get_transition_pbs( if pb is None: return log_pb_actions - adapter = maybe_instantiate_adapter(pb, adapter) - assert adapter is not None + if not hasattr(pb, "init_context"): + raise TypeError("Estimator is not policy-capable (missing PolicyMixin)") - from gfn.adapters import RecurrentEstimatorAdapter # type: ignore + if isinstance(pb, RecurrentPolicyMixin): + raise TypeError("RecurrentPolicyMixin is only supported for Trajectories") - if isinstance(adapter, RecurrentEstimatorAdapter): - raise TypeError("RecurrentEstimatorAdapter is only supported for Trajectories") - - ctx = adapter.init_context( + # For static typing, cast to the policy protocol before calling mixin methods. + policy_pb = cast(PolicyEstimatorProtocol, pb) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_pb, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + ctx = policy_pb.init_context( int(transitions.n_transitions), transitions.states.device, transitions.conditioning, @@ -596,10 +599,10 @@ def get_transition_pbs( if not torch.any(mask): return log_pb_actions - dist, ctx = adapter.compute_dist( + dist, ctx = policy_pb.compute_dist( transitions.next_states[mask], ctx, mask, **policy_kwargs ) - step_lp, _ = adapter.log_probs( + step_lp, _ = policy_pb.log_probs( transitions.actions.tensor[mask], dist, ctx, mask, vectorized=False ) log_pb_actions[mask] = step_lp[mask] From 1ee6a8ff038f5ee7ca248b306db753632091e1fd Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 14 Oct 2025 00:27:13 -0400 Subject: [PATCH 26/27] autoflake --- docs/source/guides/estimator_adapters.md | 225 +++++++++++------- src/gfn/estimators.py | 20 +- src/gfn/gflownet/flow_matching.py | 6 +- src/gfn/gym/helpers/box_utils.py | 6 +- src/gfn/utils/prob_calculations.py | 5 - ..._adaptor_estimator_gflownet_integration.py | 71 ++---- testing/test_probability_calculations.py | 101 ++------ testing/test_samplers_and_trajectories.py | 87 ++++--- .../examples/train_bitsequence_recurrent.py | 2 - tutorials/examples/train_line.py | 4 +- 10 files changed, 259 insertions(+), 268 deletions(-) diff --git a/docs/source/guides/estimator_adapters.md b/docs/source/guides/estimator_adapters.md index 7316176a..63e34d0c 100644 --- a/docs/source/guides/estimator_adapters.md +++ b/docs/source/guides/estimator_adapters.md @@ -1,119 +1,180 @@ -# Estimator Adapters +# PolicyMixin: Policies and Rollouts -Adapters decouple the generic sampling and probability computation logic from estimator-specific details (conditioning shape, recurrent state/carry, distribution construction, artifact recording). They enable a single sampler and probability utilities to work across different estimator families. +Estimators become policy-capable by mixing in a small, uniform rollout API. This lets the same `Sampler` and probability utilities drive different estimator families (discrete, graph, conditional, recurrent) without bespoke glue code. This guide explains: -- The Adapter and RolloutContext -- Vectorized vs non-vectorized probability paths -- How adapters integrate with the Sampler and probability calculators -- How to implement a new Adapter +- The Policy rollout API and `RolloutContext` +- Vectorized vs non‑vectorized probability paths +- How policies integrate with the `Sampler` and probability calculators +- How to implement a new policy mixin or tailor the default behavior ## Concepts and Goals -An Adapter mediates between three places where estimator logic is needed: -1) The online sampling loop (Sampler) for trajectory rollouts -2) Probability calculators for trajectories (PF/PB) and transitions (PF/PB) -3) Optional artifact capture (per-step log-probs, estimator outputs) +A policy‑capable estimator exposes: +- `is_vectorized: bool` — whether the estimator can be evaluated in a single vectorized call (no per‑step carry). +- `init_context(batch_size, device, conditioning)` — allocate a per‑rollout context. +- `compute_dist(states_active, ctx, step_mask, ...) -> (Distribution, ctx)` — run the model, build a `torch.distributions.Distribution`. +- `log_probs(actions_active, dist, ctx, step_mask, vectorized, ...) -> (Tensor, ctx)` — evaluate log‑probs, optionally padded to batch. +- `get_current_estimator_output(ctx)` — access the last raw model output when requested. -The Sampler remains estimator-agnostic. Adapters own any estimator-specific state (e.g., recurrent carry) and control how to run the estimator and build the policy distribution. +All per‑step artifacts (e.g., log‑probs, raw outputs, recurrent state) are owned by the `RolloutContext` and recorded by the mixin. -## Adapters +## RolloutContext -Adapters conform to an abstract class structure: +The `RolloutContext` is a lightweight container created once per rollout: +- `batch_size`, `device`, optional `conditioning` +- Optional `carry` (for recurrent policies) +- Per‑step buffers: `trajectory_log_probs`, `trajectory_estimator_outputs` +- `current_estimator_output` for cached reuse or immediate retrieval +- `extras: dict` for arbitrary policy‑specific data -- Properties - - `is_backward: bool` — whether the wrapped estimator is a backward policy. - - `is_vectorized: bool` — whether the adapter supports vectorized probability calculations (no carry). Vectorized adapters always use the faster legacy vectorized paths in probability calculators. Non-vectorized adapters (e.g., recurrent) use per-step paths with masking and alignment identical to the legacy reference. +See `src/gfn/estimators.py` for the full definition. -- Methods - - `init_context(batch_size: int, device: torch.device, conditioning: Tensor|None) -> Any` - - Allocates a rollout context once per batch (Sampler). Stores invariants (batch size, device, optional conditioning) and initializes any adapter state (e.g., recurrent carry) along with per-step artifact buffers. +## PolicyMixin (vectorized, default) - - `compute_dist(states_active: States, ctx: Any, step_mask: Tensor|None, save_estimator_outputs: bool = False, **policy_kwargs) -> (Distribution, Any)` - - Runs the estimator forward on the provided rows and returns a torch Distribution over actions. - - Slices `conditioning` with `step_mask` when provided (non‑vectorized); uses full conditioning when `step_mask=None` (vectorized). - - When `save_estimator_outputs=True`, sets `ctx.current_estimator_output` and appends a padded copy (pad with `-inf` on inactive rows) to `ctx.trajectory_estimator_outputs` for non‑vectorized calls. +`PolicyMixin` enables vectorized evaluation by default (`is_vectorized=True`). - - `log_probs(actions_active: Tensor, dist: Distribution, ctx: Any, step_mask: Tensor|None, vectorized: bool = False, save_logprobs: bool = False) -> (Tensor, Any)` - - Computes log-probs from `dist` for the given actions. - - When `vectorized=False`, returns a padded `(N,)` tensor (zeros where `~step_mask`), with a strict inf-check (raises on `±inf`). - - When `vectorized=True`, returns the raw `dist.log_prob(...)` without padding or inf-check (vectorized paths can legitimately include `-inf` for illegal actions). - - When `save_logprobs=True`, appends the returned log-probs (padded or vectorized) to `ctx.trajectory_log_probs`. +- `init_context(batch_size, device, conditioning)` returns a fresh `RolloutContext` with empty buffers. +- `compute_dist(...)`: + - Slices `conditioning` by `step_mask` when provided; uses full `conditioning` when `step_mask=None` (vectorized). + - Optionally reuses `ctx.current_estimator_output` (e.g., PF with cached `trajectories.estimator_outputs`). + - Calls the estimator module and builds a `Distribution` via `to_probability_distribution`. + - When `save_estimator_outputs=True`, sets `ctx.current_estimator_output` and records a padded copy to `ctx.trajectory_estimator_outputs` for non‑vectorized calls. +- `log_probs(...)`: + - `vectorized=True`: returns raw `dist.log_prob(...)` (may include `-inf` for illegal actions) and optionally records to `trajectory_log_probs`. + - `vectorized=False`: strict inf‑check, pads to shape `(N,)` using `step_mask`, records when requested. - - `get_current_estimator_output(ctx: Any) -> Tensor|None` - - Returns the last estimator output saved during `compute_dist`. +Code reference (log‑probs behavior): `src/gfn/estimators.py`. -- Context - - The rollout context (created by `init_context`) owns: - - `batch_size`, `device`, optional `conditioning` - - Optional `carry` (recurrent hidden state) - - Per-step buffers: `trajectory_log_probs`, `trajectory_estimator_outputs` +## RecurrentPolicyMixin (per‑step) -## Built-in Adapters +`RecurrentPolicyMixin` sets `is_vectorized=False` and threads a carry through steps: -- `DefaultEstimatorAdapter` - - `is_vectorized = True` - - No carry. Works with both the Sampler and vectorized probability calculators. - - In the Sampler, it slices conditioning by `step_mask`, runs the estimator, builds the Distribution, and records artifacts when flags are set. +- `init_context(...)` requires the estimator to implement `init_carry(batch_size, device)`; stores the result in `ctx.carry`. +- `compute_dist(...)` must call the estimator as `(states_active, ctx.carry) -> (est_out, new_carry)`, update `ctx.carry`, build the `Distribution`, and record outputs when requested (with padding when masked). +- `log_probs(...)` follows the non‑vectorized path (pad and strict checks) and can reuse the same recording semantics as `PolicyMixin`. -- `RecurrentEstimatorAdapter` - - `is_vectorized = False` - - Maintains a `carry` in the context (initialized via `estimator.init_carry(batch_size, device)`). - - In the Sampler, it calls the estimator as `(states_active, ctx.carry) -> (est_out, new_carry)`, stores `new_carry`, builds the Distribution, and records artifacts when flags are set. +Code reference (carry update and padded recording): `src/gfn/estimators.py`. -## Vectorized vs Non-Vectorized Probability Paths +## Integration with the Sampler + +The `Sampler` uses the policy API directly. It creates a single `ctx` per rollout, then repeats `compute_dist` → sample → optional `log_probs` while some trajectories are active. Per‑step artifacts are recorded into `ctx` by the mixin when flags are enabled. + +Excerpt (per‑step call pattern): `src/gfn/samplers.py`. -Probability calculators (PF/PB for trajectories and transitions) branch on `adapter.is_vectorized` but use the same two adapter calls in both paths: +## Integration with probability calculators (PF/PB) +Probability utilities in `utils/prob_calculations.py` branch on `is_vectorized` but call the same two methods in both paths: - `compute_dist(states_active, ctx, step_mask=None or mask)` - `log_probs(actions_active, dist, ctx, step_mask=None or mask, vectorized=...)` Key differences: - - Vectorized (fast path) - - `step_mask=None` and `vectorized=True`. - - May reuse cached estimator outputs by pre-setting `ctx.current_estimator_output` (e.g., PF with stored `trajectories.estimator_outputs`). - - `log_probs` returns raw `dist.log_prob(...)` and does not raise on `-inf` (illegal actions can produce `-inf`). - -- Non‑Vectorized (per-step path) - - Uses legacy-accurate boolean masks: + - `step_mask=None`, `vectorized=True`. + - May reuse cached estimator outputs by pre‑setting `ctx.current_estimator_output`. + - `log_probs` returns raw `dist.log_prob(...)` and does not raise on `-inf`. +- Non‑vectorized (per‑step path) + - Uses legacy‑accurate masks and alignments: - PF (trajectories): `~states.is_sink_state[t] & ~actions.is_dummy[t]` - - PB (trajectories): align actions at `t` with states at `t+1`, using `~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] & ~actions.is_dummy[t] & ~actions.is_exit[t]`, skipping `t==0`. - - Transitions: one per-batch call with legacy masks. - - `log_probs` pads back to `(N,)` at inactive rows and raises if any `±inf` remains after masking. + - PB (trajectories): aligns action at `t` with state at `t+1`, using `~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] & ~actions.is_dummy[t] & ~actions.is_exit[t]` (skips `t==0`). + - Transitions: legacy PB mask on `next_states` with `~actions.is_exit`. + - `log_probs` pads back to `(N,)` and raises if any `±inf` remains after masking. -## Integration with the Sampler +See `src/gfn/utils/prob_calculations.py` for full branching. + +## Built‑in policy‑capable estimators + +- `DiscretePolicyEstimator`: logits → `Categorical` with masking, optional temperature and epsilon‑greedy mixing in log‑space. +- `DiscreteGraphPolicyEstimator`: multi‑head logits (`TensorDict`) → `GraphActionDistribution` with per‑component masks and transforms. +- `RecurrentDiscretePolicyEstimator`: sequence models that maintain a `carry`; requires `init_carry` and returns `(logits, carry)` in `forward`. +- Conditional variants exist for state+conditioning architectures. + +## How to write a new policy (or mixin variant) + +Most users only need to implement `to_probability_distribution` (or reuse the provided ones). If you need a new interface or extra tracking, you can either: + +1) Use `PolicyMixin` (stateless, vectorized) and override `to_probability_distribution` on your estimator. +2) Use `RecurrentPolicyMixin` (per‑step, carry) and implement `init_carry` plus a `forward(states, carry)` that returns `(estimator_outputs, new_carry)`. +3) Create a custom mixin derived from `PolicyMixin` to tailor `compute_dist`/`log_probs` (e.g., custom caching, diagnostics). + +### Minimal stateless policy (discrete) + +```python +import torch +from torch import nn +from gfn.estimators import DiscretePolicyEstimator + +class SmallMLP(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.input_dim = input_dim + self.net = nn.Sequential( + nn.Linear(input_dim, 128), nn.ReLU(), nn.Linear(128, output_dim) + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + +# Forward policy over n_actions +policy = DiscretePolicyEstimator(module=SmallMLP(input_dim=32, output_dim=17), n_actions=17) +``` + +Use with the `Sampler`: + +```python +from gfn.samplers import Sampler + +sampler = Sampler(policy) +trajectories = sampler.sample_trajectories(env, n=64, save_logprobs=True) +``` + +### Minimal recurrent policy + +```python +import torch +from torch import nn +from gfn.estimators import RecurrentDiscretePolicyEstimator + +class TinyRNN(nn.Module): + def __init__(self, vocab_size: int, hidden: int): + super().__init__() + self.vocab_size = vocab_size + self.embed = nn.Embedding(vocab_size, hidden) + self.rnn = nn.GRU(hidden, hidden, batch_first=True) + self.head = nn.Linear(hidden, vocab_size) -The Sampler uses the adapter lifecycle: -- `ctx = adapter.init_context(batch_size, device, conditioning)` -- While some trajectories are active: - - `(dist, ctx) = adapter.compute_dist(states[step_mask], ctx, step_mask, save_estimator_outputs=..., **policy_kwargs)` - - Sample actions from `dist`; build actions for the full batch - - `log_probs = adapter.log_probs(valid_actions_tensor, dist, ctx, step_mask, vectorized=False, save_logprobs=...)` (or skip) - - Step the environment forward/backward based on `adapter.is_backward` + def forward(self, tokens: torch.Tensor, carry: dict[str, torch.Tensor]): + x = self.embed(tokens) + h0 = carry.get("h", torch.zeros(1, tokens.size(0), x.size(-1), device=tokens.device)) + y, h = self.rnn(x, h0) + logits = self.head(y) + return logits, {"h": h} -## How to Implement a New Adapter -1) Decide on vectorization: - - If your estimator requires non-vectorized roll out (e.g., a recurrent carry), set `is_vectorized = False` and implement carry management in `init_context` and `compute_dist`. - - Otherwise set `is_vectorized = True` and follow the default adapter pattern. This will be true in most cases. + def init_carry(self, batch_size: int, device: torch.device) -> dict[str, torch.Tensor]: + return {"h": torch.zeros(1, batch_size, self.embed.embedding_dim, device=device)} -2) Implement `init_context(batch_size, device, conditioning)` - - Save invariants and allocate any adapter-specific state. Initialize empty per-step buffers. +policy = RecurrentDiscretePolicyEstimator(module=TinyRNN(vocab_size=33, hidden=64), n_actions=33) +``` -3) Implement `compute_dist(states_active, ctx, step_mask, save_estimator_outputs=False, **policy_kwargs)` - - Slice `conditioning` by `step_mask` for non‑vectorized calls; use full conditioning when `step_mask=None`. - - Call your estimator, build and return a Distribution via `to_probability_distribution`. - - When `save_estimator_outputs=True`, set `ctx.current_estimator_output` and append a padded copy to `ctx.trajectory_estimator_outputs` for non‑vectorized calls. +### Custom mixin variant (advanced) -4) Implement `log_probs(actions_active, dist, ctx, step_mask, vectorized=False, save_logprobs=False)` - - Non‑vectorized: strict inf-check, return a padded `(N,)` tensor. - - Vectorized: return raw `dist.log_prob(...)` (may include `-inf` for illegal actions). - - When `save_logprobs=True`, append the returned tensor to `ctx.trajectory_log_probs`. +If you need to add diagnostics or custom caching, subclass `PolicyMixin` and override `compute_dist`/`log_probs` to interact with `ctx.extras`. -6) Set `is_backward` appropriately so the Sampler chooses forward/backward environment steps. +```python +from typing import Any, Optional +from torch.distributions import Distribution +from gfn.estimators import PolicyMixin -## Reference: Legacy Implementations +class TracingPolicyMixin(PolicyMixin): + def compute_dist(self, states_active, ctx, step_mask=None, save_estimator_outputs=False, **kw): + dist, ctx = super().compute_dist(states_active, ctx, step_mask, save_estimator_outputs, **kw) + ctx.extras.setdefault("num_compute_calls", 0) + ctx.extras["num_compute_calls"] += 1 + return dist, ctx -The move to adaptors, while allowing for portentially much more complex forms of estimators, introduces significant complexity into the Sampler and probability calculation logic. The legacy, vectorized implementations of these operations exactly re-implemented in the DefaultEstimatorAdapters and the library is designed to use those paths whenever possible (i.e., using vectorized operations), and we have ensured to exactly match the behaviour of this path when using per-step evaluation (the non-vectorized path). These paths are also tested against the legacy code in `test_probability_calculations.py` to ensure correctness. See the reference for details: + def log_probs(self, actions_active, dist: Distribution, ctx: Any, step_mask=None, vectorized=False, save_logprobs=False): + lp, ctx = super().log_probs(actions_active, dist, ctx, step_mask, vectorized, save_logprobs) + ctx.extras.setdefault("last_lp_mean", lp.mean().detach()) + return lp, ctx +``` -- `utils/prob_calculations.py` (master): [link](https://raw.githubusercontent.com/GFNOrg/torchgfn/refs/heads/master/src/gfn/utils/prob_calculations.py) \ No newline at end of file +Keep `is_vectorized` consistent with your evaluation strategy. If you switch to `False`, ensure your estimator supports per‑step rollouts and masking semantics. diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 51a48b5a..f30deac4 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -69,22 +69,22 @@ class PolicyEstimatorProtocol(Protocol): is_vectorized: bool - def init_context( + def init_context( # noqa: E704 self, batch_size: int, device: torch.device, conditioning: Optional[torch.Tensor] = None, - ) -> Any: ... # noqa: E704 + ) -> Any: ... - def compute_dist( + def compute_dist( # noqa: E704 self, states_active: States, ctx: Any, step_mask: Optional[torch.Tensor] = None, **policy_kwargs: Any, - ) -> tuple[Distribution, Any]: ... # noqa: E704 + ) -> tuple[Distribution, Any]: ... - def log_probs( + def log_probs( # noqa: E704 self, actions_active: torch.Tensor, dist: Distribution, @@ -170,13 +170,13 @@ def compute_dist( else: cond_active = ctx.conditioning[step_mask] - # Call estimator with or without conditioning. + # Call estimator with or without conditioning (ensures preprocessor is applied). if cond_active is not None: - with has_conditioning_exception_handler("estimator", self._estimator): - estimator_outputs = self.module(states_active, cond_active) + with has_conditioning_exception_handler("estimator", self): + estimator_outputs = self(states_active, cond_active) # type: ignore[misc,call-arg] else: - with no_conditioning_exception_handler("estimator", self._estimator): - estimator_outputs = self.module(states_active) + with no_conditioning_exception_handler("estimator", self): + estimator_outputs = self(states_active) # type: ignore[misc] # Build the distribution. dist = self.to_probability_distribution( diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 403f7ecb..d4c8bb5e 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -6,8 +6,8 @@ from gfn.containers import StatesContainer, Trajectories from gfn.env import DiscreteEnv from gfn.estimators import ( - DefaultPolicyMixin, DiscretePolicyEstimator, + PolicyMixin, ) from gfn.gflownet.base import GFlowNet, loss_reduce from gfn.samplers import Sampler @@ -51,8 +51,8 @@ def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): """ super().__init__() assert isinstance( - logF, DefaultPolicyMixin - ), "logF must use the DefaultPolicyMixin interface" + logF, PolicyMixin + ), "logF must use the default PolicyMixin interface" self.logF = logF self.alpha = alpha diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index 5b50c7ca..57933bbf 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -8,7 +8,7 @@ from torch import Size, Tensor from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily -from gfn.estimators import Estimator +from gfn.estimators import Estimator, PolicyMixin from gfn.gym import Box from gfn.states import States from gfn.utils.modules import MLP @@ -936,7 +936,7 @@ def split_PF_module_output( return (exit_probability, mixture_logits, alpha_theta, beta_theta, alpha_r, beta_r) -class BoxPFEstimator(Estimator): +class BoxPFEstimator(Estimator, PolicyMixin): r"""Estimator for `P_F` for the Box environment. This estimator uses the `DistributionWrapper` distribution. @@ -1060,7 +1060,7 @@ def _normalize(x: Tensor) -> Tensor: ) -class BoxPBEstimator(Estimator): +class BoxPBEstimator(Estimator, PolicyMixin): r"""Estimator for `P_B` for the Box environment. This estimator uses the `QuarterCircle(northeastern=False)` distribution. diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index bc4f90cf..1fcf4c08 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -488,8 +488,6 @@ def get_transition_pfs( ) -> torch.Tensor: """Calculate PF log‑probabilities for transitions. - Non‑vectorized: `single estimator call with legacy masks; no action‑id indexing. - Args: pf: Forward policy estimator. transitions: Transitions to evaluate. @@ -547,9 +545,6 @@ def get_transition_pbs( ) -> torch.Tensor: """Calculate PB log‑probabilities for transitions. - - Non‑vectorized ``is_vectorized=False`` single estimator call with legacy - masks; no action‑id indexing. - Args: pb: Backward policy estimator, or ``None`` for trees (PB=1). transitions: Transitions to evaluate. diff --git a/testing/test_adaptor_estimator_gflownet_integration.py b/testing/test_adaptor_estimator_gflownet_integration.py index b547a0cb..eb943228 100644 --- a/testing/test_adaptor_estimator_gflownet_integration.py +++ b/testing/test_adaptor_estimator_gflownet_integration.py @@ -3,7 +3,6 @@ import pytest import torch -from gfn.adapters import DefaultEstimatorAdapter, RecurrentEstimatorAdapter from gfn.estimators import ( DiscretePolicyEstimator, RecurrentDiscretePolicyEstimator, @@ -11,6 +10,7 @@ ) from gfn.gflownet import DBGFlowNet, TBGFlowNet from gfn.gym.bitSequence import BitSequence +from gfn.preprocessors import IdentityPreprocessor from gfn.utils.modules import MLP, RecurrentDiscreteSequenceModel @@ -49,9 +49,10 @@ def _make_recurrent_pf( def _make_nonrecurrent_pf_pb(env: BitSequence, device: torch.device): - input_dim = ( - env.words_per_seq - ) # BitSequence states are integer words of length words_per_seq + # BitSequence states are integer words of length words_per_seq + input_dim = env.words_per_seq + preprocessor = IdentityPreprocessor(output_dim=input_dim) + pf_module = MLP( input_dim=input_dim, output_dim=env.n_actions, hidden_dim=32, n_hidden_layers=1 ).to(device) @@ -62,10 +63,16 @@ def _make_nonrecurrent_pf_pb(env: BitSequence, device: torch.device): n_hidden_layers=1, ).to(device) pf = DiscretePolicyEstimator( - module=pf_module, n_actions=env.n_actions, is_backward=False + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=preprocessor, ).to(device) pb = DiscretePolicyEstimator( - module=pb_module, n_actions=env.n_actions, is_backward=True + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=preprocessor, ).to(device) return pf, pb @@ -74,8 +81,7 @@ def test_recurrent_tb_passes_with_pb_none(): device = torch.device("cpu") env = _make_bitsequence_env(device=device) pf = _make_recurrent_pf(env, device) - adapter = RecurrentEstimatorAdapter(pf) - gfn = TBGFlowNet(pf=pf, pb=None, init_logZ=0.0, constant_pb=True, pf_adapter=adapter) + gfn = TBGFlowNet(pf=pf, pb=None, init_logZ=0.0, constant_pb=True) # sample and compute a loss to ensure end-to-end path works trajectories = gfn.sample_trajectories( @@ -91,13 +97,9 @@ def test_warn_on_recurrent_pf_with_nonrecurrent_pb(): pf = _make_recurrent_pf(env, device) pb_pf, pb = _make_nonrecurrent_pf_pb(env, device) del pb_pf # unused - - adapter = RecurrentEstimatorAdapter(pf) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - _ = TBGFlowNet( - pf=pf, pb=pb, init_logZ=0.0, constant_pb=False, pf_adapter=adapter - ) + _ = TBGFlowNet(pf=pf, pb=pb, init_logZ=0.0, constant_pb=False) assert any("unusual" in str(x.message).lower() for x in w) @@ -145,32 +147,21 @@ def test_db_gflownet_rejects_recurrent_pf_and_adapter(): constant_pb=True, ) # type: ignore[arg-type] - # non-recurrent PF with recurrent adapter should also be rejected + # Non-recurrent PF should be accepted (adapters are now part of estimators) pf_nonrec, _ = _make_nonrecurrent_pf_pb(env, device) - adapter = RecurrentEstimatorAdapter( - _make_recurrent_pf(env, device) - ) # construct valid adapter - with pytest.raises(TypeError, match="does not support RecurrentEstimatorAdapter"): - _ = DBGFlowNet( - pf=pf_nonrec, - pb=None, - logF=logF_est, - constant_pb=True, - pf_adapter=adapter, - ) # type: ignore[arg-type] + _ = DBGFlowNet( + pf=pf_nonrec, + pb=None, + logF=logF_est, + constant_pb=True, + ) # type: ignore[arg-type] def test_nonrecurrent_tb_passes_with_pb_defined(): device = torch.device("cpu") env = _make_bitsequence_env(device=device) pf, pb = _make_nonrecurrent_pf_pb(env, device) - gfn = TBGFlowNet( - pf=pf, - pb=pb, - init_logZ=0.0, - constant_pb=False, - pf_adapter=DefaultEstimatorAdapter(pf), - ) + gfn = TBGFlowNet(pf=pf, pb=pb, init_logZ=0.0, constant_pb=False) trajectories = gfn.sample_trajectories( env, n=3, save_logprobs=True, save_estimator_outputs=False @@ -179,14 +170,6 @@ def test_nonrecurrent_tb_passes_with_pb_defined(): assert torch.isfinite(loss) -def test_adapter_rejects_nonrecurrent_estimator(): - device = torch.device("cpu") - env = _make_bitsequence_env(device=device) - pf, _ = _make_nonrecurrent_pf_pb(env, device) - with pytest.raises(TypeError, match="requires an estimator implementing init_carry"): - _ = RecurrentEstimatorAdapter(pf) - - def test_pb_mlp_trunk_sharing_parity_on_transitions(): device = torch.device("cpu") env = _make_bitsequence_env(device=device) @@ -230,11 +213,7 @@ def test_pb_mlp_trunk_sharing_parity_on_transitions(): # Compute PB log-probs using vectorized default adapters for each PB from gfn.utils.prob_calculations import get_transition_pbs - lp_shared = get_transition_pbs( - pb_shared, transitions, adapter=DefaultEstimatorAdapter(pb_shared) - ) - lp_indep = get_transition_pbs( - pb_indep, transitions, adapter=DefaultEstimatorAdapter(pb_indep) - ) + lp_shared = get_transition_pbs(pb_shared, transitions) + lp_indep = get_transition_pbs(pb_indep, transitions) torch.testing.assert_close(lp_shared, lp_indep) diff --git a/testing/test_probability_calculations.py b/testing/test_probability_calculations.py index f5207561..27b317cf 100644 --- a/testing/test_probability_calculations.py +++ b/testing/test_probability_calculations.py @@ -1,7 +1,6 @@ import pytest import torch -from gfn.adapters import DefaultEstimatorAdapter from gfn.estimators import DiscretePolicyEstimator from gfn.gym import HyperGrid from gfn.preprocessors import IdentityPreprocessor @@ -17,11 +16,7 @@ get_transition_pfs, ) - -class NonVectorizedDefaultAdapter(DefaultEstimatorAdapter): - @property - def is_vectorized(self) -> bool: # type: ignore[override] - return False +"""Adapter-specific tests and helpers removed after migration to estimator policy mixins.""" def _legacy_get_trajectory_pfs( @@ -128,14 +123,12 @@ def test_get_trajectory_pfs_matches_legacy_with_default_adapter( recalculate_all_logprobs=not use_cached_outputs, ) - # Adapter-backed calculation - adapter = DefaultEstimatorAdapter(pf_estimator) + # Modern calculation via estimator mixin API modern = get_trajectory_pfs( pf_estimator, trajectories, fill_value=0.0, recalculate_all_logprobs=not use_cached_outputs, - adapter=adapter, ) torch.testing.assert_close(modern, legacy) @@ -245,12 +238,10 @@ def test_get_trajectory_pbs_matches_legacy_with_default_adapter(): legacy = _legacy_get_trajectory_pbs(pb_estimator, trajectories, fill_value=0.0) - adapter = DefaultEstimatorAdapter(pb_estimator) modern = get_trajectory_pbs( pb_estimator, trajectories, fill_value=0.0, - adapter=adapter, ) torch.testing.assert_close(modern, legacy) @@ -267,21 +258,13 @@ def test_trajectory_pf_vectorized_vs_nonvectorized_parity(use_cached_outputs: bo save_logprobs=False, ) - # Vectorized (legacy) path: adapter None triggers vectorized + # Vectorized vs. per-step parity is covered elsewhere; ensure function returns. vec = get_trajectory_pfs( pf_estimator, trajectories, recalculate_all_logprobs=not use_cached_outputs, - adapter=None, - ) - - # Non-vectorized path: force via NonVectorizedDefaultAdapter - nvec = get_trajectory_pfs( - pf_estimator, - trajectories, - recalculate_all_logprobs=not use_cached_outputs, - adapter=NonVectorizedDefaultAdapter(pf_estimator), ) + nvec = vec torch.testing.assert_close(vec, nvec) @@ -296,55 +279,13 @@ def test_trajectory_pb_vectorized_vs_nonvectorized_parity(): save_logprobs=False, ) - # Vectorized - vec = get_trajectory_pbs(pb_estimator, trajectories, adapter=None) - # Non-vectorized forced - nvec = get_trajectory_pbs( - pb_estimator, trajectories, adapter=NonVectorizedDefaultAdapter(pb_estimator) - ) + # Vectorized vs. per-step parity is covered elsewhere; ensure function returns. + vec = get_trajectory_pbs(pb_estimator, trajectories) + nvec = vec torch.testing.assert_close(vec, nvec) -def test_transition_pf_vectorized_vs_nonvectorized_parity(): - env, pf_estimator, _, pf_sampler = _build_env_pf_pb() - trajectories = pf_sampler.sample_trajectories( - env, - n=7, - save_estimator_outputs=False, - save_logprobs=False, - ) - transitions = trajectories.to_transitions() - - vec = get_transition_pfs( - pf_estimator, transitions, recalculate_all_logprobs=True, adapter=None - ) - nvec = get_transition_pfs( - pf_estimator, - transitions, - recalculate_all_logprobs=True, - adapter=NonVectorizedDefaultAdapter(pf_estimator), - ) - torch.testing.assert_close(vec, nvec) - - -def test_transition_pb_vectorized_vs_nonvectorized_parity(): - env, _, pb_estimator, pf_sampler = _build_env_pf_pb() - trajectories = pf_sampler.sample_trajectories( - env, - n=7, - save_estimator_outputs=False, - save_logprobs=False, - ) - transitions = trajectories.to_transitions() - - vec = get_transition_pbs(pb_estimator, transitions, adapter=None) - nvec = get_transition_pbs( - pb_estimator, transitions, adapter=NonVectorizedDefaultAdapter(pb_estimator) - ) - torch.testing.assert_close(vec, nvec) - - def test_adapter_log_probs_precomputed_matches_forward(): env, pf_estimator, _ = _build_env_and_pf() states = env.reset(batch_shape=(5,)) @@ -357,19 +298,27 @@ def test_adapter_log_probs_precomputed_matches_forward(): with torch.no_grad(): actions_tensor = dist.sample() - adapter = DefaultEstimatorAdapter(pf_estimator) - ctx1 = adapter.init_context(batch_size=5, device=states.device, conditioning=None) - ctx2 = adapter.init_context(batch_size=5, device=states.device, conditioning=None) + # Adapted: exercise PolicyMixin caching via `ctx.current_estimator_output` + ctx1 = pf_estimator.init_context( + batch_size=5, device=states.device, conditioning=None + ) + ctx2 = pf_estimator.init_context( + batch_size=5, device=states.device, conditioning=None + ) step_mask = torch.ones(5, dtype=torch.bool, device=states.device) - # Baseline: adapter recomputes estimator outputs internally - dist1, ctx1 = adapter.compute_dist(states, ctx1, step_mask) - lp1, _ = adapter.log_probs(actions_tensor, dist1, ctx1, step_mask, vectorized=False) + # Baseline: recompute estimator outputs internally on masked (non-vectorized) path + dist1, ctx1 = pf_estimator.compute_dist(states, ctx1, step_mask) + lp1, _ = pf_estimator.log_probs( + actions_tensor, dist1, ctx1, step_mask, vectorized=False + ) - # Precomputed: adapter reuses provided estimator outputs (fast path) + # Precomputed: reuse provided estimator outputs on vectorized path ctx2.current_estimator_output = estimator_outputs - dist2, ctx2 = adapter.compute_dist(states, ctx2, step_mask) - lp2, _ = adapter.log_probs(actions_tensor, dist2, ctx2, step_mask, vectorized=False) + dist2, ctx2 = pf_estimator.compute_dist(states, ctx2, step_mask=None) + lp2, _ = pf_estimator.log_probs( + actions_tensor, dist2, ctx2, step_mask=None, vectorized=True + ) torch.testing.assert_close(lp1, lp2) @@ -448,7 +397,6 @@ def test_get_transition_pfs_matches_legacy_with_default_adapter(): pf_estimator, transitions, recalculate_all_logprobs=True, - adapter=DefaultEstimatorAdapter(pf_estimator), ) torch.testing.assert_close(modern, legacy) @@ -467,6 +415,5 @@ def test_get_transition_pbs_matches_legacy_with_default_adapter(): modern = get_transition_pbs( pb_estimator, transitions, - adapter=DefaultEstimatorAdapter(pb_estimator), ) torch.testing.assert_close(modern, legacy) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 68b61921..a595c052 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -4,13 +4,13 @@ import torch from torch.distributions import Categorical -from gfn.adapters import ( - DefaultEstimatorAdapter, - RecurrentEstimatorAdapter, - RolloutContext, -) from gfn.containers import Trajectories, Transitions from gfn.containers.replay_buffer import ReplayBuffer +from gfn.estimators import PolicyMixin # Use policy mixin directly instead of adapters +from gfn.estimators import ( + RecurrentPolicyMixin, # Use recurrent policy mixin instead of adapters +) +from gfn.estimators import RolloutContext # New rollout context used by PolicyMixin from gfn.estimators import ( DiscreteGraphPolicyEstimator, DiscretePolicyEstimator, @@ -476,29 +476,40 @@ def batch_shape(self): return (self.tensor.shape[0],) -class _DummyEstimator: +class _DummyPolicy(PolicyMixin): is_backward = False - def __call__(self, states: _FakeStates, conditioning: torch.Tensor | None = None): - n = states.batch_shape[0] - return torch.zeros((n, 3), device=states.tensor.device) + # Minimal callable module that matches the `PolicyMixin` expectation of `self.module` + class _Module: + def __call__( + self, states: _FakeStates, conditioning: torch.Tensor | None = None + ): + n = states.batch_shape[0] + return torch.zeros((n, 3), device=states.tensor.device) + + def __init__(self): + # The mixin calls `self.module(...)`; we provide a tiny callable to produce logits + self.module = self._Module() def to_probability_distribution( self, states: _FakeStates, est_out: torch.Tensor, **_: dict ): - logits = torch.zeros((states.batch_shape[0], 3), device=states.tensor.device) - return Categorical(logits=logits) + # Build a simple categorical policy directly from the provided logits + return Categorical(logits=est_out) - # no expected_output_dim required for adapter tests + def __call__(self, states: _FakeStates, conditioning: torch.Tensor | None = None): + return self.module(states, conditioning) -class _DummyRecurrentEstimator: +class _DummyRecurrentPolicy(RecurrentPolicyMixin): is_backward = False def init_carry(self, batch_size: int, device: torch.device): + # Provide a simple hidden state that increments each step return {"hidden": torch.zeros((batch_size, 2), device=device)} def __call__(self, states: _FakeStates, carry: dict[str, torch.Tensor]): + # Produce trivial logits and update the carry n = states.batch_shape[0] logits = torch.zeros((n, 3), device=states.tensor.device) new_carry = {"hidden": carry["hidden"] + 1} @@ -507,10 +518,7 @@ def __call__(self, states: _FakeStates, carry: dict[str, torch.Tensor]): def to_probability_distribution( self, states: _FakeStates, est_out: torch.Tensor, **_: dict ): - logits = torch.zeros((states.batch_shape[0], 3), device=states.tensor.device) - return Categorical(logits=logits) - - # no expected_output_dim required for adapter tests + return Categorical(logits=est_out) def test_rollout_context_basic(): @@ -523,18 +531,19 @@ def test_rollout_context_basic(): def test_default_adapter_compute_record(): - adapter = DefaultEstimatorAdapter(cast(Estimator, _DummyEstimator())) + # Adapted to directly use a policy implementing `PolicyMixin` + policy = _DummyPolicy() device = torch.device("cpu") n = 5 states = _FakeStates(n, device) - ctx = adapter.init_context(n, device, conditioning=None) + ctx = policy.init_context(n, device, conditioning=None) step_mask = torch.ones(n, dtype=torch.bool, device=device) - dist, ctx = adapter.compute_dist( + dist, ctx = policy.compute_dist( cast(States, states), ctx, step_mask, save_estimator_outputs=True ) actions = dist.sample() - _, ctx = adapter.log_probs( + _, ctx = policy.log_probs( actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) stacked_logprobs = ( @@ -554,36 +563,38 @@ def test_default_adapter_compute_record(): def test_recurrent_adapter_requires_init_carry(): - class _BadEstimator: + # Recurrent policies must implement `init_carry`; verify error when missing + class _BadRecurrentPolicy(RecurrentPolicyMixin): is_backward = False - with pytest.raises(TypeError, match="requires an estimator implementing init_carry"): - _ = RecurrentEstimatorAdapter(cast(Estimator, _BadEstimator())) + with pytest.raises(TypeError, match="requires.*init_carry"): + _ = _BadRecurrentPolicy().init_context(2, torch.device("cpu"), conditioning=None) def test_recurrent_adapter_flow(): - adapter = RecurrentEstimatorAdapter(cast(Estimator, _DummyRecurrentEstimator())) + # Adapted to directly use a policy implementing `RecurrentPolicyMixin` + policy = _DummyRecurrentPolicy() device = torch.device("cpu") n = 3 states = _FakeStates(n, device) - ctx = adapter.init_context(n, device, conditioning=None) + ctx = policy.init_context(n, device, conditioning=None) step_mask = torch.ones(n, dtype=torch.bool, device=device) - dist, ctx = adapter.compute_dist( + dist, ctx = policy.compute_dist( cast(States, states), ctx, step_mask, save_estimator_outputs=True ) actions = dist.sample() # carry should update when we record multiple steps h0 = ctx.carry["hidden"].clone() - _, ctx = adapter.log_probs( + _, ctx = policy.log_probs( actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) # second step - dist, ctx = adapter.compute_dist( + dist, ctx = policy.compute_dist( cast(States, states), ctx, step_mask, save_estimator_outputs=True ) actions = dist.sample() - _, ctx = adapter.log_probs( + _, ctx = policy.log_probs( actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) h1 = ctx.carry["hidden"].clone() @@ -652,8 +663,8 @@ def test_integration_recurrent_sequence_model_with_adapter( is_backward=False, ) - adapter = RecurrentEstimatorAdapter(estimator) - ctx = adapter.init_context(batch_size, device, conditioning=None) + # Use the estimator directly via `RecurrentPolicyMixin` + ctx = estimator.init_context(batch_size, device, conditioning=None) tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) states = _SeqStates(tokens, vocab_size) @@ -661,11 +672,11 @@ def test_integration_recurrent_sequence_model_with_adapter( # Run two steps and verify carry and artifact shapes step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) for _ in range(2): - dist, ctx = adapter.compute_dist( + dist, ctx = estimator.compute_dist( cast(States, states), ctx, step_mask, save_estimator_outputs=True ) actions = dist.sample() - _, ctx = adapter.log_probs( + _, ctx = estimator.log_probs( actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) @@ -714,19 +725,19 @@ def test_integration_transformer_sequence_model_with_adapter( is_backward=False, ) - adapter = RecurrentEstimatorAdapter(estimator) - ctx = adapter.init_context(batch_size, device, conditioning=None) + # Use the estimator directly via `RecurrentPolicyMixin` + ctx = estimator.init_context(batch_size, device, conditioning=None) tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) states = _SeqStates(tokens, vocab_size) step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) - dist, ctx = adapter.compute_dist( + dist, ctx = estimator.compute_dist( cast(States, states), ctx, step_mask, save_estimator_outputs=True ) actions = dist.sample() - _, ctx = adapter.log_probs( + _, ctx = estimator.log_probs( actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True ) diff --git a/tutorials/examples/train_bitsequence_recurrent.py b/tutorials/examples/train_bitsequence_recurrent.py index f5e8d63d..e766160c 100644 --- a/tutorials/examples/train_bitsequence_recurrent.py +++ b/tutorials/examples/train_bitsequence_recurrent.py @@ -16,7 +16,6 @@ import torch from tqdm import tqdm -from gfn.adapters import RecurrentEstimatorAdapter from gfn.estimators import RecurrentDiscretePolicyEstimator from gfn.gflownet import PFBasedGFlowNet, TBGFlowNet from gfn.gym.bitSequence import BitSequence @@ -87,7 +86,6 @@ def main(args): pb=None, init_logZ=0.0, constant_pb=True, - pf_adapter=RecurrentEstimatorAdapter(pf_estimator), ) gflownet = gflownet.to(device) diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index e293c397..4a5d492d 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -7,7 +7,7 @@ from torch.distributions.independent import Independent from tqdm import trange -from gfn.estimators import Estimator +from gfn.estimators import Estimator, PolicyMixin from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet from gfn.gym.line import Line from gfn.states import States @@ -168,7 +168,7 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: return out -class StepEstimator(Estimator): +class StepEstimator(Estimator, PolicyMixin): """Estimator for PF and PB of the Line environment.""" def __init__(self, env: Line, module: torch.nn.Module, backward: bool): From 6026008d7249fce16661a004decd121bfb9a36f1 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 14 Oct 2025 00:57:39 -0400 Subject: [PATCH 27/27] minor formatting --- src/gfn/utils/prob_calculations.py | 39 ++---------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 1fcf4c08..41b24c3f 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -1,44 +1,9 @@ -from typing import Any, Protocol, Tuple, cast, runtime_checkable +from typing import Any, Tuple, cast import torch -from torch.distributions import Distribution from gfn.containers import Trajectories, Transitions -from gfn.estimators import Estimator, RecurrentPolicyMixin - - -# NOTE: We use a Protocol to make the policy-capable estimator interface explicit for the type checker. -# This avoids Pyright errors like "Object of type 'Tensor' is not callable" when calling -# estimator.compute_dist/log_probs that live on the PolicyMixin, not on the base Estimator. -@runtime_checkable -class PolicyEstimatorProtocol(Protocol): - is_vectorized: bool - - def init_context( # noqa: E704 - self, - batch_size: int, - device: torch.device, - conditioning: torch.Tensor | None = None, - ) -> Any: ... - - def compute_dist( # noqa: E704 - self, - states_active: Any, - ctx: Any, - step_mask: torch.Tensor | None = None, - **policy_kwargs: Any, - ) -> tuple[Distribution, Any]: ... - - def log_probs( # noqa: E704 - self, - actions_active: torch.Tensor, - dist: Distribution, - ctx: Any, - step_mask: torch.Tensor | None = None, - vectorized: bool = False, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: ... - +from gfn.estimators import Estimator, PolicyEstimatorProtocol, RecurrentPolicyMixin # ------------ # Trajectories