diff --git a/src/gfn/__init__.py b/src/gfn/__init__.py index a521e9f6..1010066c 100644 --- a/src/gfn/__init__.py +++ b/src/gfn/__init__.py @@ -1,3 +1,6 @@ import importlib.metadata as met +import logging __version__ = met.version("torchgfn") + +logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 68fc5385..53c117dd 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -1,3 +1,4 @@ +import logging import math import os from abc import ABC, abstractmethod @@ -16,6 +17,8 @@ from gfn.states import States from gfn.utils.common import filter_kwargs_for_callable, temporarily_set_seed +logger = logging.getLogger(__name__) + # Lightweight typing alias for the target registry entries. TargetEntry = tuple[type["BaseTarget"], dict[str, Any]] @@ -215,9 +218,9 @@ def __init__( ) mixture_weights = mixture_weights / mixture_weights.sum() - print("+ Gaussian Mixture Target initialization:") - print("+ num_components: ", num_components) - print("+ mixture_weights: ", mixture_weights) + logger.info("+ Gaussian Mixture Target initialization:") + logger.info(f"+ num_components: {num_components}") + logger.info(f"+ mixture_weights: {mixture_weights}") for i, (loc, cov) in enumerate(zip(locs, covariances)): loc_str = np.array2string(loc, precision=2, separator=", ").replace( "\n", " " @@ -225,7 +228,7 @@ def __init__( cov_str = np.array2string(cov, precision=2, separator=", ").replace( "\n", " " ) - print(f"\tComponent {i+1}: loc={loc_str}, cov={cov_str}") + logger.info(f"\tComponent {i+1}: loc={loc_str}, cov={cov_str}") # Convert to torch tensors dtype = torch.get_default_dtype() @@ -719,8 +722,10 @@ def __init__( target_cls.__init__, {**default_kwargs, **(target_kwargs or {})}, ) - print("DiffusionSampling:") - print(f"+ Initalizing target {target_cls.__name__} with kwargs: {merged_kwargs}") + logger.info("DiffusionSampling:") + logger.info( + f"+ Initalizing target {target_cls.__name__} with kwargs: {merged_kwargs}" + ) self.target = target_cls(device=device, **merged_kwargs) self.dim = self.target.dim @@ -779,9 +784,9 @@ def list_available_targets(cls) -> dict[str, dict[str, Any]]: target class's constructor signature. """ out = {} - print("Available DiffusionSampling targets:") + logger.info("Available DiffusionSampling targets:") for alias, (cls, defaults) in cls.DIFFUSION_TARGETS.items(): - print(f"+ {alias}: {cls.__name__} with kwargs: {defaults}") + logger.info(f"+ {alias}: {cls.__name__} with kwargs: {defaults}") out[alias] = {"class": cls.__name__, "defaults": dict(defaults)} return out diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 9196aa11..f47d24dd 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -1,9 +1,9 @@ """Adapted from https://github.com/Tikquuss/GflowNets_Tutorial""" import itertools +import logging import multiprocessing import platform -import warnings from abc import ABC, abstractmethod from decimal import Decimal from functools import reduce @@ -17,6 +17,8 @@ from gfn.env import DiscreteEnv from gfn.states import DiscreteStates +logger = logging.getLogger(__name__) + if platform.system() == "Windows": multiprocessing.set_start_method("spawn", force=True) else: @@ -86,7 +88,7 @@ def __init__( debug: If True, emit States with debug guards (not compile-friendly). """ if height <= 4: - warnings.warn("+ Warning: height <= 4 can lead to unsolvable environments.") + logger.warning("+ Warning: height <= 4 can lead to unsolvable environments.") reward_functions = { "original": OriginalReward, @@ -123,10 +125,10 @@ def __init__( if self.store_all_states: assert self._all_states_tensor is not None - print(f"+ Environment has {len(self._all_states_tensor)} states") + logger.info(f"+ Environment has {len(self._all_states_tensor)} states") if self.calculate_partition: assert self._log_partition is not None - print(f"+ Environment log partition is {self._log_partition}") + logger.info(f"+ Environment log partition is {self._log_partition}") if isinstance(device, str): device = torch.device(device) @@ -414,7 +416,7 @@ def _enumerate_all_states_tensor(self, batch_size: int = 20_000): total_rewards += self.reward_fn(batch_tensor).sum().item() end_time = time() - print( + logger.info( "Enumerated all states in {} minutes".format( (end_time - start_time) / 60.0, ) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index fe847cf7..c8e4f797 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,10 +1,10 @@ -import warnings +import logging -warnings.warn( +logger = logging.getLogger(__name__) + +logger.warning( "'modules.py' is deprecated and will be removed in a future release. Please import " - "from 'estimators.py' instead.", - DeprecationWarning, - stacklevel=2, + "from 'estimators.py' instead." ) from gfn.estimators import ( # noqa: F401, E402 diff --git a/src/gfn/states.py b/src/gfn/states.py index ce881a93..f679d34c 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -1,7 +1,7 @@ from __future__ import annotations # This allows to use the class name in type hints import inspect -import warnings +import logging from abc import ABC from math import prod from typing import ( @@ -25,9 +25,7 @@ from gfn.utils.common import ensure_same_device from gfn.utils.graphs import GeometricBatch, get_edge_indices -warnings.filterwarnings( - "once", message="Inconsistent conditions when extending states. Setting to None." -) +logger = logging.getLogger(__name__) def _assert_factory_accepts_debug(factory: Callable, factory_name: str) -> None: @@ -373,7 +371,7 @@ def __setitem__( self.conditions[index] = states.conditions else: if self.conditions is not None or states.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when setting states. Setting to None." ) self.conditions = None @@ -439,7 +437,7 @@ def extend(self, other: States) -> None: ) else: if self.conditions is not None or other.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when extending states. Setting to None." ) self.conditions = None @@ -832,7 +830,7 @@ def extend(self, other: DiscreteStates) -> None: else: # Inconsistent, raise a warning and set to None if self.conditions is not None or other.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when extending states. Setting to None." ) self.conditions = None @@ -1495,7 +1493,7 @@ def __setitem__( self.conditions[index] = graph.conditions else: if self.conditions is not None or graph.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when setting states. Setting to None." ) self.conditions = None @@ -1600,7 +1598,7 @@ def extend(self, other: GraphStates): ) else: if self.conditions is not None or other.conditions is not None: - warnings.warn( + logger.warning( "Inconsistent conditions when extending states. Setting to None." ) self.conditions = None diff --git a/src/gfn/utils/common.py b/src/gfn/utils/common.py index 5c7fd44d..dc4ec157 100644 --- a/src/gfn/utils/common.py +++ b/src/gfn/utils/common.py @@ -1,9 +1,9 @@ import inspect +import logging import os import random import threading import time -import warnings from contextlib import contextmanager from typing import Any, Callable, Tuple @@ -11,6 +11,8 @@ import torch import torch.distributed as dist +logger = logging.getLogger(__name__) + # ----------------------------------------------------------------------------- # Utility helpers # ----------------------------------------------------------------------------- @@ -249,9 +251,8 @@ def set_seed(seed: int, deterministic_mode: bool = False) -> None: torch.use_deterministic_algorithms(True) except AttributeError: # Older PyTorch (<1.8) fallback: do nothing. - warnings.warn( - "PyTorch is older than 1.8, deterministic algorithms are not supported.", - UserWarning, + logger.warning( + "PyTorch is older than 1.8, deterministic algorithms are not supported." ) # CPU-specific settings for non-distributed case diff --git a/src/gfn/utils/compile.py b/src/gfn/utils/compile.py index d772a503..b5263152 100644 --- a/src/gfn/utils/compile.py +++ b/src/gfn/utils/compile.py @@ -1,9 +1,12 @@ from __future__ import annotations +import logging from typing import Iterable import torch +logger = logging.getLogger(__name__) + def try_compile_gflownet( gfn, @@ -52,4 +55,4 @@ def try_compile_gflownet( formatted = ", ".join( f"{name}:{'✓' if success else 'x'}" for name, success in results.items() ) - print(f"[compile] {formatted}") + logger.info(f"[compile] {formatted}") diff --git a/src/gfn/utils/distributed.py b/src/gfn/utils/distributed.py index 08bbf9b2..7be84a52 100644 --- a/src/gfn/utils/distributed.py +++ b/src/gfn/utils/distributed.py @@ -255,10 +255,13 @@ def initialize_distributed_compute( training_ranks = [ r for r in range(num_training_ranks) ] # e.g., 0..num_training_ranks-1 - train_global_group = dist.new_group( - ranks=training_ranks, - backend=dist_backend, - timeout=datetime.timedelta(minutes=5), + train_global_group = cast( + Optional[dist.ProcessGroup], + dist.new_group( + ranks=training_ranks, + backend=dist_backend, + timeout=datetime.timedelta(minutes=5), + ), ) buffer_group = None @@ -268,10 +271,13 @@ def initialize_distributed_compute( buffer_ranks = list( range(num_training_ranks, num_training_ranks + num_remote_buffers) ) - buffer_group = dist.new_group( - buffer_ranks, - backend=dist_backend, - timeout=datetime.timedelta(minutes=5), + buffer_group = cast( + Optional[dist.ProcessGroup], + dist.new_group( + buffer_ranks, + backend=dist_backend, + timeout=datetime.timedelta(minutes=5), + ), ) logger.info("Buffer group ranks: %s", buffer_ranks) diff --git a/src/gfn/utils/handlers.py b/src/gfn/utils/handlers.py index 74e3905e..fd5aaecc 100644 --- a/src/gfn/utils/handlers.py +++ b/src/gfn/utils/handlers.py @@ -1,9 +1,11 @@ -import warnings +import logging from contextlib import contextmanager from typing import Any from gfn.containers import Container +logger = logging.getLogger(__name__) + @contextmanager def has_conditions_exception_handler( @@ -19,8 +21,8 @@ def has_conditions_exception_handler( try: yield except TypeError as e: - print(f"conditions was passed but {target_name} is {type(target)}") - print(f"error: {str(e)}") + logger.error(f"conditions was passed but {target_name} is {type(target)}") + logger.error(f"error: {str(e)}") raise @@ -38,8 +40,8 @@ def no_conditions_exception_handler( try: yield except TypeError as e: - print(f"conditions was not passed but {target_name} is {type(target)}") - print(f"error: {str(e)}") + logger.error(f"conditions was not passed but {target_name} is {type(target)}") + logger.error(f"error: {str(e)}") raise @@ -57,7 +59,9 @@ def is_callable_exception_handler( try: yield except: # noqa - print(f"conditions was passed but {target_name} is not callable: {type(target)}") + logger.error( + f"conditions was passed but {target_name} is not callable: {type(target)}" + ) raise @@ -72,7 +76,7 @@ def warn_about_recalculating_logprobs( recalculate_all_logprobs: Whether to recalculate all logprobs. """ if recalculate_all_logprobs and obj.has_log_probs: - warnings.warn( + logger.warning( "Recalculating logprobs for a container that already has them. " "This might be intended, if the log_probs were calculated off-policy. " "However, this is inefficient when training on-policy. In this case, " diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index b617a93b..5679cbe1 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -1,4 +1,4 @@ -import warnings +import logging from typing import Dict, Iterable, Optional, Tuple import torch @@ -11,17 +11,17 @@ from gfn.samplers import Trajectories from gfn.states import DiscreteStates +logger = logging.getLogger(__name__) + def get_terminating_state_dist( env: DiscreteEnv, states: DiscreteStates, ) -> torch.Tensor: """[DEPRECATED] Use `env.get_terminating_state_dist(states)` instead.""" - warnings.warn( + logger.warning( "gfn.utils.training.get_terminating_state_dist is deprecated; " - "use DiscreteEnv.get_terminating_state_dist(states) instead.", - DeprecationWarning, - stacklevel=2, + "use DiscreteEnv.get_terminating_state_dist(states) instead." ) return env.get_terminating_state_dist(states) @@ -33,10 +33,8 @@ def validate( visited_terminating_states: Optional[DiscreteStates] = None, ) -> Tuple[Dict[str, float], DiscreteStates | None]: """[DEPRECATED] Use `env.validate(gflownet, ...)` instead.""" - warnings.warn( - "gfn.utils.training.validate is deprecated; use DiscreteEnv.validate(...) instead.", - DeprecationWarning, - stacklevel=2, + logger.warning( + "gfn.utils.training.validate is deprecated; use DiscreteEnv.validate(...) instead." ) return env.validate( gflownet=gflownet,