From d6a27210dd12d92a3d4b3dce1aa69aa94b3e1c8f Mon Sep 17 00:00:00 2001 From: younik Date: Thu, 29 Jan 2026 11:43:26 +0100 Subject: [PATCH 1/3] uniformize logging --- src/gfn/__init__.py | 3 ++ src/gfn/gym/diffusion_sampling.py | 19 +++++---- src/gfn/gym/hypergrid.py | 12 +++--- src/gfn/modules.py | 10 ++--- src/gfn/states.py | 14 +++--- src/gfn/utils/common.py | 9 ++-- src/gfn/utils/compile.py | 5 ++- src/gfn/utils/distributed.py | 71 ++++++++++++++++--------------- src/gfn/utils/handlers.py | 16 ++++--- src/gfn/utils/modules.py | 5 ++- src/gfn/utils/training.py | 16 +++---- 11 files changed, 98 insertions(+), 82 deletions(-) diff --git a/src/gfn/__init__.py b/src/gfn/__init__.py index a521e9f6..1010066c 100644 --- a/src/gfn/__init__.py +++ b/src/gfn/__init__.py @@ -1,3 +1,6 @@ import importlib.metadata as met +import logging __version__ = met.version("torchgfn") + +logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 68fc5385..394e664b 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -1,5 +1,6 @@ import math import os +import logging from abc import ABC, abstractmethod from contextlib import nullcontext from typing import Any, cast @@ -16,6 +17,8 @@ from gfn.states import States from gfn.utils.common import filter_kwargs_for_callable, temporarily_set_seed +logger = logging.getLogger(__name__) + # Lightweight typing alias for the target registry entries. TargetEntry = tuple[type["BaseTarget"], dict[str, Any]] @@ -215,9 +218,9 @@ def __init__( ) mixture_weights = mixture_weights / mixture_weights.sum() - print("+ Gaussian Mixture Target initialization:") - print("+ num_components: ", num_components) - print("+ mixture_weights: ", mixture_weights) + logger.info("+ Gaussian Mixture Target initialization:") + logger.info(f"+ num_components: {num_components}") + logger.info(f"+ mixture_weights: {mixture_weights}") for i, (loc, cov) in enumerate(zip(locs, covariances)): loc_str = np.array2string(loc, precision=2, separator=", ").replace( "\n", " " @@ -225,7 +228,7 @@ def __init__( cov_str = np.array2string(cov, precision=2, separator=", ").replace( "\n", " " ) - print(f"\tComponent {i+1}: loc={loc_str}, cov={cov_str}") + logger.info(f"\tComponent {i+1}: loc={loc_str}, cov={cov_str}") # Convert to torch tensors dtype = torch.get_default_dtype() @@ -719,8 +722,8 @@ def __init__( target_cls.__init__, {**default_kwargs, **(target_kwargs or {})}, ) - print("DiffusionSampling:") - print(f"+ Initalizing target {target_cls.__name__} with kwargs: {merged_kwargs}") + logger.info("DiffusionSampling:") + logger.info(f"+ Initalizing target {target_cls.__name__} with kwargs: {merged_kwargs}") self.target = target_cls(device=device, **merged_kwargs) self.dim = self.target.dim @@ -779,9 +782,9 @@ def list_available_targets(cls) -> dict[str, dict[str, Any]]: target class's constructor signature. """ out = {} - print("Available DiffusionSampling targets:") + logger.info("Available DiffusionSampling targets:") for alias, (cls, defaults) in cls.DIFFUSION_TARGETS.items(): - print(f"+ {alias}: {cls.__name__} with kwargs: {defaults}") + logger.info(f"+ {alias}: {cls.__name__} with kwargs: {defaults}") out[alias] = {"class": cls.__name__, "defaults": dict(defaults)} return out diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 9196aa11..f47d24dd 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -1,9 +1,9 @@ """Adapted from https://github.com/Tikquuss/GflowNets_Tutorial""" import itertools +import logging import multiprocessing import platform -import warnings from abc import ABC, abstractmethod from decimal import Decimal from functools import reduce @@ -17,6 +17,8 @@ from gfn.env import DiscreteEnv from gfn.states import DiscreteStates +logger = logging.getLogger(__name__) + if platform.system() == "Windows": multiprocessing.set_start_method("spawn", force=True) else: @@ -86,7 +88,7 @@ def __init__( debug: If True, emit States with debug guards (not compile-friendly). """ if height <= 4: - warnings.warn("+ Warning: height <= 4 can lead to unsolvable environments.") + logger.warning("+ Warning: height <= 4 can lead to unsolvable environments.") reward_functions = { "original": OriginalReward, @@ -123,10 +125,10 @@ def __init__( if self.store_all_states: assert self._all_states_tensor is not None - print(f"+ Environment has {len(self._all_states_tensor)} states") + logger.info(f"+ Environment has {len(self._all_states_tensor)} states") if self.calculate_partition: assert self._log_partition is not None - print(f"+ Environment log partition is {self._log_partition}") + logger.info(f"+ Environment log partition is {self._log_partition}") if isinstance(device, str): device = torch.device(device) @@ -414,7 +416,7 @@ def _enumerate_all_states_tensor(self, batch_size: int = 20_000): total_rewards += self.reward_fn(batch_tensor).sum().item() end_time = time() - print( + logger.info( "Enumerated all states in {} minutes".format( (end_time - start_time) / 60.0, ) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index fe847cf7..c8e4f797 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,10 +1,10 @@ -import warnings +import logging -warnings.warn( +logger = logging.getLogger(__name__) + +logger.warning( "'modules.py' is deprecated and will be removed in a future release. Please import " - "from 'estimators.py' instead.", - DeprecationWarning, - stacklevel=2, + "from 'estimators.py' instead." ) from gfn.estimators import ( # noqa: F401, E402 diff --git a/src/gfn/states.py b/src/gfn/states.py index ce881a93..b85fe7fb 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -1,7 +1,7 @@ from __future__ import annotations # This allows to use the class name in type hints import inspect -import warnings +import logging from abc import ABC from math import prod from typing import ( @@ -25,9 +25,7 @@ from gfn.utils.common import ensure_same_device from gfn.utils.graphs import GeometricBatch, get_edge_indices -warnings.filterwarnings( - "once", message="Inconsistent conditions when extending states. Setting to None." -) +logger = logging.getLogger(__name__) def _assert_factory_accepts_debug(factory: Callable, factory_name: str) -> None: @@ -373,7 +371,7 @@ def __setitem__( self.conditions[index] = states.conditions else: if self.conditions is not None or states.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when setting states. Setting to None." ) self.conditions = None @@ -439,7 +437,7 @@ def extend(self, other: States) -> None: ) else: if self.conditions is not None or other.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when extending states. Setting to None." ) self.conditions = None @@ -1495,7 +1493,7 @@ def __setitem__( self.conditions[index] = graph.conditions else: if self.conditions is not None or graph.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when setting states. Setting to None." ) self.conditions = None @@ -1600,7 +1598,7 @@ def extend(self, other: GraphStates): ) else: if self.conditions is not None or other.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when extending states. Setting to None." ) self.conditions = None diff --git a/src/gfn/utils/common.py b/src/gfn/utils/common.py index 5c7fd44d..dc4ec157 100644 --- a/src/gfn/utils/common.py +++ b/src/gfn/utils/common.py @@ -1,9 +1,9 @@ import inspect +import logging import os import random import threading import time -import warnings from contextlib import contextmanager from typing import Any, Callable, Tuple @@ -11,6 +11,8 @@ import torch import torch.distributed as dist +logger = logging.getLogger(__name__) + # ----------------------------------------------------------------------------- # Utility helpers # ----------------------------------------------------------------------------- @@ -249,9 +251,8 @@ def set_seed(seed: int, deterministic_mode: bool = False) -> None: torch.use_deterministic_algorithms(True) except AttributeError: # Older PyTorch (<1.8) fallback: do nothing. - warnings.warn( - "PyTorch is older than 1.8, deterministic algorithms are not supported.", - UserWarning, + logger.warning( + "PyTorch is older than 1.8, deterministic algorithms are not supported." ) # CPU-specific settings for non-distributed case diff --git a/src/gfn/utils/compile.py b/src/gfn/utils/compile.py index d772a503..b5263152 100644 --- a/src/gfn/utils/compile.py +++ b/src/gfn/utils/compile.py @@ -1,9 +1,12 @@ from __future__ import annotations +import logging from typing import Iterable import torch +logger = logging.getLogger(__name__) + def try_compile_gflownet( gfn, @@ -52,4 +55,4 @@ def try_compile_gflownet( formatted = ", ".join( f"{name}:{'✓' if success else 'x'}" for name, success in results.items() ) - print(f"[compile] {formatted}") + logger.info(f"[compile] {formatted}") diff --git a/src/gfn/utils/distributed.py b/src/gfn/utils/distributed.py index 5b9f02fd..b27d31da 100644 --- a/src/gfn/utils/distributed.py +++ b/src/gfn/utils/distributed.py @@ -1,4 +1,5 @@ import datetime +import logging import os from dataclasses import dataclass from typing import Dict, List, Optional, cast @@ -6,6 +7,8 @@ import torch import torch.distributed as dist +logger = logging.getLogger(__name__) + def report_load_imbalance( all_timing_dict: List[Dict[str, List[float]]], @@ -20,8 +23,8 @@ def report_load_imbalance( param world_size: The total number of ranks in the distributed setup. """ # Header - print(f"{'Step Name':<25} {'Useful Work':>12} {'Waiting':>12}") - print("-" * 80) + logger.info(f"{'Step Name':<25} {'Useful Work':>12} {'Waiting':>12}") + logger.info("-" * 80) for step, times in all_timing_dict[0].items(): if type(times) is not list: @@ -37,7 +40,7 @@ def report_load_imbalance( is_valid_key = False break if not is_valid_key: - print(f"Time for Step - '{step}' not found in all ranks, skipping...") + logger.info(f"Time for Step - '{step}' not found in all ranks, skipping...") continue # Calculate the timing profile for the step. @@ -56,7 +59,7 @@ def report_load_imbalance( total_useful = sum(useful_work) total_waiting = sum(waiting_times) - print(f"{step:<25} {total_useful:>10.4f}s {total_waiting:>10.4f}s") + logger.info(f"{step:<25} {total_useful:>10.4f}s {total_waiting:>10.4f}s") def report_time_info( @@ -72,9 +75,9 @@ def report_time_info( param world_size: The total number of ranks in the distributed setup. """ overall_timing = {} - print("Timing information for each rank:") + logger.info("Timing information for each rank:") for rank in range(world_size): - print(f"Rank {rank} timing information:") + logger.info(f"Rank {rank} timing information:") for step, times in all_timing_dict[rank].items(): if type(times) is not list: times = [times] # Ensure times is a list @@ -82,20 +85,20 @@ def report_time_info( times_tensor = torch.tensor(times) avg_time = torch.sum(times_tensor).item() / len(times) sum_time = torch.sum(times_tensor).item() - print(f" {step}: {avg_time:.4f} seconds (total: {sum_time:.4f} seconds)") + logger.info(f" {step}: {avg_time:.4f} seconds (total: {sum_time:.4f} seconds)") if overall_timing.get(step) is None: overall_timing[step] = [sum_time] else: overall_timing[step].append(sum_time) - print("\nMaximum timing information:") + logger.info("\nMaximum timing information:") for step, times in overall_timing.items(): - print(f" {step}: {max(times):.4f} seconds") + logger.info(f" {step}: {max(times):.4f} seconds") - print("\nAverage timing information:") + logger.info("\nAverage timing information:") for step, times in overall_timing.items(): - print(f" {step}: {sum(times) / len(times):.4f} seconds") + logger.info(f" {step}: {sum(times) / len(times):.4f} seconds") def average_gradients(model): @@ -159,16 +162,16 @@ def initialize_distributed_compute( ], f"Invalid backend requested: {dist_backend}" pmi_size = int(os.environ.get("PMI_SIZE", "0")) # 0 or 1 default value? - print("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size)) + logger.info("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size)) if pmi_size <= 1: - print("+ PMI_SIZE <= 1, running in single process mode.") + logger.info("+ PMI_SIZE <= 1, running in single process mode.") return DistributedContext( my_rank=0, world_size=1, num_training_ranks=1, agent_group_size=1 ) if dist_backend == "ccl": - print("+ CCL backend requested...") + logger.info("+ CCL backend requested...") try: # Note - intel must be imported before oneccl! import oneccl_bindings_for_pytorch # noqa: F401 @@ -176,7 +179,7 @@ def initialize_distributed_compute( raise Exception("import oneccl_bindings_for_pytorch failed, {}".format(e)) elif dist_backend == "mpi": - print("+ MPI backend requested...") + logger.info("+ MPI backend requested...") assert torch.distributed.is_mpi_available() try: import torch_mpi # noqa: F401 @@ -184,7 +187,7 @@ def initialize_distributed_compute( raise Exception("import torch_mpi failed, {}".format(e)) elif dist_backend == "gloo": - print("+ Gloo backend requested...") + logger.info("+ Gloo backend requested...") assert torch.distributed.is_gloo_available() else: @@ -193,7 +196,7 @@ def initialize_distributed_compute( os.environ["RANK"] = os.environ.get("PMI_RANK", "0") os.environ["WORLD_SIZE"] = os.environ.get("PMI_SIZE", "1") - print("+ OMP_NUM_THREADS = ", os.getenv("OMP_NUM_THREADS")) + logger.info("+ OMP_NUM_THREADS = {}".format(os.getenv("OMP_NUM_THREADS"))) world_size = os.environ.get("WORLD_SIZE") if world_size is None: @@ -211,7 +214,7 @@ def initialize_distributed_compute( ) dist.barrier() - print("+ Distributed compute initialized, backend = {}".format(dist_backend)) + logger.info("+ Distributed compute initialized, backend = {}".format(dist_backend)) my_rank = dist.get_rank() # Global! world_size = dist.get_world_size() # Global! @@ -220,8 +223,8 @@ def initialize_distributed_compute( # make sure that we have atmost 1 remote buffer per training rank. assert num_training_ranks >= num_remote_buffers - print("num_train = ", num_training_ranks) - print("num_remote_buffers = ", num_remote_buffers) + logger.info("num_train = {}".format(num_training_ranks)) + logger.info("num_remote_buffers = {}".format(num_remote_buffers)) # for now, let us enforce that each agent gets equal number of ranks. # TODO: later, we can relax this condition. @@ -231,7 +234,7 @@ def initialize_distributed_compute( list(range(i * agent_group_size, (i + 1) * agent_group_size)) for i in range(num_agent_groups) ] - print(f"Agent group ranks: {agent_group_rank_list}") + logger.info(f"Agent group ranks: {agent_group_rank_list}") agent_group_list = [ cast( dist.ProcessGroup, @@ -266,7 +269,7 @@ def initialize_distributed_compute( backend=dist_backend, timeout=datetime.timedelta(minutes=5), ) - print(f"Buffer group ranks: {buffer_ranks}") + logger.info(f"Buffer group ranks: {buffer_ranks}") # Each training rank gets assigned to a buffer rank if my_rank < (num_training_ranks): @@ -278,14 +281,14 @@ def initialize_distributed_compute( if (ranks % num_remote_buffers) == (my_rank - num_training_ranks) ] - print(f"+ My rank: {my_rank} size: {world_size}") + logger.info(f"+ My rank: {my_rank} size: {world_size}") if my_rank < (num_training_ranks): - print(f" -> Training group, assigned buffer rank = {assigned_buffer}") + logger.info(f" -> Training group, assigned buffer rank = {assigned_buffer}") else: - print(" -> Buffer group") + logger.info(" -> Buffer group") dist.barrier() - print("+ Distributed compute initialized, rank = ", my_rank) + logger.info("+ Distributed compute initialized, rank = {}".format(my_rank)) return DistributedContext( my_rank=my_rank, @@ -321,7 +324,7 @@ def gather_distributed_data( On other ranks: None """ if verbose: - print("syncing distributed data") + logger.info("syncing distributed data") if world_size is None: world_size = dist.get_world_size() @@ -346,8 +349,8 @@ def gather_distributed_data( batch_size_list = None if verbose: - print("rank={}, batch_size_list={}".format(rank, batch_size_list)) - print( + logger.info("rank={}, batch_size_list={}".format(rank, batch_size_list)) + logger.info( "+ gather of local_batch_size={} to batch_size_list".format(local_batch_size) ) dist.gather( @@ -357,7 +360,7 @@ def gather_distributed_data( # Pad local tensor to maximum size. if verbose: - print("+ padding local tensor") + logger.info("+ padding local tensor") if rank == 0: assert batch_size_list is not None @@ -394,8 +397,8 @@ def gather_distributed_data( tensor_list = None if verbose: - print("+ gathering all tensors from world_size={}".format(world_size)) - print("rank={}, tensor_list={}".format(rank, tensor_list)) + logger.info("+ gathering all tensors from world_size={}".format(world_size)) + logger.info("rank={}, tensor_list={}".format(rank, tensor_list)) dist.gather(local_tensor, gather_list=tensor_list, dst=0, group=training_group) dist.barrier(group=training_group) # Add synchronization @@ -409,10 +412,10 @@ def gather_distributed_data( results.append(trimmed_tensor) if verbose: - print("distributed n_results={}".format(len(results))) + logger.info("distributed n_results={}".format(len(results))) for r in results: - print(" {}".format(r.shape)) + logger.info(" {}".format(r.shape)) return torch.cat(results, dim=0) # Concatenates along the batch dimension. diff --git a/src/gfn/utils/handlers.py b/src/gfn/utils/handlers.py index 74e3905e..c62d2a31 100644 --- a/src/gfn/utils/handlers.py +++ b/src/gfn/utils/handlers.py @@ -1,9 +1,11 @@ -import warnings +import logging from contextlib import contextmanager from typing import Any from gfn.containers import Container +logger = logging.getLogger(__name__) + @contextmanager def has_conditions_exception_handler( @@ -19,8 +21,8 @@ def has_conditions_exception_handler( try: yield except TypeError as e: - print(f"conditions was passed but {target_name} is {type(target)}") - print(f"error: {str(e)}") + logger.error(f"conditions was passed but {target_name} is {type(target)}") + logger.error(f"error: {str(e)}") raise @@ -38,8 +40,8 @@ def no_conditions_exception_handler( try: yield except TypeError as e: - print(f"conditions was not passed but {target_name} is {type(target)}") - print(f"error: {str(e)}") + logger.error(f"conditions was not passed but {target_name} is {type(target)}") + logger.error(f"error: {str(e)}") raise @@ -57,7 +59,7 @@ def is_callable_exception_handler( try: yield except: # noqa - print(f"conditions was passed but {target_name} is not callable: {type(target)}") + logger.error(f"conditions was passed but {target_name} is not callable: {type(target)}") raise @@ -72,7 +74,7 @@ def warn_about_recalculating_logprobs( recalculate_all_logprobs: Whether to recalculate all logprobs. """ if recalculate_all_logprobs and obj.has_log_probs: - warnings.warn( + logger.warning( "Recalculating logprobs for a container that already has them. " "This might be intended, if the log_probs were calculated off-policy. " "However, this is inefficient when training on-policy. In this case, " diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 72697457..1ece17a4 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -1,5 +1,6 @@ """This file contains some examples of modules that can be used with GFN.""" +import logging import math from abc import ABC, abstractmethod from typing import Literal, Optional @@ -16,6 +17,8 @@ from gfn.utils.common import is_int_dtype from gfn.utils.graphs import GeometricBatch, get_edge_indices +logger = logging.getLogger(__name__) + ACTIVATION_FNS = { "relu": nn.ReLU, "leaky_relu": nn.LeakyReLU, @@ -1742,7 +1745,7 @@ def forward( # TODO: learn variance, lp, clipping, ... if torch.isnan(out).any(): - print("+ out has {} nans".format(torch.isnan(out).sum())) + logger.warning("+ out has {} nans".format(torch.isnan(out).sum())) out = torch.nan_to_num(out) return out diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index b617a93b..5679cbe1 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -1,4 +1,4 @@ -import warnings +import logging from typing import Dict, Iterable, Optional, Tuple import torch @@ -11,17 +11,17 @@ from gfn.samplers import Trajectories from gfn.states import DiscreteStates +logger = logging.getLogger(__name__) + def get_terminating_state_dist( env: DiscreteEnv, states: DiscreteStates, ) -> torch.Tensor: """[DEPRECATED] Use `env.get_terminating_state_dist(states)` instead.""" - warnings.warn( + logger.warning( "gfn.utils.training.get_terminating_state_dist is deprecated; " - "use DiscreteEnv.get_terminating_state_dist(states) instead.", - DeprecationWarning, - stacklevel=2, + "use DiscreteEnv.get_terminating_state_dist(states) instead." ) return env.get_terminating_state_dist(states) @@ -33,10 +33,8 @@ def validate( visited_terminating_states: Optional[DiscreteStates] = None, ) -> Tuple[Dict[str, float], DiscreteStates | None]: """[DEPRECATED] Use `env.validate(gflownet, ...)` instead.""" - warnings.warn( - "gfn.utils.training.validate is deprecated; use DiscreteEnv.validate(...) instead.", - DeprecationWarning, - stacklevel=2, + logger.warning( + "gfn.utils.training.validate is deprecated; use DiscreteEnv.validate(...) instead." ) return env.validate( gflownet=gflownet, From 6042dc9d6d6cda0b530640f6cff99682edf8d7cb Mon Sep 17 00:00:00 2001 From: younik Date: Thu, 29 Jan 2026 12:17:21 +0100 Subject: [PATCH 2/3] fixes --- src/gfn/gym/diffusion_sampling.py | 6 ++++-- src/gfn/states.py | 2 +- src/gfn/utils/handlers.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 394e664b..53c117dd 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -1,6 +1,6 @@ +import logging import math import os -import logging from abc import ABC, abstractmethod from contextlib import nullcontext from typing import Any, cast @@ -723,7 +723,9 @@ def __init__( {**default_kwargs, **(target_kwargs or {})}, ) logger.info("DiffusionSampling:") - logger.info(f"+ Initalizing target {target_cls.__name__} with kwargs: {merged_kwargs}") + logger.info( + f"+ Initalizing target {target_cls.__name__} with kwargs: {merged_kwargs}" + ) self.target = target_cls(device=device, **merged_kwargs) self.dim = self.target.dim diff --git a/src/gfn/states.py b/src/gfn/states.py index b85fe7fb..f679d34c 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -830,7 +830,7 @@ def extend(self, other: DiscreteStates) -> None: else: # Inconsistent, raise a warning and set to None if self.conditions is not None or other.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when extending states. Setting to None." ) self.conditions = None diff --git a/src/gfn/utils/handlers.py b/src/gfn/utils/handlers.py index c62d2a31..fd5aaecc 100644 --- a/src/gfn/utils/handlers.py +++ b/src/gfn/utils/handlers.py @@ -59,7 +59,9 @@ def is_callable_exception_handler( try: yield except: # noqa - logger.error(f"conditions was passed but {target_name} is not callable: {type(target)}") + logger.error( + f"conditions was passed but {target_name} is not callable: {type(target)}" + ) raise From 1d75862921c8c46c79126f9148b92689c19e2c4d Mon Sep 17 00:00:00 2001 From: younik Date: Thu, 29 Jan 2026 12:32:00 +0100 Subject: [PATCH 3/3] fix ci --- src/gfn/utils/distributed.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/gfn/utils/distributed.py b/src/gfn/utils/distributed.py index 08bbf9b2..7be84a52 100644 --- a/src/gfn/utils/distributed.py +++ b/src/gfn/utils/distributed.py @@ -255,10 +255,13 @@ def initialize_distributed_compute( training_ranks = [ r for r in range(num_training_ranks) ] # e.g., 0..num_training_ranks-1 - train_global_group = dist.new_group( - ranks=training_ranks, - backend=dist_backend, - timeout=datetime.timedelta(minutes=5), + train_global_group = cast( + Optional[dist.ProcessGroup], + dist.new_group( + ranks=training_ranks, + backend=dist_backend, + timeout=datetime.timedelta(minutes=5), + ), ) buffer_group = None @@ -268,10 +271,13 @@ def initialize_distributed_compute( buffer_ranks = list( range(num_training_ranks, num_training_ranks + num_remote_buffers) ) - buffer_group = dist.new_group( - buffer_ranks, - backend=dist_backend, - timeout=datetime.timedelta(minutes=5), + buffer_group = cast( + Optional[dist.ProcessGroup], + dist.new_group( + buffer_ranks, + backend=dist_backend, + timeout=datetime.timedelta(minutes=5), + ), ) logger.info("Buffer group ranks: %s", buffer_ranks)