Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/gfn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import importlib.metadata as met
import logging

__version__ = met.version("torchgfn")

logging.getLogger(__name__).addHandler(logging.NullHandler())
21 changes: 13 additions & 8 deletions src/gfn/gym/diffusion_sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
import os
from abc import ABC, abstractmethod
Expand All @@ -16,6 +17,8 @@
from gfn.states import States
from gfn.utils.common import filter_kwargs_for_callable, temporarily_set_seed

logger = logging.getLogger(__name__)

# Lightweight typing alias for the target registry entries.
TargetEntry = tuple[type["BaseTarget"], dict[str, Any]]

Expand Down Expand Up @@ -215,17 +218,17 @@ def __init__(
)
mixture_weights = mixture_weights / mixture_weights.sum()

print("+ Gaussian Mixture Target initialization:")
print("+ num_components: ", num_components)
print("+ mixture_weights: ", mixture_weights)
logger.info("+ Gaussian Mixture Target initialization:")
logger.info(f"+ num_components: {num_components}")
logger.info(f"+ mixture_weights: {mixture_weights}")
for i, (loc, cov) in enumerate(zip(locs, covariances)):
loc_str = np.array2string(loc, precision=2, separator=", ").replace(
"\n", " "
)
cov_str = np.array2string(cov, precision=2, separator=", ").replace(
"\n", " "
)
print(f"\tComponent {i+1}: loc={loc_str}, cov={cov_str}")
logger.info(f"\tComponent {i+1}: loc={loc_str}, cov={cov_str}")

# Convert to torch tensors
dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -719,8 +722,10 @@ def __init__(
target_cls.__init__,
{**default_kwargs, **(target_kwargs or {})},
)
print("DiffusionSampling:")
print(f"+ Initalizing target {target_cls.__name__} with kwargs: {merged_kwargs}")
logger.info("DiffusionSampling:")
logger.info(
f"+ Initalizing target {target_cls.__name__} with kwargs: {merged_kwargs}"
)
self.target = target_cls(device=device, **merged_kwargs)

self.dim = self.target.dim
Expand Down Expand Up @@ -779,9 +784,9 @@ def list_available_targets(cls) -> dict[str, dict[str, Any]]:
target class's constructor signature.
"""
out = {}
print("Available DiffusionSampling targets:")
logger.info("Available DiffusionSampling targets:")
for alias, (cls, defaults) in cls.DIFFUSION_TARGETS.items():
print(f"+ {alias}: {cls.__name__} with kwargs: {defaults}")
logger.info(f"+ {alias}: {cls.__name__} with kwargs: {defaults}")
out[alias] = {"class": cls.__name__, "defaults": dict(defaults)}

return out
Expand Down
12 changes: 7 additions & 5 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Adapted from https://github.com/Tikquuss/GflowNets_Tutorial"""

import itertools
import logging
import multiprocessing
import platform
import warnings
from abc import ABC, abstractmethod
from decimal import Decimal
from functools import reduce
Expand All @@ -17,6 +17,8 @@
from gfn.env import DiscreteEnv
from gfn.states import DiscreteStates

logger = logging.getLogger(__name__)

if platform.system() == "Windows":
multiprocessing.set_start_method("spawn", force=True)
else:
Expand Down Expand Up @@ -86,7 +88,7 @@ def __init__(
debug: If True, emit States with debug guards (not compile-friendly).
"""
if height <= 4:
warnings.warn("+ Warning: height <= 4 can lead to unsolvable environments.")
logger.warning("+ Warning: height <= 4 can lead to unsolvable environments.")

reward_functions = {
"original": OriginalReward,
Expand Down Expand Up @@ -123,10 +125,10 @@ def __init__(

if self.store_all_states:
assert self._all_states_tensor is not None
print(f"+ Environment has {len(self._all_states_tensor)} states")
logger.info(f"+ Environment has {len(self._all_states_tensor)} states")
if self.calculate_partition:
assert self._log_partition is not None
print(f"+ Environment log partition is {self._log_partition}")
logger.info(f"+ Environment log partition is {self._log_partition}")

if isinstance(device, str):
device = torch.device(device)
Expand Down Expand Up @@ -414,7 +416,7 @@ def _enumerate_all_states_tensor(self, batch_size: int = 20_000):
total_rewards += self.reward_fn(batch_tensor).sum().item()
end_time = time()

print(
logger.info(
"Enumerated all states in {} minutes".format(
(end_time - start_time) / 60.0,
)
Expand Down
10 changes: 5 additions & 5 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
import logging

warnings.warn(
logger = logging.getLogger(__name__)

logger.warning(
"'modules.py' is deprecated and will be removed in a future release. Please import "
"from 'estimators.py' instead.",
DeprecationWarning,
stacklevel=2,
"from 'estimators.py' instead."
)

from gfn.estimators import ( # noqa: F401, E402
Expand Down
16 changes: 7 additions & 9 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations # This allows to use the class name in type hints

import inspect
import warnings
import logging
from abc import ABC
from math import prod
from typing import (
Expand All @@ -25,9 +25,7 @@
from gfn.utils.common import ensure_same_device
from gfn.utils.graphs import GeometricBatch, get_edge_indices

warnings.filterwarnings(
"once", message="Inconsistent conditions when extending states. Setting to None."
)
logger = logging.getLogger(__name__)


def _assert_factory_accepts_debug(factory: Callable, factory_name: str) -> None:
Expand Down Expand Up @@ -373,7 +371,7 @@ def __setitem__(
self.conditions[index] = states.conditions
else:
if self.conditions is not None or states.conditions is not None:
warnings.warn(
logger.warning(
"Inconsistent conditions when setting states. Setting to None."
)
self.conditions = None
Expand Down Expand Up @@ -439,7 +437,7 @@ def extend(self, other: States) -> None:
)
else:
if self.conditions is not None or other.conditions is not None:
warnings.warn(
logger.warning(
"Inconsistent conditions when extending states. Setting to None."
)
self.conditions = None
Expand Down Expand Up @@ -832,7 +830,7 @@ def extend(self, other: DiscreteStates) -> None:
else:
# Inconsistent, raise a warning and set to None
if self.conditions is not None or other.conditions is not None:
warnings.warn(
logger.warning(
"Inconsistent conditions when extending states. Setting to None."
)
self.conditions = None
Expand Down Expand Up @@ -1495,7 +1493,7 @@ def __setitem__(
self.conditions[index] = graph.conditions
else:
if self.conditions is not None or graph.conditions is not None:
warnings.warn(
logger.warning(
"Inconsistent conditions when setting states. Setting to None."
)
self.conditions = None
Expand Down Expand Up @@ -1600,7 +1598,7 @@ def extend(self, other: GraphStates):
)
else:
if self.conditions is not None or other.conditions is not None:
warnings.warn(
logger.warning(
"Inconsistent conditions when extending states. Setting to None."
)
self.conditions = None
Expand Down
9 changes: 5 additions & 4 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import inspect
import logging
import os
import random
import threading
import time
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Tuple

import numpy as np
import torch
import torch.distributed as dist

logger = logging.getLogger(__name__)

# -----------------------------------------------------------------------------
# Utility helpers
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -249,9 +251,8 @@ def set_seed(seed: int, deterministic_mode: bool = False) -> None:
torch.use_deterministic_algorithms(True)
except AttributeError:
# Older PyTorch (<1.8) fallback: do nothing.
warnings.warn(
"PyTorch is older than 1.8, deterministic algorithms are not supported.",
UserWarning,
logger.warning(
"PyTorch is older than 1.8, deterministic algorithms are not supported."
)

# CPU-specific settings for non-distributed case
Expand Down
5 changes: 4 additions & 1 deletion src/gfn/utils/compile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import logging
from typing import Iterable

import torch

logger = logging.getLogger(__name__)


def try_compile_gflownet(
gfn,
Expand Down Expand Up @@ -52,4 +55,4 @@ def try_compile_gflownet(
formatted = ", ".join(
f"{name}:{'✓' if success else 'x'}" for name, success in results.items()
)
print(f"[compile] {formatted}")
logger.info(f"[compile] {formatted}")
22 changes: 14 additions & 8 deletions src/gfn/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,13 @@ def initialize_distributed_compute(
training_ranks = [
r for r in range(num_training_ranks)
] # e.g., 0..num_training_ranks-1
train_global_group = dist.new_group(
ranks=training_ranks,
backend=dist_backend,
timeout=datetime.timedelta(minutes=5),
train_global_group = cast(
Optional[dist.ProcessGroup],
dist.new_group(
ranks=training_ranks,
backend=dist_backend,
timeout=datetime.timedelta(minutes=5),
),
)

buffer_group = None
Expand All @@ -268,10 +271,13 @@ def initialize_distributed_compute(
buffer_ranks = list(
range(num_training_ranks, num_training_ranks + num_remote_buffers)
)
buffer_group = dist.new_group(
buffer_ranks,
backend=dist_backend,
timeout=datetime.timedelta(minutes=5),
buffer_group = cast(
Optional[dist.ProcessGroup],
dist.new_group(
buffer_ranks,
backend=dist_backend,
timeout=datetime.timedelta(minutes=5),
),
)
logger.info("Buffer group ranks: %s", buffer_ranks)

Expand Down
18 changes: 11 additions & 7 deletions src/gfn/utils/handlers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
import logging
from contextlib import contextmanager
from typing import Any

from gfn.containers import Container

logger = logging.getLogger(__name__)


@contextmanager
def has_conditions_exception_handler(
Expand All @@ -19,8 +21,8 @@ def has_conditions_exception_handler(
try:
yield
except TypeError as e:
print(f"conditions was passed but {target_name} is {type(target)}")
print(f"error: {str(e)}")
logger.error(f"conditions was passed but {target_name} is {type(target)}")
logger.error(f"error: {str(e)}")
raise


Expand All @@ -38,8 +40,8 @@ def no_conditions_exception_handler(
try:
yield
except TypeError as e:
print(f"conditions was not passed but {target_name} is {type(target)}")
print(f"error: {str(e)}")
logger.error(f"conditions was not passed but {target_name} is {type(target)}")
logger.error(f"error: {str(e)}")
raise


Expand All @@ -57,7 +59,9 @@ def is_callable_exception_handler(
try:
yield
except: # noqa
print(f"conditions was passed but {target_name} is not callable: {type(target)}")
logger.error(
f"conditions was passed but {target_name} is not callable: {type(target)}"
)
raise


Expand All @@ -72,7 +76,7 @@ def warn_about_recalculating_logprobs(
recalculate_all_logprobs: Whether to recalculate all logprobs.
"""
if recalculate_all_logprobs and obj.has_log_probs:
warnings.warn(
logger.warning(
"Recalculating logprobs for a container that already has them. "
"This might be intended, if the log_probs were calculated off-policy. "
"However, this is inefficient when training on-policy. In this case, "
Expand Down
16 changes: 7 additions & 9 deletions src/gfn/utils/training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import warnings
import logging
from typing import Dict, Iterable, Optional, Tuple

import torch
Expand All @@ -11,17 +11,17 @@
from gfn.samplers import Trajectories
from gfn.states import DiscreteStates

logger = logging.getLogger(__name__)


def get_terminating_state_dist(
env: DiscreteEnv,
states: DiscreteStates,
) -> torch.Tensor:
"""[DEPRECATED] Use `env.get_terminating_state_dist(states)` instead."""
warnings.warn(
logger.warning(
"gfn.utils.training.get_terminating_state_dist is deprecated; "
"use DiscreteEnv.get_terminating_state_dist(states) instead.",
DeprecationWarning,
stacklevel=2,
"use DiscreteEnv.get_terminating_state_dist(states) instead."
)
return env.get_terminating_state_dist(states)

Expand All @@ -33,10 +33,8 @@ def validate(
visited_terminating_states: Optional[DiscreteStates] = None,
) -> Tuple[Dict[str, float], DiscreteStates | None]:
"""[DEPRECATED] Use `env.validate(gflownet, ...)` instead."""
warnings.warn(
"gfn.utils.training.validate is deprecated; use DiscreteEnv.validate(...) instead.",
DeprecationWarning,
stacklevel=2,
logger.warning(
"gfn.utils.training.validate is deprecated; use DiscreteEnv.validate(...) instead."
)
return env.validate(
gflownet=gflownet,
Expand Down