diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index c34bcc908..b01361562 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -641,3 +641,185 @@ def _batch_size_from_data(self, data: Mapping[str, any]) -> int: inference variables as present. """ return keras.ops.shape(data["inference_variables"])[0] + + def compositional_sample( + self, + *, + num_samples: int, + conditions: Mapping[str, np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], + split: bool = False, + **kwargs, + ) -> dict[str, np.ndarray]: + """ + Generates compositional samples from the approximator given input conditions. + The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...). + This method handles the extra compositional dimension appropriately. + + Parameters + ---------- + num_samples : int + Number of samples to generate. + conditions : dict[str, np.ndarray] + Dictionary of conditioning variables as NumPy arrays with shape + (n_datasets, n_compositional_conditions, ...). + compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + A function that computes the score of the log prior distribution. + split : bool, default=False + Whether to split the output arrays along the last axis and return one column vector per target variable + samples. + **kwargs : dict + Additional keyword arguments for the adapter and sampling process. + + Returns + ------- + dict[str, np.ndarray] + Dictionary containing generated samples with compositional structure preserved. + """ + original_shapes = {} + flattened_conditions = {} + for key, value in conditions.items(): # Flatten compositional dimensions + original_shapes[key] = value.shape + n_datasets, n_comp = value.shape[:2] + flattened_shape = (n_datasets * n_comp,) + value.shape[2:] + flattened_conditions[key] = value.reshape(flattened_shape) + n_datasets, n_comp = original_shapes[next(iter(original_shapes))][:2] + + # Prepare data using existing method (handles adaptation and standardization) + prepared_conditions = self._prepare_data(flattened_conditions, **kwargs) + + # Remove any superfluous keys, just retain actual conditions + prepared_conditions = {k: v for k, v in prepared_conditions.items() if k in self.CONDITION_KEYS} + + # Prepare prior scores to handle adapter + def compute_prior_score_pre(_samples: Tensor) -> Tensor: + if "inference_variables" in self.standardize: + _samples = self.standardize_layers["inference_variables"](_samples, forward=False) + _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) + adapted_samples, log_det_jac = self.adapter( + _samples, inverse=True, strict=False, log_det_jac=True, **kwargs + ) + + if len(log_det_jac) > 0: + problematic_keys = [key for key in log_det_jac if log_det_jac[key] != 0.0] + raise NotImplementedError( + f"Cannot use compositional sampling with adapters " + f"that have non-zero log_det_jac. Problematic keys: {problematic_keys}" + ) + + prior_score = compute_prior_score(adapted_samples) + for key in adapted_samples: + prior_score[key] = prior_score[key].astype(np.float32) + + prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) + out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) + + if "inference_variables" in self.standardize: + # Apply jacobian correction from standardization + # For standardization T^{-1}(z) = z * std + mean, the jacobian is diagonal with std on diagonal + # The gradient of log|det(J)| w.r.t. z is 0 since log|det(J)| = sum(log(std)) is constant w.r.t. z + # But we need to transform the score: score_z = score_x * std where x = T^{-1}(z) + standardize_layer = self.standardize_layers["inference_variables"] + + # Compute the correct standard deviation for all components + std_components = [] + for idx in range(len(standardize_layer.moving_mean)): + std_val = standardize_layer.moving_std(idx) + std_components.append(std_val) + + # Concatenate std components to match the shape of out + if len(std_components) == 1: + std = std_components[0] + else: + std = keras.ops.concatenate(std_components, axis=-1) + + # Expand std to match batch dimension of out + std_expanded = keras.ops.expand_dims(std, (0, 1)) # Add batch, sample dimensions + std_expanded = keras.ops.tile(std_expanded, [n_datasets, num_samples, 1]) + + # Apply the jacobian: score_z = score_x * std + out = out * std_expanded + return out + + # Test prior score function, useful for debugging + test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) + test = compute_prior_score_pre(test) + if test.shape[:2] != (n_datasets, num_samples): + raise ValueError( + "The provided compute_prior_score function does not return the correct shape. " + f"Expected ({n_datasets}, {num_samples}, ...), got {test.shape}." + ) + + # Sample using compositional sampling + samples = self._compositional_sample( + num_samples=num_samples, + n_datasets=n_datasets, + n_compositional=n_comp, + compute_prior_score=compute_prior_score_pre, + **prepared_conditions, + **kwargs, + ) + + if "inference_variables" in self.standardize: + samples = self.standardize_layers["inference_variables"](samples, forward=False) + + samples = {"inference_variables": samples} + samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples) + + # Back-transform quantities and samples + samples = self.adapter(samples, inverse=True, strict=False, **kwargs) + + if split: + samples = split_arrays(samples, axis=-1) + return samples + + def _compositional_sample( + self, + num_samples: int, + n_datasets: int, + n_compositional: int, + compute_prior_score: Callable[[Tensor], Tensor], + inference_conditions: Tensor = None, + summary_variables: Tensor = None, + **kwargs, + ) -> Tensor: + """ + Internal method for compositional sampling. + """ + if self.summary_network is None: + if summary_variables is not None: + raise ValueError("Cannot use summary variables without a summary network.") + else: + if summary_variables is None: + raise ValueError("Summary variables are required when a summary network is present.") + + if self.summary_network is not None: + summary_outputs = self.summary_network( + summary_variables, **filter_kwargs(kwargs, self.summary_network.call) + ) + inference_conditions = concatenate_valid([inference_conditions, summary_outputs], axis=-1) + + if inference_conditions is not None: + # Reshape conditions for compositional sampling + # From (n_datasets * n_comp, ...., dims) to (n_datasets, n_comp, ...., dims) + condition_dims = keras.ops.shape(inference_conditions)[1:] + inference_conditions = keras.ops.reshape( + inference_conditions, (n_datasets, n_compositional, *condition_dims) + ) + + # Expand for num_samples: (n_datasets, n_comp, dims) -> (n_datasets, n_comp, num_samples, dims) + inference_conditions = keras.ops.expand_dims(inference_conditions, axis=2) + inference_conditions = keras.ops.broadcast_to( + inference_conditions, (n_datasets, n_compositional, num_samples, *condition_dims) + ) + + batch_shape = (n_datasets, num_samples) + else: + raise ValueError("Cannot perform compositional sampling without inference conditions.") + + return self.inference_network.sample( + batch_shape, + conditions=inference_conditions, + compute_prior_score=compute_prior_score, + **filter_kwargs(kwargs, self.inference_network.sample), + ) diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index f71d4b536..fb9819445 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -7,7 +7,7 @@ from .consistency_models import ConsistencyModel from .coupling_flow import CouplingFlow from .deep_set import DeepSet -from .diffusion_model import DiffusionModel +from .diffusion_model import DiffusionModel, CompositionalDiffusionModel from .flow_matching import FlowMatching from .inference_network import InferenceNetwork from .point_inference_network import PointInferenceNetwork diff --git a/bayesflow/networks/diffusion_model/__init__.py b/bayesflow/networks/diffusion_model/__init__.py index 341c84c62..ca8aa19be 100644 --- a/bayesflow/networks/diffusion_model/__init__.py +++ b/bayesflow/networks/diffusion_model/__init__.py @@ -1,4 +1,5 @@ from .diffusion_model import DiffusionModel +from .compositional_diffusion_model import CompositionalDiffusionModel from .schedules import CosineNoiseSchedule from .schedules import EDMNoiseSchedule from .schedules import NoiseSchedule diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py new file mode 100644 index 000000000..3d26639ab --- /dev/null +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -0,0 +1,405 @@ +from typing import Literal, Callable + +import keras +import numpy as np +from keras import ops + +from bayesflow.types import Tensor +from bayesflow.utils import expand_right_as, integrate, integrate_stochastic, STOCHASTIC_METHODS +from bayesflow.utils.serialization import serializable +from .diffusion_model import DiffusionModel +from .schedules.noise_schedule import NoiseSchedule + + +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) +class CompositionalDiffusionModel(DiffusionModel): + """Compositional Diffusion Model for Amortized Bayesian Inference. Allows to learn a single + diffusion model one single i.i.d simulations that can perform inference for multiple simulations by leveraging a + compositional score function as in [2]. + + [1] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021) + [2] Compositional Score Modeling for Simulation-Based Inference: Geffner et al. (2023) + [3] Compositional amortized inference for large-scale hierarchical Bayesian models: Arruda et al. (2025) + """ + + MLP_DEFAULT_CONFIG = { + "widths": (256, 256, 256, 256, 256), + "activation": "mish", + "kernel_initializer": "he_normal", + "residual": True, + "dropout": 0.05, + "spectral_normalization": False, + } + + INTEGRATE_DEFAULT_CONFIG = { + "method": "two_step_adaptive", + "steps": "adaptive", + } + + def __init__( + self, + *, + subnet: str | type | keras.Layer = "mlp", + noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm", + prediction_type: Literal["velocity", "noise", "F", "x"] = "F", + loss_type: Literal["velocity", "noise", "F"] = "noise", + subnet_kwargs: dict[str, any] = None, + schedule_kwargs: dict[str, any] = None, + integrate_kwargs: dict[str, any] = None, + **kwargs, + ): + """ + Initializes a diffusion model with configurable subnet architecture, noise schedule, + and prediction/loss types for amortized Bayesian inference. + + Note, that score-based diffusion is the most sluggish of all available samplers, + so expect slower inference times than flow matching and much slower than normalizing flows. + + Parameters + ---------- + subnet : str, type or keras.Layer, optional + Architecture for the transformation network. Can be "mlp", a custom network class, or + a Layer object, e.g., `bayesflow.networks.MLP(widths=[32, 32])`. Default is "mlp". + noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional + Noise schedule controlling the diffusion dynamics. Can be a string identifier, + a schedule class, or a pre-initialized schedule instance. Default is "edm". + prediction_type : {'velocity', 'noise', 'F', 'x'}, optional + Output format of the model's prediction. Default is "F". + loss_type : {'velocity', 'noise', 'F'}, optional + Loss function used to train the model. Default is "noise". + subnet_kwargs : dict[str, any], optional + Additional keyword arguments passed to the subnet constructor. Default is None. + schedule_kwargs : dict[str, any], optional + Additional keyword arguments passed to the noise schedule constructor. Default is None. + integrate_kwargs : dict[str, any], optional + Configuration dictionary for integration during training or inference. Default is None. + concatenate_subnet_input: bool, optional + Flag for advanced users to control whether all inputs to the subnet should be concatenated + into a single vector or passed as separate arguments. If set to False, the subnet + must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio), + and optional 'conditions'. Default is True. + + **kwargs + Additional keyword arguments passed to the base class and internal components. + """ + super().__init__( + subnet=subnet, + noise_schedule=noise_schedule, + prediction_type=prediction_type, + loss_type=loss_type, + subnet_kwargs=subnet_kwargs, + schedule_kwargs=schedule_kwargs, + integrate_kwargs=integrate_kwargs, + **kwargs, + ) + + def compositional_bridge(self, time: Tensor) -> Tensor: + """ + Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1. + Otherwise, it can be used to scale the compositional score over time. + + Parameters + ---------- + time: Tensor + Time step for the diffusion process. + + Returns + ------- + Tensor + Bridge function value with same shape as time. + + """ + return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time) + + def compositional_velocity( + self, + xz: Tensor, + time: float | Tensor, + stochastic_solver: bool, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional velocity for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + stochastic_solver : bool + Whether to use stochastic (SDE) or deterministic (ODE) formulation + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + compositional_score = self.compositional_score( + xz=xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + + # Compute velocity using standard drift-diffusion formulation + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) + + if stochastic_solver: + # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW + velocity = f - g_squared * compositional_score + else: + # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt + velocity = f - 0.5 * g_squared * compositional_score + + return velocity + + def compositional_score( + self, + xz: Tensor, + time: float | Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional score for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + if conditions is None: + raise ValueError("Conditions are required for compositional sampling") + + # Get shapes for compositional structure + n_compositional = ops.shape(conditions)[1] + + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + # Compute individual dataset scores + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + conditions_batch = conditions[:, mini_batch_idx] + else: + conditions_batch = conditions + individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training) + + # Compute prior score component + prior_score = compute_prior_score(xz) + weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score + + # Sum individual scores across compositional dimensions + summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) + + # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) + compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) + return compositional_score + + def _compute_individual_scores( + self, + xz: Tensor, + log_snr_t: Tensor, + conditions: Tensor, + training: bool, + ) -> Tensor: + """ + Compute individual dataset scores s_ψ(θ,t,yᵢ) for each compositional condition. + + Returns + ------- + Tensor + Individual scores with shape (n_datasets, n_compositional, ...) + """ + # Get shapes + xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) + conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) + n_datasets, n_compositional = conditions_shape[0], conditions_shape[1] + conditions_dims = tuple(conditions_shape[3:]) + num_samples = xz_shape[1] + dims = tuple(xz_shape[2:]) + + # Expand xz to match compositional structure + xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) + xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) + + # Expand log_snr_t to match compositional structure + log_snr_expanded = ops.expand_dims(log_snr_t, axis=1) + log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) + + # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims) + xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) + log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) + conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) + + # Use standard score function + scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training) + + # Reshape back to compositional structure + scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims) + return scores + + def _inverse_compositional( + self, + z: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Inverse pass for compositional diffusion sampling. + """ + n_compositional = ops.shape(conditions)[1] + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + if keras.backend.backend() == "jax": + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + if mini_batch_size is not None: + raise ValueError( + "Mini batching is not supported with JAX backend. Set mini_batch_size to None " + "or use another backend." + ) + else: + mini_batch_size = max(integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)), 1) + self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) + self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) + + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional + scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) + z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) + + if density: + if integrate_kwargs["method"] in STOCHASTIC_METHODS: + raise ValueError("Stochastic methods are not supported for density computation.") + + def deltas(time, xz): + v = self.compositional_velocity( + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) + return {"xz": v, "trace": trace} + + state = { + "xz": z, + "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), + } + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) - ops.squeeze(state["trace"], axis=-1) + return x, log_density + + state = {"xz": z} + + if integrate_kwargs["method"] in STOCHASTIC_METHODS: + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, + time=time, + stochastic_solver=True, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + def diffusion(time, xz): + return {"xz": self.diffusion_term(xz, time=time, training=training)} + + score_fn = None + if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin": + + def score_fn(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = integrate_stochastic( + drift_fn=deltas, + diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, + state=state, + seed=self.seed_generator, + **integrate_kwargs, + ) + else: + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + return x diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index ca8a634e9..8cbce1e87 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -16,6 +16,7 @@ integrate_stochastic, logging, tensor_utils, + STOCHASTIC_METHODS, ) from bayesflow.utils.serialization import serialize, deserialize, serializable @@ -39,13 +40,13 @@ class DiffusionModel(InferenceNetwork): "activation": "mish", "kernel_initializer": "he_normal", "residual": True, - "dropout": 0.0, + "dropout": 0.05, "spectral_normalization": False, } INTEGRATE_DEFAULT_CONFIG = { - "method": "rk45", - "steps": 100, + "method": "two_step_adaptive", + "steps": "adaptive", } def __init__( @@ -243,6 +244,55 @@ def _apply_subnet( else: return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training) + def score( + self, + xz: Tensor, + time: float | Tensor = None, + log_snr_t: Tensor = None, + conditions: Tensor = None, + training: bool = False, + ) -> Tensor: + """ + Computes the score of the target or latent variable `xz`. + + Parameters + ---------- + xz : Tensor + The current state of the latent variable `z`, typically of shape (..., D), + where D is the dimensionality of the latent space. + time : float or Tensor + Scalar or tensor representing the time (or noise level) at which the velocity + should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. + log_snr_t : Tensor + The log signal-to-noise ratio at time `t`. If None, time must be provided. + conditions : Tensor, optional + Conditional inputs to the network, such as conditioning variables + or encoder outputs. Shape must be broadcastable with `xz`. Default is None. + training : bool, optional + Whether the model is in training mode. Affects behavior of dropout, batch norm, + or other stochastic layers. Default is False. + + Returns + ------- + Tensor + The velocity tensor of the same shape as `xz`, representing the right-hand + side of the SDE or ODE at the given `time`. + """ + if log_snr_t is None: + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + subnet_out = self._apply_subnet( + xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) + + x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) + + score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + return score + def velocity( self, xz: Tensor, @@ -279,19 +329,10 @@ def velocity( The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ - # calculate the current noise level and transform into correct shape log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - subnet_out = self._apply_subnet( - xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training - ) - pred = self.output_projector(subnet_out, training=training) - - x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) - score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -368,7 +409,7 @@ def _forward( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for forward integration.") if density: @@ -418,7 +459,7 @@ def _inverse( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if density: - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): @@ -437,7 +478,7 @@ def deltas(time, xz): return x, log_density state = {"xz": z} - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { @@ -447,9 +488,24 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + score_fn = None + if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin": + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index fa74089a4..268fc6de4 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -23,6 +23,9 @@ class FlowMatching(InferenceNetwork): """(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated from [1-5]. + For optimal transport, the Sinkhorn algorithm is used to compute mini-batch optimal transport plans + between samples from the base distribution and the target distribution during training [6-8]. + [1] Liu et al. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow. arXiv preprint arXiv:2209.03003. [2] Lipman et al. (2022). Flow matching for generative modeling. @@ -33,6 +36,10 @@ class FlowMatching(InferenceNetwork): Advances in Neural Information Processing Systems, 36, 16837-16864. [5] Orsini et al. (2025). Flow matching posterior estimation for simulation-based atmospheric retrieval of exoplanets. IEEE Access. + [6] Nguyen et al. (2022) "Improving Mini-batch Optimal Transport via Partial Transportation" + [7] Cheng et al. (2025) "The Curse of Conditions: Analyzing and Improving Optimal Transport for + Conditional Flow-Based Generation" + [8] Fluri et al. (2024) "Improving Flow Matching for Simulation-Based Inference" """ MLP_DEFAULT_CONFIG = { @@ -49,12 +56,13 @@ class FlowMatching(InferenceNetwork): "regularization": 0.1, "max_steps": 100, "atol": 1e-5, - "rtol": 1e-4, + "partial_ot_factor": 1.0, # no partial OT + "conditional_ot_ratio": 0.01, # only used if conditions are provided } INTEGRATE_DEFAULT_CONFIG = { - "method": "rk45", - "steps": 100, + "method": "tsit5", + "steps": "adaptive", } def __init__( @@ -236,6 +244,7 @@ def f(x): def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = self.integrate_kwargs | kwargs if density: def deltas(time, xz): @@ -243,7 +252,7 @@ def deltas(time, xz): return {"xz": v, "trace": trace} state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))} - state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs) z = state["xz"] log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1) @@ -254,7 +263,7 @@ def deltas(time, xz): return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)} state = {"xz": x} - state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs) z = state["xz"] @@ -263,6 +272,7 @@ def deltas(time, xz): def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = self.integrate_kwargs | kwargs if density: def deltas(time, xz): @@ -270,7 +280,7 @@ def deltas(time, xz): return {"xz": v, "trace": trace} state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))} - state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1) @@ -281,7 +291,7 @@ def deltas(time, xz): return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)} state = {"xz": z} - state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs) x = state["xz"] @@ -311,16 +321,14 @@ def compute_metrics( # since the data is possibly noisy and may contain outliers, it is better # to possibly drop some samples from x1 than from x0 # in the marginal over multiple batches, this is not a problem - x0, x1, assignments = optimal_transport( + x0, x1, conditions, assignments = optimal_transport( x0, x1, + conditions=conditions, seed=self.seed_generator, **self.optimal_transport_kwargs, return_assignments=True, ) - if conditions is not None: - # conditions must be resampled along with x1 - conditions = keras.ops.take(conditions, assignments, axis=0) u = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator) # p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index b092ce2cb..9488f644d 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,3 +1,4 @@ +from typing import Callable import keras from bayesflow.types import Shape, Tensor @@ -27,11 +28,30 @@ def call( conditions: Tensor = None, inverse: bool = False, density: bool = False, + compute_prior_score: Callable[[Tensor], Tensor] = None, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: if inverse: + if compute_prior_score is not None: + return self._inverse_compositional( + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, + ) return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) + if compute_prior_score is not None: + return self._forward_compositional( + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, + ) return self._forward(xz, conditions=conditions, density=density, training=training, **kwargs) def _forward( @@ -44,6 +64,28 @@ def _inverse( ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError + def _forward_compositional( + self, + x: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + + def _inverse_compositional( + self, + z: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + @allow_batch_size def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor: samples = self.base_distribution.sample(batch_shape) diff --git a/bayesflow/simulators/sequential_simulator.py b/bayesflow/simulators/sequential_simulator.py index 21e1542e6..96ab0ead3 100644 --- a/bayesflow/simulators/sequential_simulator.py +++ b/bayesflow/simulators/sequential_simulator.py @@ -88,7 +88,7 @@ def _single_sample(self, batch_shape_ext, **kwargs) -> dict[str, np.ndarray]: return self.sample(batch_shape=(1, *tuple(batch_shape_ext)), **kwargs) def sample_parallel( - self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 0, **kwargs + self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 1, **kwargs ) -> dict[str, np.ndarray]: """ Sample in parallel from the sequential simulator. @@ -101,7 +101,7 @@ def sample_parallel( n_jobs : int, optional Number of parallel jobs. -1 uses all available cores. Default is -1. verbose : int, optional - Verbosity level for joblib. Default is 0 (no output). + Verbosity level for joblib. Default is 1 (minimal output). **kwargs Additional keyword arguments passed to each simulator. These may include previously sampled outputs used as inputs for subsequent simulators. diff --git a/bayesflow/simulators/simulator.py b/bayesflow/simulators/simulator.py index 00d3d84f3..53d54e455 100644 --- a/bayesflow/simulators/simulator.py +++ b/bayesflow/simulators/simulator.py @@ -95,3 +95,8 @@ def accept_all_predicate(x): return np.full((sample_size,), True) return self.rejection_sample(batch_shape, predicate=accept_all_predicate, sample_size=sample_size, **kwargs) + + def sample_parallel( + self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 1, **kwargs + ) -> dict[str, np.ndarray]: + raise NotImplementedError diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index a8d28a50a..25b7dd920 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -47,7 +47,7 @@ ) from .hparam_utils import find_batch_size, find_memory_budget -from .integrate import integrate, integrate_stochastic +from .integrate import integrate, integrate_stochastic, DETERMINISTIC_METHODS, STOCHASTIC_METHODS from .io import ( pickle_load, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b197ea975..2a2000593 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -1,7 +1,9 @@ from collections.abc import Callable, Sequence +from typing import Dict, Tuple, Optional from functools import partial import keras +from keras import backend as K import numpy as np from typing import Literal, Union @@ -11,128 +13,219 @@ from bayesflow.utils import filter_kwargs from bayesflow.utils.logging import warning -from . import logging +import logging ArrayLike = int | float | Tensor +StateDict = Dict[str, ArrayLike] + + +DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"] +STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] + + +def _check_all_nans(state: StateDict): + all_nans_flags = [] + for v in state.values(): + all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) + return keras.ops.all(keras.ops.stack(all_nans_flags)) def euler_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - use_adaptive_step_size: bool = False, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + **kwargs, +) -> Tuple[StateDict, ArrayLike, None, ArrayLike]: k1 = fn(time, **filter_kwargs(state, fn)) - if use_adaptive_step_size: - intermediate_state = state.copy() - for key, delta in k1.items(): - intermediate_state[key] = state[key] + step_size * delta - - k2 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) - - # check all keys are equal - if set(k1.keys()) != set(k2.keys()): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - - # compute next step size - intermediate_error = keras.ops.stack([keras.ops.norm(k2[key] - k1[key], ord=2, axis=-1) for key in k1]) - new_step_size = step_size * tolerance / (intermediate_error + 1e-9) - - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - - # consolidate step size - new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) - else: - new_step_size = step_size - - # apply updates new_state = state.copy() for key in k1.keys(): new_state[key] = state[key] + step_size * k1[key] - new_time = time + step_size - return new_state, new_time, new_step_size + return new_state, new_time, None, 0.0 + + +def add_scaled(state, ks, coeffs, h): + out = {} + for key, y in state.items(): + acc = keras.ops.zeros_like(y) + for c, k in zip(coeffs, ks): + acc = acc + c * k[key] + out[key] = y + h * acc + return out def rk45_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, - last_step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - use_adaptive_step_size: bool = False, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): - step_size = last_step_size - - k1 = fn(time, **filter_kwargs(state, fn)) - - intermediate_state = state.copy() - for key, delta in k1.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + step_size: ArrayLike, + k1: StateDict = None, + use_adaptive_step_size: bool = True, +) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: + """ + Dormand-Prince 5(4) method with embedded error estimation [1]. - k2 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + Dormand (1996), Numerical Methods for Differential Equations: A Computational Approach + """ + h = step_size + + if k1 is None: # reuse k1 if available + k1 = fn(time, **filter_kwargs(state, fn)) + k2 = fn(time + h * (1 / 5), **filter_kwargs(add_scaled(state, [k1], [1 / 5], h), fn)) + k3 = fn(time + h * (3 / 10), **filter_kwargs(add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h), fn)) + k4 = fn(time + h * (4 / 5), **filter_kwargs(add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h), fn)) + k5 = fn( + time + h * (8 / 9), + **filter_kwargs( + add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), fn + ), + ) + k6 = fn( + time + h, + **filter_kwargs( + add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), + fn, + ), + ) - intermediate_state = state.copy() - for key, delta in k2.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + # 5th order solution + new_state = {} + for key in k1.keys(): + new_state[key] = state[key] + h * ( + 35 / 384 * k1[key] + 500 / 1113 * k3[key] + 125 / 192 * k4[key] - 2187 / 6784 * k5[key] + 11 / 84 * k6[key] + ) - k3 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + new_time = time + h + if not use_adaptive_step_size: + return new_state, new_time, None, 0.0 - intermediate_state = state.copy() - for key, delta in k3.items(): - intermediate_state[key] = state[key] + step_size * delta + k7 = fn(time + h, **filter_kwargs(new_state, fn)) - k4 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) + # 4th order embedded solution + err_state = {} + for key in k1.keys(): + y4 = state[key] + h * ( + 5179 / 57600 * k1[key] + + 7571 / 16695 * k3[key] + + 393 / 640 * k4[key] + - 92097 / 339200 * k5[key] + + 187 / 2100 * k6[key] + + 1 / 40 * k7[key] + ) + err_state[key] = new_state[key] - y4 - if use_adaptive_step_size: - intermediate_state = state.copy() - for key, delta in k4.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) - k5 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + return new_state, new_time, k7, err - # check all keys are equal - if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5]): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - # compute next step size - intermediate_error = keras.ops.stack([keras.ops.norm(k5[key] - k4[key], ord=2, axis=-1) for key in k5.keys()]) - new_step_size = step_size * tolerance / (intermediate_error + 1e-9) +def tsit5_step( + fn: Callable, + state: StateDict, + time: ArrayLike, + step_size: ArrayLike, + k1: StateDict = None, + use_adaptive_step_size: bool = True, +) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: + """ + Implements a single step of the Tsitouras 5/4 Runge-Kutta method [1]. - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + [1] Tsitouras (2011), Runge--Kutta pairs of order 5(4) satisfying only the first column simplifying assumption + """ + h = step_size + + # Butcher tableau coefficients + c2 = 0.161 + c3 = 0.327 + c4 = 0.9 + c5 = 0.9800255409045097 + + if k1 is None: # reuse k1 if available + k1 = fn(time, **filter_kwargs(state, fn)) + k2 = fn(time + h * c2, **filter_kwargs(add_scaled(state, [k1], [0.161], h), fn)) + k3 = fn( + time + h * c3, **filter_kwargs(add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h), fn) + ) + k4 = fn( + time + h * c4, + **filter_kwargs( + add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h), fn + ), + ) + k5 = fn( + time + h * c5, + **filter_kwargs( + add_scaled( + state, + [k1, k2, k3, k4], + [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], + h, + ), + fn, + ), + ) + k6 = fn( + time + h, + **filter_kwargs( + add_scaled( + state, + [k1, k2, k3, k4, k5], + [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838], + h, + ), + fn, + ), + ) - # consolidate step size - new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) - else: - new_step_size = step_size + # 5th order solution: b coefficients + new_state = {} + for key in state.keys(): + new_state[key] = state[key] + h * ( + 0.09646076681806523 * k1[key] + + 0.01 * k2[key] + + 0.4798896504144996 * k3[key] + + 1.379008574103742 * k4[key] + - 3.290069515436081 * k5[key] + + 2.324710524099774 * k6[key] + ) - # apply updates - new_state = state.copy() - for key in k1.keys(): - new_state[key] = state[key] + (step_size / 6.0) * (k1[key] + 2.0 * k2[key] + 2.0 * k3[key] + k4[key]) + new_time = time + h + if not use_adaptive_step_size: + return new_state, new_time, None, 0.0 + + k7 = fn(time + h, **filter_kwargs(new_state, fn)) + + err_state = {} + for key in state.keys(): + err_state[key] = h * ( + -0.00178001105222577714 * k1[key] + - 0.0008164344596567469 * k2[key] + + 0.007880878010261995 * k3[key] + - 0.1447110071732629 * k4[key] + + 0.5823571654525552 * k5[key] + - 0.45808210592918697 * k6[key] + + 0.015151515151515152 * k7[key] + ) - new_time = time + step_size + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) - return new_state, new_time, new_step_size + return new_state, new_time, k7, err def integrate_fixed( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, steps: int, - method: str = "rk45", + method: str, **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -141,6 +234,8 @@ def integrate_fixed( step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -149,16 +244,53 @@ def integrate_fixed( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) step_size = (stop_time - start_time) / steps - time = start_time - def body(_loop_var, _loop_state): _state, _time = _loop_state - _state, _time, _ = step_fn(_state, _time, step_size) - + _state, _time, _, _ = step_fn(_state, _time, step_size) return _state, _time - state, time = keras.ops.fori_loop(0, steps, body, (state, time)) + state, _ = keras.ops.fori_loop( + 0, + steps, + body, + (state, start_time), + ) + return state + +def integrate_scheduled( + fn: Callable, + state: StateDict, + steps: Tensor | np.ndarray, + method: str, + **kwargs, +) -> StateDict: + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + + def body(_loop_var, _loop_state): + _time = steps[_loop_var] + step_size = steps[_loop_var + 1] - steps[_loop_var] + _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_state + + state = keras.ops.fori_loop( + 0, + keras.ops.shape(steps)[0] - 1, + body, + state, + ) return state @@ -167,114 +299,120 @@ def integrate_adaptive( state: dict[str, ArrayLike], start_time: ArrayLike, stop_time: ArrayLike, - min_steps: int = 10, - max_steps: int = 1000, - method: str = "rk45", + min_steps: int, + max_steps: int, + method: str, **kwargs, ) -> dict[str, ArrayLike]: if max_steps <= min_steps: raise ValueError("Maximum number of steps must be greater than minimum number of steps.") match method: - case "euler": - step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step + case "euler": + raise ValueError("Adaptive step sizing is not supported for the 'euler' method.") case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: raise TypeError(f"Invalid integration method: {other!r}") + tolerance = keras.ops.convert_to_tensor(kwargs.get("tolerance", 1e-6), dtype="float32") step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True) + initial_step = (stop_time - start_time) / float(min_steps) + step0 = keras.ops.convert_to_tensor(0.0, dtype="float32") + count_not_accepted = 0 - def cond(_state, _time, _step_size, _step): - # while step < min_steps or time_remaining > 0 and step < max_steps + # "First Same As Last" (FSAL) property + k1_0 = fn(start_time, **filter_kwargs(state, fn)) - # time remaining after the next step - time_remaining = keras.ops.abs(stop_time - (_time + _step_size)) + def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted): + time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (_time + _step_size)) + step_lt_min = keras.ops.less(_step, float(min_steps)) + step_lt_max = keras.ops.less(_step, float(max_steps)) - return keras.ops.logical_or( - keras.ops.all(_step < min_steps), - keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.all(_step < max_steps)), - ) - - def body(_state, _time, _step_size, _step): - _step = _step + 1 + all_nans = _check_all_nans(_state) - # time remaining after the next step - time_remaining = stop_time - (_time + _step_size) + end_now = keras.ops.logical_or( + step_lt_min, keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max) + ) + return keras.ops.logical_and(~all_nans, end_now) + def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): + # Time remaining from current point + time_remaining = keras.ops.abs(stop_time - _time) min_step_size = time_remaining / (max_steps - _step) max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0) - - # reorder - min_step_size, max_step_size = ( - keras.ops.minimum(min_step_size, max_step_size), - keras.ops.maximum(min_step_size, max_step_size), + h = keras.ops.sign(_step_size) * keras.ops.clip(keras.ops.abs(_step_size), min_step_size, max_step_size) + + # Take one trial step + new_state, new_time, new_k1, err = step_fn( + state=_state, + time=_time, + step_size=h, + k1=_k1, ) - _state, _time, _step_size = step_fn( - _state, _time, _step_size, min_step_size=min_step_size, max_step_size=max_step_size + new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) + new_step_size = keras.ops.sign(new_step_size) * keras.ops.clip( + keras.ops.abs(new_step_size), min_step_size, max_step_size ) - return _state, _time, _step_size, _step - - # select initial step size conservatively - step_size = (stop_time - start_time) / max_steps - - step = 0 - time = start_time - - state, time, step_size, step = keras.ops.while_loop(cond, body, [state, time, step_size, step]) - - # do the last step - step_size = stop_time - time - state, _, _ = step_fn(state, time, step_size) - step = step + 1 + # Error control: reject if err > tolerance + too_big = keras.ops.greater(err, tolerance) + at_min = keras.ops.less_equal( + keras.ops.abs(h), + keras.ops.abs(min_step_size), + ) + accepted = keras.ops.logical_or(keras.ops.logical_not(too_big), at_min) - logging.debug("Finished integration after {} steps.", step) + updated_state = keras.ops.cond(accepted, lambda: new_state, lambda: _state) + updated_time = keras.ops.cond(accepted, lambda: new_time, lambda: _time) + updated_k1 = keras.ops.cond(accepted, lambda: new_k1, lambda: _k1) - return state + # Step counter: increment only on accepted steps + updated_step = _step + keras.ops.where(accepted, 1.0, 0.0) + _count_not_accepted = _count_not_accepted + keras.ops.where(accepted, 1.0, 0.0) + # For the next iteration, always use the new suggested step size + return updated_state, updated_time, new_step_size, updated_step, updated_k1, _count_not_accepted -def integrate_scheduled( - fn: Callable, - state: dict[str, ArrayLike], - steps: Tensor | np.ndarray, - method: str = "rk45", - **kwargs, -) -> dict[str, ArrayLike]: - match method: - case "euler": - step_fn = euler_step - case "rk45": - step_fn = rk45_step - case str() as name: - raise ValueError(f"Unknown integration method name: {name!r}") - case other: - raise TypeError(f"Invalid integration method: {other!r}") - - step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) - - def body(_loop_var, _loop_state): - _time = steps[_loop_var] - step_size = steps[_loop_var + 1] - steps[_loop_var] + # Run the adaptive loop + state, time, step_size, step, k1, count_not_accepted = keras.ops.while_loop( + cond, + body, + [state, start_time, initial_step, step0, k1_0, count_not_accepted], + ) - _loop_state, _, _ = step_fn(_loop_state, _time, step_size) - return _loop_state + if _check_all_nans(state): + raise RuntimeError(f"All values are NaNs in state during integration at {time}.") + + # Final step to hit stop_time exactly + time_diff = stop_time - time + time_remaining = keras.ops.sign(stop_time - start_time) * time_diff + if keras.ops.all(time_remaining > 0): + state, time, _, _ = step_fn( + state=state, + time=time, + step_size=time_diff, + k1=k1, + ) + step = step + 1.0 - state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + logging.debug(f"Finished integration after {step} steps with {count_not_accepted} rejected steps.") return state def integrate_scipy( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, scipy_kwargs: dict | None = None, **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: import scipy.integrate scipy_kwargs = scipy_kwargs or {} @@ -316,15 +454,15 @@ def scipy_wrapper_fn(time, x): def integrate( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike | None = None, stop_time: ArrayLike | None = None, min_steps: int = 10, max_steps: int = 10_000, steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, - method: str = "euler", + method: str = "rk45", **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: if start_time is None or stop_time is None: raise ValueError( @@ -351,14 +489,66 @@ def integrate( raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") +############ SDE Solvers ############# + + +def generate_noise(z: StateDict, seed: keras.random.SeedGenerator) -> StateDict: + noise = { + k: keras.random.normal(keras.ops.shape(val), dtype=keras.ops.dtype(val), seed=seed) for k, val in z.items() + } + return noise + + +def stochastic_adaptive_step_size_controller( + state, + drift, + adaptive_factor: ArrayLike, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), +) -> ArrayLike: + """ + Adaptive step size controller based on [1]. Similar to a tamed explicit Euler method when used in Euler-Maruyama. + + Adaptive step sizing uses: + h = max(1, ||x||**2) / max(1, ||f(x)||**2) * adaptive_factor + + + [1] Fang & Giles, Adaptive Euler-Maruyama Method for SDEs with Non-Globally Lipschitz Drift Coefficients (2020) + + Returns + ------- + New step size. + """ + state_norms = [] + drift_norms = [] + for key in state.keys(): + state_norms.append(keras.ops.norm(state[key], ord=2, axis=-1)) + drift_norms.append(keras.ops.norm(drift[key], ord=2, axis=-1)) + state_norm = keras.ops.stack(state_norms) + drift_norm = keras.ops.stack(drift_norms) + max_state_norm = keras.ops.maximum( + keras.ops.cast(1.0, dtype=keras.ops.dtype(state_norm)), keras.ops.max(state_norm) ** 2 + ) + max_drift_norm = keras.ops.maximum( + keras.ops.cast(1.0, dtype=keras.ops.dtype(drift_norm)), keras.ops.max(drift_norm) ** 2 + ) + new_step_size = max_state_norm / max_drift_norm * adaptive_factor + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + return new_step_size + + def euler_maruyama_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: dict[str, ArrayLike], -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + noise: StateDict, + use_adaptive_step_size: bool = False, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), + **kwargs, +) -> Union[Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict]]: """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -369,6 +559,9 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + use_adaptive_step_size: Whether to use adaptive step sizing. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. Returns: new_state: Updated state after one Euler-Maruyama step. @@ -378,78 +571,853 @@ def euler_maruyama_step( drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) - # Check noise keys - if set(diffusion.keys()) != set(noise.keys()): - raise ValueError("Keys of diffusion terms and noise do not match.") + new_step_size = step_size + if use_adaptive_step_size: + sign_step = keras.ops.sign(step_size) + new_step_size = stochastic_adaptive_step_size_controller( + state=state, + drift=drift, + adaptive_factor=max_step_size, + min_step_size=min_step_size, + max_step_size=max_step_size, + ) + new_step_size = sign_step * keras.ops.abs(new_step_size) + + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(new_step_size)) new_state = {} for key, d in drift.items(): + base = state[key] + new_step_size * d + if key in diffusion: + base = base + diffusion[key] * sqrt_step_size * noise[key] + new_state[key] = base + + if use_adaptive_step_size: + return new_state, time + new_step_size, new_step_size, state + return new_state, time + new_step_size, new_step_size + + +def two_step_adaptive_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: StateDict, + time: ArrayLike, + step_size: ArrayLike, + noise: StateDict, + last_state: StateDict = None, + use_adaptive_step_size: bool = True, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), + e_rel: float = 0.1, + e_abs: float = None, + r: float = 0.9, + adapt_safety: float = 0.9, + **kwargs, +) -> Union[ + Tuple[StateDict, ArrayLike, ArrayLike], + Tuple[StateDict, ArrayLike, ArrayLike, StateDict], +]: + """ + Performs a single adaptive step for stochastic differential equations based on [1]. + + Based on + + This method uses a predictor-corrector approach with error estimation: + 1. Take an Euler-Maruyama step (predictor) + 2. Take another Euler-Maruyama step from the predicted state + 3. Average the two predictions (corrector) + 4. Estimate error and adapt step size + + When step_size reaches min_step_size, steps are always accepted regardless of + error to ensure progress and termination within max_steps. + + [1] Jolicoeur-Martineau et al. (2021) "Gotta Go Fast When Generating Data with Score-Based Models" + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors (pre-scaled by sqrt(dt)). + last_state: Previous state for error estimation. + use_adaptive_step_size: Whether to adapt step size. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + e_rel: Relative error tolerance. + e_abs: Absolute error tolerance. Default assumes standardized targets. + r: Order of the method for step size adaptation. + adapt_safety: Safety factor for step size adaptation. + **kwargs: Additional arguments passed to drift_fn and diffusion_fn. + + Returns: + new_state: Updated state after one adaptive step. + new_time: time + dt (or time if step rejected). + new_step_size: Adapted step size for next iteration. + """ + state_euler, time_mid, _ = euler_maruyama_step( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=state, + time=time, + step_size=step_size, + min_step_size=min_step_size, + max_step_size=max_step_size, + noise=noise, + use_adaptive_step_size=False, + ) + + # Compute drift and diffusion at new state, but update from old state + drift_mid = drift_fn(time_mid, **filter_kwargs(state_euler, drift_fn)) + diffusion_mid = diffusion_fn(time_mid, **filter_kwargs(state_euler, diffusion_fn)) + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) + + state_euler_mid = {} + for key, d in drift_mid.items(): base = state[key] + step_size * d - if key in diffusion: # stochastic update - base = base + diffusion[key] * noise[key] + if key in diffusion_mid: + base = base + diffusion_mid[key] * sqrt_step_size * noise[key] + state_euler_mid[key] = base + + # average the two predictions + state_heun = {} + for key in state.keys(): + state_heun[key] = 0.5 * (state_euler[key] + state_euler_mid[key]) + + # Error estimation + if use_adaptive_step_size: + if e_abs is None: + e_abs = 0.02576 # 1% of 99% CI of standardized unit variance + # Check if we're at minimum step size - if so, force acceptance + at_min_step = keras.ops.less_equal(keras.ops.abs(step_size), min_step_size) + + # Compute error tolerance for each component + e_abs_tensor = keras.ops.cast(e_abs, dtype=keras.ops.dtype(list(state.values())[0])) + e_rel_tensor = keras.ops.cast(e_rel, dtype=keras.ops.dtype(list(state.values())[0])) + + max_error = keras.ops.cast(0.0, dtype=keras.ops.dtype(list(state.values())[0])) + + for key in state.keys(): + # Local error estimate: difference between Heun and first Euler step + error_estimate = keras.ops.abs(state_heun[key] - state_euler[key]) + + # Tolerance threshold + delta = keras.ops.maximum( + e_abs_tensor, + e_rel_tensor * keras.ops.maximum(keras.ops.abs(state_euler[key]), keras.ops.abs(last_state[key])), + ) + + # Normalized error + normalized_error = error_estimate / (delta + 1e-10) + + # Maximum error across all components and batch dimensions + component_max_error = keras.ops.max(normalized_error) + max_error = keras.ops.maximum(max_error, component_max_error) + + error_scale = 1 # 1/sqrt(n_params) + E2 = error_scale * max_error + + # Accept step if error is acceptable OR if at minimum step size + error_acceptable = keras.ops.less_equal(E2, keras.ops.cast(1.0, dtype=keras.ops.dtype(E2))) + accepted = keras.ops.logical_or(error_acceptable, at_min_step) + + # Adapt step size for next iteration (only if not at minimum) + # Ensure E2 is not zero to avoid division issues + E2_safe = keras.ops.maximum(E2, 1e-10) + + # New step size based on error estimate + adapt_factor = adapt_safety * keras.ops.power(E2_safe, -r) + new_step_candidate = step_size * adapt_factor + + # Clamp to valid range + new_step_size = keras.ops.clip(keras.ops.abs(new_step_candidate), min_step_size, max_step_size) + new_step_size = keras.ops.sign(step_size) * new_step_size + + # Return appropriate state based on acceptance + new_state = keras.ops.cond(accepted, lambda: state_heun, lambda: state) + + new_time = keras.ops.cond(accepted, lambda: time_mid, lambda: time) + + prev_state = keras.ops.cond(accepted, lambda: state_euler, lambda: state) + + return new_state, new_time, new_step_size, prev_state + + else: + return state_heun, time_mid, step_size + + +def compute_levy_area( + state: StateDict, diffusion: StateDict, noise: StateDict, noise_aux: StateDict, step_size: ArrayLike +) -> StateDict: + step_size_abs = keras.ops.abs(step_size) + sqrt_step_size = keras.ops.sqrt(step_size_abs) + inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(step_size_abs)) + + # Build Lévy area H_k from w_k and Z_k + H = {} + for k in state.keys(): + if k in diffusion: + term1 = 0.5 * step_size_abs * noise[k] + term2 = 0.5 * step_size_abs * sqrt_step_size * inv_sqrt3 * noise_aux[k] + H[k] = term1 + term2 + else: + H[k] = keras.ops.zeros_like(state[k]) + return H + + +def sea_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: StateDict, + time: ArrayLike, + step_size: ArrayLike, + noise: StateDict, # standard normals + noise_aux: StateDict, # standard normals + **kwargs, +) -> Tuple[StateDict, ArrayLike, ArrayLike]: + """ + Performs a single shifted Euler step for SDEs with additive noise [1]. + + Compared to Euler-Maruyama, this evaluates the drift at a shifted state, + which improves the local error and the global error constant for additive noise. + + The scheme is + X_{n+1} = X_n + f(t_n, X_n + g(t_n) * (0.5 * ΔW_n + ΔH_n) * h + g(t_n) * ΔW_n + + [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise. + + Returns: + new_state: Updated state after one SEA step. + new_time: time + dt. + """ + # Compute diffusion + diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) + + la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) + + # Build shifted state: X_shift = X + g * (0.5 * ΔW + ΔH) + shifted_state = {} + for key, x in state.items(): + if key in diffusion: + shifted_state[key] = x + diffusion[key] * (0.5 * sqrt_step_size * noise[key] + la[key]) + else: + shifted_state[key] = x + + # Drift evaluated at shifted state + drift_shifted = drift_fn(time, **filter_kwargs(shifted_state, drift_fn)) + + # Final update + new_state = {} + for key, d in drift_shifted.items(): + base = state[key] + step_size * d + if key in diffusion: + base = base + diffusion[key] * sqrt_step_size * noise[key] new_state[key] = base - return new_state, time + step_size + return new_state, time + step_size, step_size -def integrate_stochastic( +def shark_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, + time: ArrayLike, + step_size: ArrayLike, + noise: StateDict, + noise_aux: StateDict, + **kwargs, +) -> Tuple[StateDict, ArrayLike, ArrayLike]: + """ + Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion + per step and has a strong order 1.5. + + SHARK method as specified: + + 1) ỹ_k = y_k + g(y_k) H_k + 2) ỹ_{k+5/6} = ỹ_k + (5/6)[ f(ỹ_k) h + g(ỹ_k) W_k ] + 3) y_{k+1} = y_k + + (2/5) f(ỹ_k) h + + (3/5) f(ỹ_{k+5/6}) h + + g(ỹ_k) ( 2/5 W_k + 6/5 H_k ) + + g(ỹ_{k+5/6}) ( 3/5 W_k - 6/5 H_k ) + + with + H_k = 0.5 * |h| * W_k + (|h| ** 1.5) / (2 * sqrt(3)) * Z_k + + [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise. + + Returns: + new_state: Updated state after one SHARK step. + new_time: time + dt. + """ + h = step_size + t = time + h_mag = keras.ops.abs(h) + sqrt_h_mag = keras.ops.sqrt(h_mag) + + diffusion = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) + + la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) + + # === 1) shifted initial state === + y_tilde_k = {} + for k in state.keys(): + if k in diffusion: + y_tilde_k[k] = state[k] + diffusion[k] * la[k] + else: + y_tilde_k[k] = state[k] + + # === evaluate drift and diffusion at ỹ_k === + f_tilde_k = drift_fn(t, **filter_kwargs(y_tilde_k, drift_fn)) + g_tilde_k = diffusion_fn(t, **filter_kwargs(y_tilde_k, diffusion_fn)) + + # === 2) internal stage at 5/6 === + y_tilde_mid = {} + for k in state.keys(): + drift_part = (5.0 / 6.0) * f_tilde_k[k] * h + if k in g_tilde_k: + sto_part = (5.0 / 6.0) * g_tilde_k[k] * sqrt_h_mag * noise[k] + else: + sto_part = keras.ops.zeros_like(state[k]) + y_tilde_mid[k] = y_tilde_k[k] + drift_part + sto_part + + # === evaluate drift and diffusion at ỹ_(k+5/6) === + f_tilde_mid = drift_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, drift_fn)) + g_tilde_mid = diffusion_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, diffusion_fn)) + + # === 3) final update === + new_state = {} + for k in state.keys(): + # deterministic weights + det = state[k] + (2.0 / 5.0) * f_tilde_k[k] * h + (3.0 / 5.0) * f_tilde_mid[k] * h + + # stochastic parts + sto1 = ( + g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * la[k]) + if k in g_tilde_k + else keras.ops.zeros_like(det) + ) + sto2 = ( + g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * la[k]) + if k in g_tilde_mid + else keras.ops.zeros_like(det) + ) + + new_state[k] = det + sto1 + sto2 + + return new_state, t + h, h + + +def _apply_corrector( + new_state: StateDict, + new_time: ArrayLike, + i: ArrayLike, + corrector_steps: int, + score_fn: Optional[Callable], + corrector_noise_history: StateDict | None, + seed: keras.random.SeedGenerator, + step_size_factor: ArrayLike = 0.01, + noise_schedule=None, +) -> StateDict: + """Helper function to apply corrector steps [1]. + + [1] Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations" (2020) + """ + for j in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + if corrector_noise_history is None: + _z_corr = generate_noise(new_state, seed=seed) + else: + _z_corr = {k: val[i, j] for k, val in corrector_noise_history.items()} + + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + for k in new_state.keys(): + if k in score: + # Calculate required norms for Langevin step + z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + score_norm = keras.ops.maximum(score_norm, 1e-8) + + # Compute step size for the Langevin update + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 + + # Annealed Langevin Dynamics update + new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] + return new_state + + +def integrate_stochastic_fixed( + step_fn: Callable, + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, steps: int, + min_step_size: ArrayLike, + max_step_size: ArrayLike, + z_history: StateDict | None, + z_extra_history: StateDict | None, + score_fn: Optional[Callable], + step_size_factor: ArrayLike, + corrector_noise_history: StateDict | None, seed: keras.random.SeedGenerator, + corrector_steps: int = 0, + noise_schedule=None, +) -> StateDict: + """ + Performs fixed-step SDE integration. + """ + initial_step = (stop_time - start_time) / float(steps) + + def cond(_loop_var, _loop_state, _loop_time, _loop_step): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) + + def body(_i, _current_state, _current_time, _current_step): + # Determine step size: either the constant size or the remainder to reach stop_time + remaining = keras.ops.abs(stop_time - _current_time) + sign = keras.ops.sign(_current_step) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) + dt = sign * dt_mag + + # Generate noise increment + if z_history is None: + _noise_i = generate_noise(_current_state, seed=seed) + else: + _noise_i = {k: val[_i] for k, val in z_history.items()} + _noise_extra_i = None + if z_extra_history is not None: + if len(z_extra_history) == 0: + _noise_extra_i = generate_noise(_current_state, seed=seed) + else: + _noise_extra_i = {k: val[_i] for k, val in z_history.items()} + + new_state, new_time, new_step = step_fn( + state=_current_state, + time=_current_time, + step_size=dt, + min_step_size=min_step_size, + max_step_size=keras.ops.minimum(max_step_size, remaining), + noise=_noise_i, + noise_aux=_noise_extra_i, + use_adaptive_step_size=False, + ) + + if corrector_steps > 0: + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + seed=seed, + ) + return _i + 1, new_state, new_time, initial_step + + _, final_state, final_time, _ = keras.ops.while_loop( + cond, + body, + [0, state, start_time, initial_step], + ) + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") + + return final_state + + +def integrate_stochastic_adaptive( + step_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + max_steps: int, + min_step_size: ArrayLike, + max_step_size: ArrayLike, + initial_step: ArrayLike, + z_history: StateDict | None, + z_extra_history: StateDict | None, + score_fn: Optional[Callable], + step_size_factor: ArrayLike, + seed: keras.random.SeedGenerator, + corrector_noise_history: StateDict | None, + corrector_steps: int = 0, + noise_schedule=None, +) -> StateDict: + """ + Performs adaptive-step SDE integration. + """ + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, state) + if K.backend() == "jax": + seed = None # not needed, noise is generated upfront + else: + seed_body = seed + + def cond(i, current_state, current_time, current_step, last_state): + time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) + all_nans = _check_all_nans(current_state) + end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) + return keras.ops.logical_and(~all_nans, end_now) + + def body_adaptive(_i, _current_state, _current_time, _current_step, _last_state): + # Step Size Control + remaining = keras.ops.abs(stop_time - _current_time) + sign = keras.ops.sign(_current_step) + # Ensure the next step does not overshoot the stop_time + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) + dt = sign * dt_mag + + if z_history is None: + _noise_i = generate_noise(_current_state, seed=seed_body) + else: + _noise_i = {k: val[_i] for k, val in z_history.items()} + + _noise_extra_i = None + if z_extra_history is not None: + if len(z_extra_history) == 0: + _noise_extra_i = generate_noise(_current_state, seed=seed_body) + else: + _noise_extra_i = {k: val[_i] for k, val in z_history.items()} + + new_state, new_time, new_step, _new_current_state = step_fn( + state=_current_state, + last_state=_last_state, + time=_current_time, + step_size=dt, + min_step_size=min_step_size, + max_step_size=keras.ops.minimum(max_step_size, remaining), + noise=_noise_i, + noise_aux=_noise_extra_i, + use_adaptive_step_size=True, + ) + + if corrector_steps > 0: + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + seed=seed_body, + ) + + return _i + 1, new_state, new_time, new_step, _new_current_state + + # Execute the adaptive loop + final_counter, final_state, final_time, _, final_k1 = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") + + # Final step to hit stop_time exactly + time_diff = stop_time - final_time + time_remaining = keras.ops.sign(stop_time - start_time) * time_diff + if keras.ops.all(time_remaining > 0): + noise_final = generate_noise(final_state, seed=seed) + noise_extra_final = None + if z_extra_history is not None and len(z_extra_history) > 0: + noise_extra_final = generate_noise(final_state, seed=seed) + + final_state, _, _ = step_fn( + state=final_state, + time=final_time, + step_size=time_diff, + last_state=final_k1, + min_step_size=min_step_size, + max_step_size=time_remaining, + noise=noise_final, + noise_aux=noise_extra_final, + use_adaptive_step_size=False, + ) + final_counter = final_counter + 1 + + logging.debug(f"Finished integration after {final_counter}.") + return final_state + + +def integrate_langevin( + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + steps: int, + z_history: StateDict | None, + score_fn: Callable, + noise_schedule, + seed: keras.random.SeedGenerator, + corrector_noise_history: StateDict | None, + step_size_factor: ArrayLike = 0.01, + corrector_steps: int = 0, +) -> StateDict: + """ + Annealed Langevin dynamics using the given score_fn and noise_schedule [1]. + + At each step i with time t_i, performs for every state component k: + state_k <- state_k + e * score_k + sqrt(2 * e) * z + + Times are stepped linearly from start_time to stop_time. + + [1] Song et al., "Generative Modeling by Estimating Gradients of the Data Distribution" (2020) + """ + + if steps <= 0: + raise ValueError("Number of Langevin steps must be positive.") + if score_fn is None or noise_schedule is None: + raise ValueError("score_fn and noise_schedule must be provided.") + + # Linear time grid + dt = (stop_time - start_time) / float(steps) + effective_factor = step_size_factor * 100 / np.sqrt(steps) + + def cond(_loop_var, _loop_state, _loop_time): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) + + def body(_i, _loop_state, _loop_time): + # score at current time + score = score_fn(_loop_time, **filter_kwargs(_loop_state, score_fn)) + + # noise schedule + log_snr_t = noise_schedule.get_log_snr(t=_loop_time, training=False) + _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + new_state: StateDict = {} + if z_history is None: + z_history_i = generate_noise(_loop_state, seed=seed) + else: + z_history_i = {k: val[_i] for k, val in z_history.items()} + for k in _loop_state.keys(): + s_k = score.get(k, None) + if s_k is None: + new_state[k] = _loop_state[k] + continue + + e = effective_factor * sigma_t**2 + new_state[k] = _loop_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history_i[k] + + new_time = _loop_time + dt + + if corrector_steps > 0: + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + seed=seed, + ) + + return _i + 1, new_state, new_time + + _, final_state, final_time = keras.ops.while_loop( + cond, + body, + (0, state, start_time), + ) + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") + return final_state + + +def integrate_stochastic( + drift_fn: Callable, + diffusion_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + seed: keras.random.SeedGenerator, + steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", + min_steps: int = 50, + max_steps: int = 1_000, + score_fn: Callable = None, + corrector_steps: int = 0, + noise_schedule=None, + step_size_factor: ArrayLike = 0.01, **kwargs, -) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: +) -> StateDict: """ Integrates a stochastic differential equation from start_time to stop_time. + Dispatches to fixed-step or adaptive-step integration logic. + Args: drift_fn: Function that computes the drift term. diffusion_fn: Function that computes the diffusion term. state: Dictionary containing the initial state. start_time: Starting time for integration. - stop_time: Ending time for integration. - steps: Number of integration steps. + stop_time: Ending time for integration. steps: Number of integration steps. seed: Random seed for noise generation. - method: Integration method to use, e.g., 'euler_maruyama'. + steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. + method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. We pre-generate noise up to this many steps, + which may impact memory usage. + score_fn: Optional score function for predictor-corrector sampling. + corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing alpha_t in corrector. + step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. - Returns: - If return_noise is False, returns the final state dictionary. - If return_noise is True, returns a tuple of (final_state, noise_history). + Returns: Final state dictionary after integration. """ - if steps <= 0: - raise ValueError("Number of steps must be positive.") + is_adaptive = isinstance(steps, str) and steps in ["adaptive", "dynamic"] + + if is_adaptive: + if start_time is None or stop_time is None: + raise ValueError("Please provide start_time and stop_time for adaptive integration.") + if min_steps <= 0 or max_steps <= 0 or max_steps < min_steps: + raise ValueError("min_steps and max_steps must be positive, and max_steps >= min_steps.") + + loop_steps = max_steps + initial_step = (stop_time - start_time) / float(min_steps) + span_mag = keras.ops.abs(stop_time - start_time) + min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) + max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) + else: + if steps <= 0: + raise ValueError("Number of steps must be positive.") + loop_steps = int(steps) + initial_step = (stop_time - start_time) / float(loop_steps) + # For fixed step, min/max step size are just the fixed step size + min_step_size, max_step_size = initial_step, initial_step + + # Pre-generate corrector noise if requested + corrector_noise_history = None + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + if K.backend() == "jax": + corrector_noise_history = {} + for key, val in state.items(): + shape = keras.ops.shape(val) + corrector_noise_history[key] = keras.random.normal( + (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed + ) - # Select step function based on method match method: case "euler_maruyama": - step_fn = euler_maruyama_step + step_fn_raw = euler_maruyama_step + case "sea": + step_fn_raw = sea_step + if is_adaptive: + raise ValueError("SEA SDE solver does not support adaptive steps.") + case "shark": + step_fn_raw = shark_step + if is_adaptive: + raise ValueError("SHARK SDE solver does not support adaptive steps.") + case "two_step_adaptive": + step_fn_raw = two_step_adaptive_step + case "langevin": + if is_adaptive: + raise ValueError("Langevin sampling does not support adaptive steps.") + + z_history = None + if K.backend() == "jax": + logging.warning("JAX backend needs to preallocate random samples for max steps.") + z_history = {} + for key, val in state.items(): + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + + return integrate_langevin( + state=state, + start_time=start_time, + stop_time=stop_time, + steps=loop_steps, + z_history=z_history, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_steps=corrector_steps, + corrector_noise_history=corrector_noise_history, + seed=seed, + ) case other: raise TypeError(f"Invalid integration method: {other!r}") - # Prepare step function with partial application - step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs) - - # Time step - step_size = (stop_time - start_time) / steps - sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) + # Partial the step function with common arguments + step_fn = partial( + step_fn_raw, + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + **kwargs, + ) - # Pre-generate noise history: shape = (steps, *state_shape) - noise_history = {} - for key, val in state.items(): - noise_history[key] = ( - keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt + # Pre-generate standard normals for the predictor step (up to max_steps) + z_history = None + z_extra_history = None if method not in ["sea", "shark"] else {} + if K.backend() == "jax": + logging.warning("JAX backend needs to preallocate random samples for max steps.") + z_history = {} + for key, val in state.items(): + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + if method in ["sea", "shark"]: + z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + + if is_adaptive: + return integrate_stochastic_adaptive( + step_fn=step_fn, + state=state, + start_time=start_time, + stop_time=stop_time, + max_steps=max_steps, + min_step_size=min_step_size, + max_step_size=max_step_size, + initial_step=initial_step, + z_history=z_history, + z_extra_history=z_extra_history, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + seed=seed, + ) + else: + return integrate_stochastic_fixed( + step_fn=step_fn, + state=state, + start_time=start_time, + stop_time=stop_time, + min_step_size=min_step_size, + max_step_size=max_step_size, + steps=loop_steps, + z_history=z_history, + z_extra_history=z_extra_history, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + seed=seed, ) - - def body(_loop_var, _loop_state): - _current_state, _current_time = _loop_state - _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} - new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) - return new_state, new_time - - final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) - return final_state diff --git a/bayesflow/utils/optimal_transport/euclidean.py b/bayesflow/utils/optimal_transport/euclidean.py deleted file mode 100644 index 26610a81c..000000000 --- a/bayesflow/utils/optimal_transport/euclidean.py +++ /dev/null @@ -1,11 +0,0 @@ -import keras - - -def euclidean(x1, x2): - # TODO: rename and move this function - result = x1[:, None] - x2[None, :] - shape = list(keras.ops.shape(result)) - shape[2:] = [-1] - result = keras.ops.reshape(result, shape) - result = keras.ops.norm(result, ord=2, axis=-1) - return result diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 2def2b0c7..f969e78f8 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -1,78 +1,162 @@ import keras -from .. import logging +from bayesflow.types import Tensor +from bayesflow.utils import filter_kwargs +from .ot_utils import ( + squared_euclidean, + cosine_distance, + augment_for_partial_ot, + search_for_conditional_weight, +) -from .euclidean import euclidean +from .. import logging -def log_sinkhorn(x1, x2, seed: int = None, **kwargs): +def log_sinkhorn(x1: Tensor, x2: Tensor, conditions: Tensor | None = None, seed: int = None, **kwargs) -> Tensor: """ Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`. About 50% slower than the unstabilized version, so use only when you need numerical stability. + + Partial optimal transport can be performed by setting `partial=True` to reduce the effect of misspecified mappings + in mini-batch settings [1]. For conditional optimal transport, conditions can be provided along with a + `condition_ratio` [2]. + + [1] Nguyen et al. (2022) "Improving Mini-batch Optimal Transport via Partial Transportation" + [2] Cheng et al. (2025) "The Curse of Conditions: Analyzing and Improving Optimal Transport for + Conditional Flow-Based Generation" + [3] Fluri et al. (2024) "Improving Flow Matching for Simulation-Based Inference" """ - log_plan = log_sinkhorn_plan(x1, x2, **kwargs) + log_plan = log_sinkhorn_plan(x1, x2, conditions=conditions, **kwargs) + assignments = keras.random.categorical(log_plan, num_samples=1, seed=seed) assignments = keras.ops.squeeze(assignments, axis=1) return assignments -def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None): +def log_sinkhorn_plan( + x1: Tensor, + x2: Tensor, + conditions: Tensor | None = None, + regularization: float = 1.0, + atol: float = 1e-5, + max_steps: int = 1000, + conditional_ot_ratio: float = 0.5, + partial_ot_factor: float = 1.0, + **kwargs, +) -> Tensor: """ Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`. About 50% slower than the unstabilized version, so use primarily when you need numerical stability. - """ - cost = euclidean(x1, x2) - cost_scaled = -cost / regularization - # initialize transport plan from a gaussian kernel - log_plan = cost_scaled - keras.ops.max(cost_scaled) - n, m = keras.ops.shape(log_plan) + :param x1: Tensor of shape (n, ...) + Samples from the first distribution. + + :param x2: Tensor of shape (m, ...) + Samples from the second distribution. + + :param conditions: Optional tensor of shape (m, ...) + Conditions to be used in conditional optimal transport settings. - log_a = -keras.ops.log(n) - log_b = -keras.ops.log(m) + :param regularization: Regularization parameter. + Controls the standard deviation of the Gaussian kernel. + Default: 1.0 - def contains_nans(plan): - return keras.ops.any(keras.ops.isnan(plan)) + :param max_steps: Maximum number of iterations. + Default: 1000 - def is_converged(plan): - # for convergence, the target marginals must match - conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), log_b, rtol=0.0, atol=rtol + atol)) - conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), log_a, rtol=0.0, atol=rtol + atol)) - return conv0 & conv1 + :param atol: Absolute tolerance for convergence. + Default: 1e-5. - def cond(_, plan): - # break the while loop if the plan contains nans or is converged - return ~(contains_nans(plan) | is_converged(plan)) + :param conditional_ot_ratio: Ratio which measures the proportion of samples that are considered "potential optimal + transport candidates". 0.5 is equivalent to no conditioning. [2] recommends a ratio of 0.01. + Only used if `conditions` is not None. + Default: 0.01 - def body(steps, plan): - # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension - plan = plan - keras.ops.logsumexp(plan, axis=0, keepdims=True) + log_b - plan = plan - keras.ops.logsumexp(plan, axis=1, keepdims=True) + log_a + :param partial_ot_factor: Proportion of mass to transport in partial optimal transport. + Default: 1.0 (i.e., balanced OT) - return steps + 1, plan + :return: Tensor of shape (n, m) or (n+1, m+1) if partial=True + The log transport probabilities. + """ + if not (0.0 < partial_ot_factor <= 1.0): + raise ValueError(f"s must be in (0, 1] for partial OT, got {partial_ot_factor}") + partial = partial_ot_factor < 1.0 + + cost = squared_euclidean(x1, x2) + + if regularization <= 0.0: + raise ValueError(f"regularization must be positive, got {regularization}") - steps = 0 - steps, log_plan = keras.ops.while_loop(cond, body, (steps, log_plan), maximum_iterations=max_steps) + if conditions is not None and conditional_ot_ratio < 0.5: + cond_cost = cosine_distance(conditions, conditions) + cost, w = search_for_conditional_weight( + M=cost, + C=cond_cost, + condition_ratio=conditional_ot_ratio, + **filter_kwargs(kwargs, search_for_conditional_weight), + ) + + cost_scaled = -cost / regularization + if partial: + cost_scaled, a, b = augment_for_partial_ot( + cost_scaled=cost_scaled, + regularization=regularization, + s=partial_ot_factor, + **filter_kwargs(kwargs, augment_for_partial_ot), + ) + log_a = keras.ops.log(a) + log_b = keras.ops.log(b) + n, m = keras.ops.shape(cost_scaled) + else: + # balanced uniform marginals (scalars) + n, m = keras.ops.shape(cost_scaled) + log_a = keras.ops.full((n,), -keras.ops.log(keras.ops.cast(n, cost_scaled.dtype))) + log_b = keras.ops.full((m,), -keras.ops.log(keras.ops.cast(m, cost_scaled.dtype))) + + # log-plan is implicitly: log_plan = cost_scaled + u[:, None] + v[None, :] + u = keras.ops.zeros((n,), dtype=cost_scaled.dtype) + v = keras.ops.zeros((m,), dtype=cost_scaled.dtype) + + def contains_nans(_plan): + return keras.ops.any(keras.ops.isnan(_plan)) + + def cond(_, __, ___, _err): + return _err > atol + + def body(_steps, _u, _v, _err): + u_next = log_a - keras.ops.logsumexp(cost_scaled + keras.ops.expand_dims(_v, 0), axis=1) + v_next = log_b - keras.ops.logsumexp(cost_scaled + keras.ops.expand_dims(u_next, 1), axis=0) + + # Error check on dual variable change + err_next = keras.ops.max(keras.ops.abs(u_next - _u)) + return _steps + 1, u_next, v_next, err_next + + err0 = keras.ops.cast(1e30, cost_scaled.dtype) + steps, u, v, err = keras.ops.while_loop(cond, body, (0, u, v, err0), maximum_iterations=max_steps) + + # final reconstruction + log_plan = cost_scaled + keras.ops.expand_dims(u, 1) + keras.ops.expand_dims(v, 0) def do_nothing(): pass def log_steps(): msg = "Log-Sinkhorn-Knopp converged after {} steps." - logging.debug(msg, steps) def warn_convergence(): msg = "Log-Sinkhorn-Knopp did not converge after {} steps." - - logging.warning(msg, max_steps) + logging.warning(msg, steps) def warn_nans(): msg = "Log-Sinkhorn-Knopp produced NaNs after {} steps." logging.warning(msg, steps) keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing) - keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence) + keras.ops.cond(cond(None, None, None, err), log_steps, warn_convergence) + + if partial: + return log_plan[:-1, :-1] return log_plan diff --git a/bayesflow/utils/optimal_transport/optimal_transport.py b/bayesflow/utils/optimal_transport/optimal_transport.py index c1bca6d2c..4162477ff 100644 --- a/bayesflow/utils/optimal_transport/optimal_transport.py +++ b/bayesflow/utils/optimal_transport/optimal_transport.py @@ -1,5 +1,7 @@ import keras +from bayesflow.types import Tensor + from .log_sinkhorn import log_sinkhorn from .sinkhorn import sinkhorn @@ -11,7 +13,9 @@ } -def optimal_transport(x1, x2, method="log_sinkhorn", return_assignments=False, **kwargs): +def optimal_transport( + x1: Tensor, x2: Tensor, conditions: Tensor | None = None, method="sinkhorn", return_assignments=False, **kwargs +): """Matches elements from x2 onto x1, such that the transport cost between them is minimized, according to the method and cost matrix used. @@ -27,6 +31,10 @@ def optimal_transport(x1, x2, method="log_sinkhorn", return_assignments=False, * :param x2: Tensor of shape (m, ...) Samples from the second distribution. + :param conditions: Optional tensor of shape (k, ...) + Conditions to be used in conditional optimal transport settings. + Default: None + :param method: Method used to compute the transport cost. Default: 'log_sinkhorn' @@ -38,10 +46,14 @@ def optimal_transport(x1, x2, method="log_sinkhorn", return_assignments=False, * :return: Tensors of shapes (n, ...) and (m, ...) x1 and x2 in optimal transport permutation order. """ - assignments = methods[method.lower()](x1, x2, **kwargs) + assignments = methods[method.lower()](x1, x2, conditions, **kwargs) x2 = keras.ops.take(x2, assignments, axis=0) + if conditions is not None: + # conditions must be resampled along with x1 + conditions = keras.ops.take(conditions, assignments, axis=0) + if return_assignments: - return x1, x2, assignments + return x1, x2, conditions, assignments - return x1, x2 + return x1, x2, conditions diff --git a/bayesflow/utils/optimal_transport/ot_utils.py b/bayesflow/utils/optimal_transport/ot_utils.py new file mode 100644 index 000000000..8ca9ee4ad --- /dev/null +++ b/bayesflow/utils/optimal_transport/ot_utils.py @@ -0,0 +1,164 @@ +import keras + +from bayesflow.types import Tensor + + +def squared_euclidean(x1: Tensor, x2: Tensor) -> Tensor: + # flatten trailing dims + x1 = keras.ops.reshape(x1, (keras.ops.shape(x1)[0], -1)) + x2 = keras.ops.reshape(x2, (keras.ops.shape(x2)[0], -1)) + + x1_sq = keras.ops.sum(x1 * x1, axis=1, keepdims=True) # (n,1) + x2_sq = keras.ops.sum(x2 * x2, axis=1, keepdims=True) # (m,1) + cross = keras.ops.matmul(x1, keras.ops.transpose(x2)) # (n,m) + + dist2 = x1_sq + keras.ops.transpose(x2_sq) - 2.0 * cross + return keras.ops.maximum(dist2, 0.0) + + +def cosine_distance(x1: Tensor, x2: Tensor, eps: float = 1e-8) -> Tensor: + """ + Pairwise cosine distance: + d(x, y) = 1 - / (||x|| ||y||) + + x1: Tensor of shape (n, ...) + x2: Tensor of shape (m, ...) + returns: Tensor of shape (n, m) + """ + x1 = keras.ops.reshape(x1, (keras.ops.shape(x1)[0], -1)) + x2 = keras.ops.reshape(x2, (keras.ops.shape(x2)[0], -1)) + + x1 = x1 / (keras.ops.norm(x1, axis=1, keepdims=True) + eps) + x2 = x2 / (keras.ops.norm(x2, axis=1, keepdims=True) + eps) + + # cosine similarity + sim = keras.ops.matmul(x1, keras.ops.transpose(x2)) + sim = keras.ops.clip(sim, -1.0, 1.0) + return 1.0 - sim + + +def augment_for_partial_ot( + cost_scaled: Tensor, + regularization: float, + s: float, + dummy_cost: float | None = None, +) -> tuple[Tensor, Tensor, Tensor]: + """ + Augments a scaled cost matrix for partial OT via a dummy row/column. + + For partial OT with mass s ∈ (0,1), we transport s proportion of mass + and leave (1-s) unmatched via dummy nodes. + """ + if dummy_cost is None: + dummy_cost = keras.ops.max(-cost_scaled * regularization) + 1 # same as POT library default + + # cost_scaled is expected to be -C/eps with shape (n, m) + A = keras.ops.convert_to_tensor(dummy_cost, dtype=cost_scaled.dtype) + + n0 = keras.ops.shape(cost_scaled)[0] + m0 = keras.ops.shape(cost_scaled)[1] + + # Augmented cost: [[cost_scaled, 0], [0, -A/eps]] + zero_col = keras.ops.zeros((n0, 1), dtype=cost_scaled.dtype) # (n0, 1) + zero_row = keras.ops.zeros((1, m0), dtype=cost_scaled.dtype) # (1, m0) + br = keras.ops.reshape(-A / regularization, (1, 1)) # (1, 1) + + top = keras.ops.concatenate([cost_scaled, zero_col], axis=1) # (n0, m0+1) + bottom = keras.ops.concatenate([zero_row, br], axis=1) # (1, m0+1) + cost_scaled_aug = keras.ops.concatenate([top, bottom], axis=0) # (n0+1, m0+1) + + # Augmented marginals: [u_n, 1-s] and [u_m, 1-s] + dtype = cost_scaled.dtype + s_t = keras.ops.convert_to_tensor(s, dtype=dtype) + one_minus_s = 1.0 - s_t + + n0_f = keras.ops.cast(n0, dtype) + m0_f = keras.ops.cast(m0, dtype) + + a = keras.ops.concatenate( + [ + keras.ops.ones((n0,), dtype=dtype) * (1.0 / n0_f), + keras.ops.reshape(one_minus_s, (1,)), + ], + axis=0, + ) # (n0+1,) + + b = keras.ops.concatenate( + [ + keras.ops.ones((m0,), dtype=dtype) * (1.0 / m0_f), + keras.ops.reshape(one_minus_s, (1,)), + ], + axis=0, + ) # (m0+1,) + + return cost_scaled_aug, a, b + + +def search_for_conditional_weight( + M: Tensor, + C: Tensor, + condition_ratio: float, + initial_w: float = 1.0, + max_iter: int = 10, + abs_tol: float = 1e-3, + max_w: float = 1e8, +) -> tuple[Tensor, Tensor]: + """ + Find w such that mean((M + w*C) <= diag(M)) ≈ condition_ratio + + Returns: + cost = M + w*C (Tensor, shape (N,N)) + w = Tensor scalar (same dtype as M) + """ + dtype = M.dtype + r_t = keras.ops.convert_to_tensor(condition_ratio, dtype=dtype) + max_w_t = keras.ops.convert_to_tensor(max_w, dtype=dtype) + + # condition: M + w*C <= M_diag => w*C <= M_diag - M + M_diag = keras.ops.expand_dims(keras.ops.diagonal(M), 1) + Delta = M_diag - M # Pre-computed target threshold + + def get_ratio(w): + return keras.ops.mean(keras.ops.cast(w * C <= Delta, dtype)) + + # Boundary check at w=0 + r0 = get_ratio(keras.ops.convert_to_tensor(0.0, dtype=dtype)) + + def do_search(): + # Exponential search to bracket w + def exp_cond(it, low, high): + return (it < max_iter) & (get_ratio(high) > r_t) & (high < max_w_t) + + def exp_body(it, low, high): + return it + 1, high, high * 2.0 + + _, low, high = keras.ops.while_loop(exp_cond, exp_body, (0, 0.0, initial_w)) + high = keras.ops.minimum(high, max_w_t) + + # Binary search for optimal w + def bin_cond(it, low, high, best_w, best_r): + return (it < max_iter) & (keras.ops.abs(best_r - r_t) > abs_tol) + + def bin_body(it, low, high, best_w, best_r): + mid = (low + high) / 2.0 + r_mid = get_ratio(mid) + + # Update bounds + new_high = keras.ops.where(r_mid < r_t, mid, high) + new_low = keras.ops.where(r_mid < r_t, low, mid) + + # Update best_w based on proximity to target ratio + closer = keras.ops.abs(r_mid - r_t) < keras.ops.abs(best_r - r_t) + best_w_next = keras.ops.where(closer, mid, best_w) + best_r_next = keras.ops.where(closer, r_mid, best_r) + + return it + 1, new_low, new_high, best_w_next, best_r_next + + _, _, _, final_w, _ = keras.ops.while_loop(bin_cond, bin_body, (0, low, high, high, get_ratio(high))) + return final_w + + # Select w and construct final cost matrix once + optimal_w = keras.ops.cond(r0 < r_t, lambda: 0.0, do_search) + final_cost = M + optimal_w * C + + return final_cost, optimal_w diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index 45c568294..b4697d137 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -1,13 +1,18 @@ import keras from bayesflow.types import Tensor +from bayesflow.utils import filter_kwargs +from .ot_utils import ( + squared_euclidean, + cosine_distance, + augment_for_partial_ot, + search_for_conditional_weight, +) from .. import logging -from .euclidean import euclidean - -def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Tensor): +def sinkhorn(x1: Tensor, x2: Tensor, conditions: Tensor | None = None, seed: int = None, **kwargs) -> Tensor: """ Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm. @@ -15,27 +20,40 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten transport plan, containing assignment probabilities. The permutation is then sampled randomly according to the transport plan. + Partial optimal transport can be performed by setting `partial=True` to reduce the effect of misspecified mappings + in mini-batch settings [1]. For conditional optimal transport, conditions can be provided along with a + `condition_ratio` [2]. + + [1] Nguyen et al. (2022) "Improving Mini-batch Optimal Transport via Partial Transportation" + [2] Cheng et al. (2025) "The Curse of Conditions: Analyzing and Improving Optimal Transport for + Conditional Flow-Based Generation" + [3] Fluri et al. (2024) "Improving Flow Matching for Simulation-Based Inference" + :param x1: Tensor of shape (n, ...) Samples from the first distribution. :param x2: Tensor of shape (m, ...) Samples from the second distribution. - :param kwargs: - Additional keyword arguments that are passed to :py:func:`sinkhorn_plan`. + :param conditions: Optional tensor of shape (k, ...) + Conditions to be used in conditional optimal transport settings. + Default: None :param seed: Random seed to use for sampling indices. Default: None, which means the seed will be auto-determined for non-compiled contexts. + :param kwargs: + Additional keyword arguments that are passed to :py:func:`sinkhorn_plan`. + :return: Tensor of shape (n,) Assignment indices for x2. """ - plan = sinkhorn_plan(x1, x2, **kwargs) + plan = sinkhorn_plan(x1, x2, conditions=conditions, **kwargs) # we sample from log(plan) to receive assignments of length n, corresponding to indices of x2 # such that x2[assignments] matches x1 - assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed) + assignments = keras.random.categorical(keras.ops.log(plan + 1e-10), num_samples=1, seed=seed) assignments = keras.ops.squeeze(assignments, axis=1) return assignments @@ -44,10 +62,13 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten def sinkhorn_plan( x1: Tensor, x2: Tensor, + conditions: Tensor | None = None, regularization: float = 1.0, - max_steps: int = None, - rtol: float = 1e-5, - atol: float = 1e-8, + max_steps: int = 1000, + atol: float = 1e-5, + conditional_ot_ratio: float = 0.5, + partial_ot_factor: float = 1.0, + **kwargs, ) -> Tensor: """ Computes the Sinkhorn-Knopp optimal transport plan. @@ -58,70 +79,112 @@ def sinkhorn_plan( :param x2: Tensor of shape (m, ...) Samples from the second distribution. + :param conditions: Optional tensor of shape (m, ...) + Conditions to be used in conditional optimal transport settings. + :param regularization: Regularization parameter. Controls the standard deviation of the Gaussian kernel. + Default: 1.0 - :param max_steps: Maximum number of iterations, or None to run until convergence. - Default: None + :param max_steps: Maximum number of iterations. + Default: 1000 - :param rtol: Relative tolerance for convergence. + :param atol: Tolerance for convergence. Default: 1e-5. - :param atol: Absolute tolerance for convergence. - Default: 1e-8. + :param conditional_ot_ratio: Ratio which measures the proportion of samples that are considered “potential optimal + transport candidates”. 0.5 is equivalent to no conditioning. [2] recommends a ratio of 0.01. + Only used if `conditions` is not None. + Default: 0.0 + + :param partial_ot_factor: Proportion of mass to transport in partial optimal transport. + Default: 1.0 (i.e., balanced OT) :return: Tensor of shape (n, m) The transport probabilities. """ - cost = euclidean(x1, x2) + if not (0.0 < partial_ot_factor <= 1.0): + raise ValueError(f"s must be in (0, 1] for partial OT, got {partial_ot_factor}") + partial = partial_ot_factor < 1.0 + + cost = squared_euclidean(x1, x2) + + if regularization <= 0.0: + raise ValueError(f"regularization must be positive, got {regularization}") + + if conditions is not None and conditional_ot_ratio < 0.5: + cond_cost = cosine_distance(conditions, conditions) + cost, w = search_for_conditional_weight( + M=cost, + C=cond_cost, + condition_ratio=conditional_ot_ratio, + **filter_kwargs(kwargs, search_for_conditional_weight), + ) + cost_scaled = -cost / regularization + if partial: + cost_scaled, a, b = augment_for_partial_ot( + cost_scaled=cost_scaled, + regularization=regularization, + s=partial_ot_factor, + **filter_kwargs(kwargs, augment_for_partial_ot), + ) + a = keras.ops.reshape(a, (-1,)) # (n,) + b = keras.ops.reshape(b, (-1,)) # (m,) + else: + # balanced uniform marginals (scalars) + n, m = keras.ops.shape(cost_scaled) + a = keras.ops.ones((n,), dtype=cost_scaled.dtype) / keras.ops.cast(n, cost_scaled.dtype) + b = keras.ops.ones((m,), dtype=cost_scaled.dtype) / keras.ops.cast(m, cost_scaled.dtype) # initialize transport plan from a gaussian kernel # (more numerically stable version of keras.ops.exp(-cost/regularization)) plan = keras.ops.exp(cost_scaled - keras.ops.max(cost_scaled)) - n, m = keras.ops.shape(cost) + u = keras.ops.ones_like(a) + v = keras.ops.ones_like(b) + tiny = keras.ops.cast(1e-12, plan.dtype) - def contains_nans(plan): - return keras.ops.any(keras.ops.isnan(plan)) + def contains_nans(_plan): + return keras.ops.any(keras.ops.isnan(_plan)) - def is_converged(plan): - # for convergence, the target marginals must match - conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0 / m, rtol=rtol, atol=atol)) - conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0 / n, rtol=rtol, atol=atol)) - return conv0 & conv1 + def cond(_, __, ___, _err): + return _err > atol - def cond(_, plan): - # break the while loop if the plan contains nans or is converged - return ~(contains_nans(plan) | is_converged(plan)) + def body(_steps, _u, _v, _err): + plan_v = keras.ops.matmul(plan, keras.ops.expand_dims(_v, 1))[:, 0] + tiny + u_new = a / plan_v + plan_T_u = keras.ops.matmul(keras.ops.transpose(plan), keras.ops.expand_dims(u_new, 1))[:, 0] + tiny + v_new = b / plan_T_u - def body(steps, plan): - # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension - plan = plan / keras.ops.sum(plan, axis=0, keepdims=True) * (1.0 / m) - plan = plan / keras.ops.sum(plan, axis=1, keepdims=True) * (1.0 / n) + # log-relative change (stable even if u/v span many orders of magnitude) + du = keras.ops.max(keras.ops.abs(keras.ops.log((u_new + tiny) / (_u + tiny)))) + dv = keras.ops.max(keras.ops.abs(keras.ops.log((v_new + tiny) / (_v + tiny)))) + err_new = keras.ops.maximum(du, dv) - return steps + 1, plan + return _steps + 1, u_new, v_new, err_new - steps = 0 - steps, plan = keras.ops.while_loop(cond, body, (steps, plan), maximum_iterations=max_steps) + err0 = keras.ops.cast(1e30, plan.dtype) + steps, u, v, err = keras.ops.while_loop(cond, body, (0, u, v, err0), maximum_iterations=max_steps) + plan = (keras.ops.expand_dims(u, 1) * plan) * keras.ops.expand_dims(v, 0) def do_nothing(): pass def log_steps(): msg = "Sinkhorn-Knopp converged after {} steps." - - logging.debug(msg, max_steps) + logging.debug(msg, steps) def warn_convergence(): - msg = "Sinkhorn-Knopp did not converge after {}." - - logging.warning(msg, max_steps) + msg = "Sinkhorn-Knopp did not converge after {} steps." + logging.warning(msg, steps) def warn_nans(): msg = "Sinkhorn-Knopp produced NaNs after {} steps." logging.warning(msg, steps) keras.ops.cond(contains_nans(plan), warn_nans, do_nothing) - keras.ops.cond(is_converged(plan), log_steps, warn_convergence) + keras.ops.cond(cond(None, None, None, err), log_steps, warn_convergence) + if partial: + plan = plan[:-1, :-1] return plan diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 54d0d6605..6500125a7 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -286,6 +286,42 @@ def sample( """ return self.approximator.sample(num_samples=num_samples, conditions=conditions, **kwargs) + def compositional_sample( + self, + *, + num_samples: int, + conditions: Mapping[str, np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], + **kwargs, + ) -> dict[str, np.ndarray]: + """ + Draws `num_samples` samples from the approximator given specified composition conditions. + The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...). + + Parameters + ---------- + num_samples : int + The number of samples to generate. + conditions : dict[str, np.ndarray] + A dictionary where keys represent variable names and values are + NumPy arrays containing the adapted simulated variables. Keys used as summary or inference + conditions during training should be present. + Should have shape (n_datasets, n_compositional_conditions, ...). + compute_prior_score : Callable[[dict[str, np.ndarray]], dict[str, np.ndarray]] + A function that computes the log probability of samples under the prior distribution. + **kwargs : dict, optional + Additional keyword arguments passed to the approximator's sampling function. + + Returns + ------- + dict[str, np.ndarray] + A dictionary where keys correspond to variable names and + values are arrays containing the generated samples. + """ + return self.approximator.compositional_sample( + num_samples=num_samples, conditions=conditions, compute_prior_score=compute_prior_score, **kwargs + ) + def estimate( self, *, diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index a56802a3e..befc0da06 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -220,3 +220,71 @@ def approximator_with_summaries(request): ) case _: raise ValueError("Invalid param for approximator class.") + + +@pytest.fixture +def simple_log_simulator(): + """Create a simple simulator for testing.""" + import numpy as np + from bayesflow.simulators import Simulator + from bayesflow.utils.decorators import allow_batch_size + from bayesflow.types import Shape, Tensor + + class SimpleSimulator(Simulator): + """Simple simulator that generates mean and scale parameters.""" + + @allow_batch_size + def sample(self, batch_shape: Shape) -> dict[str, Tensor]: + # Generate parameters in original space + loc = np.random.normal(0.0, 1.0, size=batch_shape + (2,)) # location parameters + scale = np.random.lognormal(0.0, 0.5, size=batch_shape + (2,)) # scale parameters > 0 + + # Generate some dummy conditions + conditions = np.random.normal(0.0, 1.0, size=batch_shape + (3,)) + + return dict( + loc=loc.astype("float32"), scale=scale.astype("float32"), conditions=conditions.astype("float32") + ) + + return SimpleSimulator() + + +@pytest.fixture +def identity_adapter(): + """Create an adapter that applies no transformation to the parameters.""" + from bayesflow.adapters import Adapter + + adapter = Adapter() + adapter.to_array() + adapter.convert_dtype("float64", "float32") + + adapter.concatenate(["loc"], into="inference_variables") + adapter.concatenate(["conditions"], into="inference_conditions") + adapter.keep(["inference_variables", "inference_conditions"]) + return adapter + + +@pytest.fixture +def transforming_adapter(): + """Create an adapter that applies log transformation to scale parameters.""" + from bayesflow.adapters import Adapter + + adapter = Adapter() + adapter.to_array() + adapter.convert_dtype("float64", "float32") + + # Apply log transformation to scale parameters (to make them unbounded) + adapter.log(["scale"]) + + adapter.concatenate(["scale", "loc"], into="inference_variables") + adapter.concatenate(["conditions"], into="inference_conditions") + adapter.keep(["inference_variables", "inference_conditions"]) + return adapter + + +@pytest.fixture +def diffusion_network(): + """Create a diffusion network for compositional sampling.""" + from bayesflow.networks import DiffusionModel, MLP + + return DiffusionModel(subnet=MLP(widths=[32, 32])) diff --git a/tests/test_approximators/test_compositional_prior_score.py b/tests/test_approximators/test_compositional_prior_score.py new file mode 100644 index 000000000..02be46c00 --- /dev/null +++ b/tests/test_approximators/test_compositional_prior_score.py @@ -0,0 +1,43 @@ +"""Tests for compositional sampling and prior score computation with adapters.""" + +import numpy as np + +from bayesflow import ContinuousApproximator + + +def mock_prior_score_original_space(data_dict): + """Mock prior score function that expects data in original space.""" + loc = data_dict["loc"] + + # Simple prior: N(0,1) for loc + loc_score = -loc + return {"loc": loc_score} + + +def test_prior_score_identity_adapter(simple_log_simulator, identity_adapter, diffusion_network): + """Test that prior scores work correctly with transforming adapter (log transformation).""" + + # Create approximator with transforming adapter + approximator = ContinuousApproximator( + adapter=identity_adapter, + inference_network=diffusion_network, + ) + + # Generate test data and adapt it + data = simple_log_simulator.sample((2,)) + adapted_data = identity_adapter(data) + + # Build approximator + approximator.build_from_data(adapted_data) + + # Test compositional sampling + n_datasets, n_compositional = 3, 5 + conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")} + samples = approximator.compositional_sample( + num_samples=10, + conditions=conditions, + compute_prior_score=mock_prior_score_original_space, + ) + + assert "loc" in samples + assert samples["loc"].shape == (n_datasets, 10, 2) diff --git a/tests/test_networks/test_diffusion_model/conftest.py b/tests/test_networks/test_diffusion_model/conftest.py index b1ee915ae..581b4abde 100644 --- a/tests/test_networks/test_diffusion_model/conftest.py +++ b/tests/test_networks/test_diffusion_model/conftest.py @@ -1,4 +1,5 @@ import pytest +import keras @pytest.fixture() @@ -21,3 +22,49 @@ def edm_noise_schedule(): ) def noise_schedule(request): return request.getfixturevalue(request.param) + + +@pytest.fixture +def simple_diffusion_model(): + """Create a simple diffusion model for testing compositional sampling.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + return DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule="cosine", + prediction_type="noise", + loss_type="noise", + ) + + +@pytest.fixture +def compositional_conditions(): + """Create test conditions for compositional sampling.""" + batch_size = 2 + n_compositional = 3 + n_samples = 4 + condition_dim = 5 + + return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim)) + + +@pytest.fixture +def compositional_state(): + """Create test state for compositional sampling.""" + batch_size = 2 + n_samples = 4 + param_dim = 3 + + return keras.random.normal((batch_size, n_samples, param_dim)) + + +@pytest.fixture +def mock_prior_score(): + """Create a mock prior score function for testing.""" + + def prior_score_fn(theta): + # Simple quadratic prior: -0.5 * ||theta||^2 + return -theta + + return prior_score_fn diff --git a/tests/test_networks/test_diffusion_model/test_compositional_sampling.py b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py new file mode 100644 index 000000000..2757bd28a --- /dev/null +++ b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py @@ -0,0 +1,132 @@ +import keras +import pytest + + +def test_compositional_score_shape( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test that compositional score returns correct shapes.""" + # Build the model + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + time = 0.5 + + score = simple_diffusion_model.compositional_score( + xz=compositional_state, + time=time, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + training=False, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(score) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +def test_compositional_score_no_conditions_raises_error(simple_diffusion_model, compositional_state, mock_prior_score): + """Test that compositional score raises error when conditions is None.""" + simple_diffusion_model.build(keras.ops.shape(compositional_state), None) + + with pytest.raises(ValueError, match="Conditions are required for compositional sampling"): + simple_diffusion_model.compositional_score( + xz=compositional_state, time=0.5, conditions=None, compute_prior_score=mock_prior_score, training=False + ) + + +def test_inverse_compositional_basic( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test basic compositional inverse sampling.""" + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + # Test inverse sampling with ODE method + result = simple_diffusion_model._inverse_compositional( + z=compositional_state, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + density=False, + training=False, + method="euler", + steps=5, + start_time=1.0, + stop_time=0.0, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(result) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +def test_inverse_compositional_euler_maruyama_with_corrector( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test compositional inverse sampling with Euler-Maruyama and corrector steps.""" + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + result = simple_diffusion_model._inverse_compositional( + z=compositional_state, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + density=False, + training=False, + method="euler_maruyama", + steps=5, + corrector_steps=2, + start_time=1.0, + stop_time=0.0, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(result) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +@pytest.mark.parametrize("noise_schedule_name", ["cosine", "edm"]) +def test_compositional_sampling_with_different_schedules( + noise_schedule_name, compositional_state, compositional_conditions, mock_prior_score +): + """Test compositional sampling with different noise schedules.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + diffusion_model = DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule=noise_schedule_name, + prediction_type="noise", + loss_type="noise", + ) + + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + diffusion_model.build(state_shape, conditions_shape) + + score = diffusion_model.compositional_score( + xz=compositional_state, + time=0.5, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + training=False, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(score) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index db5c448d7..985e7d2f8 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -1,23 +1,31 @@ import numpy as np +import keras +import pytest +from bayesflow.utils import integrate, integrate_stochastic -def test_scheduled_integration(): - import keras - from bayesflow.utils import integrate +TOLERANCE_ADAPTIVE = 1e-6 # Adaptive solvers should be very accurate. +TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance +# tolerances for SDE tests +TOL_MEAN = 5e-2 +TOL_VAR = 5e-2 + + +@pytest.mark.parametrize("method", ["euler", "rk45", "tsit5"]) +def test_scheduled_integration(method): def fn(t, x): return {"x": t**2} - steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0]) - approximate_result = 0.0 + 0.5**2 * 0.5 - result = integrate(fn, {"x": 0.0}, steps=steps)["x"] - assert result == approximate_result + def analytical_result(t): + return (t**3) / 3.0 + steps = keras.ops.arange(0.0, 1.0 + 1e-6, 0.01) + result = integrate(fn, {"x": 0.0}, steps=steps, method=method)["x"] + np.testing.assert_allclose(result, analytical_result(steps[-1]), atol=1e-1, rtol=1e-1) -def test_scipy_integration(): - import keras - from bayesflow.utils import integrate +def test_scipy_integration(): def fn(t, x): return {"x": keras.ops.exp(t)} @@ -34,3 +42,300 @@ def fn(t, x): scipy_kwargs={"atol": 1e-6, "rtol": 1e-6}, )["x"] np.testing.assert_allclose(exact_result, result, atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize( + "method, atol", [("euler", TOLERANCE_EULER), ("rk45", TOLERANCE_ADAPTIVE), ("tsit5", TOLERANCE_ADAPTIVE)] +) +def test_analytical_integration(method, atol): + def fn(t, x): + return {"x": keras.ops.convert_to_tensor([2.0 * t])} + + initial_state = {"x": keras.ops.convert_to_tensor([1.0])} + T_final = 1.0 + num_steps = 100 + analytical_result = 1.0 + T_final**2 + + result = integrate(fn, initial_state, start_time=0.0, stop_time=T_final, steps=num_steps, method=method)["x"] + if method == "euler": + result_adaptive = result + else: + result_adaptive = integrate( + fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 + )["x"] + + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) + + +@pytest.mark.parametrize( + "method, atol", [("euler", TOLERANCE_EULER), ("rk45", TOLERANCE_ADAPTIVE), ("tsit5", TOLERANCE_ADAPTIVE)] +) +def test_analytical_backward_integration(method, atol): + T_final = 1.0 + + def fn(t, x): + return {"x": keras.ops.convert_to_tensor([2.0 * t])} + + num_steps = 100 + analytical_result = 1.0 + initial_state = {"x": keras.ops.convert_to_tensor([1.0 + T_final**2])} + + result = integrate(fn, initial_state, start_time=T_final, stop_time=0.0, steps=num_steps, method=method)["x"] + if method == "euler": + result_adaptive = result + else: + result_adaptive = integrate( + fn, initial_state, start_time=T_final, stop_time=0.0, steps="adaptive", method=method, max_steps=1_000 + )["x"] + + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("euler_maruyama", True), + ("sea", False), + ("shark", False), + ("two_step_adaptive", False), + ("two_step_adaptive", True), + ], +) +def test_forward_additive_ou_weak_means_and_vars(method, use_adapt): + """ + Ornstein-Uhlenbeck with additive noise, integrated FORWARD in time. + This serves as a sanity check that forward integration still works correctly. + + Forward SDE: + dX = a X dt + sigma dW + + Exact at time T starting from X(0) = x_0: + E[X(T)] = x_0 * exp(a T) + Var[X(T)] = sigma^2 * (exp(2 a T) - 1) / (2 a) + """ + # SDE parameters + a = -1.0 + sigma = 0.5 + x_0 = 1.2 # initial condition at time 0 + T = 1.0 + + N = 10000 + seed = keras.random.SeedGenerator(42) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + return {"x": keras.ops.convert_to_tensor([sigma])} + + initial_state = {"x": keras.ops.ones((N,)) * x_0} + steps = 200 if not use_adapt else "adaptive" + + # Expected mean and variance at t=T + exp_mean = x_0 * np.exp(a * T) + exp_var = sigma**2 * (np.exp(2.0 * a * T) - 1.0) / (2.0 * a) + + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=0.0, + stop_time=T, + steps=steps, + seed=seed, + method=method, + ) + + x_T = np.array(out["x"]) + emp_mean = float(x_T.mean()) + emp_var = float(x_T.var()) + + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("euler_maruyama", True), + ("sea", False), + ("shark", False), + ("two_step_adaptive", False), + ("two_step_adaptive", True), + ], +) +def test_backward_additive_ou_weak_means_and_vars(method, use_adapt): + """ + Ornstein-Uhlenbeck with additive noise, integrated BACKWARD in time. + + When integrating from t=T back to t=0 with initial condition X(T) = x_T, + we get X(0) which should satisfy: + E[X(0)] = x_T * exp(-a T) (-a because we go backward) + Var[X(0)] = sigma^2 * (exp(-2 a T) - 1) / (-2 a) + + We verify weak accuracy by matching empirical mean and variance. + """ + # SDE parameters + a = -1.0 + sigma = 0.5 + x_T = 1.2 # initial condition at time T + T = 1.0 + + N = 10000 + seed = keras.random.SeedGenerator(42) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # additive noise, independent of state + return {"x": keras.ops.convert_to_tensor([sigma])} + + # Start at time T with value x_T + initial_state = {"x": keras.ops.ones((N,)) * x_T} + steps = 200 if not use_adapt else "adaptive" + # Expected mean and variance at t=0 after integrating backward from t=T + # For backward integration, the effective drift coefficient changes sign + exp_mean = x_T * np.exp(-a * T) + exp_var = sigma**2 * (np.exp(-2.0 * a * T) - 1.0) / (-2.0 * a) + + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=T, + stop_time=0.0, + steps=steps, + seed=seed, + method=method, + ) + + x_0 = np.array(out["x"]) + emp_mean = float(x_0.mean()) + emp_var = float(x_0.var()) + + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("euler_maruyama", True), + ("sea", False), + ("shark", False), + ("two_step_adaptive", False), + ("two_step_adaptive", True), + ], +) +def test_zero_noise_reduces_to_deterministic(method, use_adapt): + """ + With zero diffusion the SDE reduces to the ODE + dX = a X dt + """ + a = 0.7 + x0 = 0.9 + T = 1.25 + steps = 200 if not use_adapt else "adaptive" + seed = keras.random.SeedGenerator(0) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # identically zero diffusion + return {"x": keras.ops.convert_to_tensor([0.0])} + + initial_state = {"x": keras.ops.ones((256,)) * x0} + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=0.0, + stop_time=T, + steps=steps, + seed=seed, + method=method, + max_steps=1_000, + )["x"] + + exact = x0 * np.exp(a * T) + np.testing.assert_allclose(np.array(out).mean(), exact, atol=1e-3, rtol=0.1) + + +@pytest.mark.parametrize("steps", [500]) +def test_langevin_gaussian_sampling(steps): + """ + Test annealed Langevin dynamics on a 1D Gaussian target. + + Target distribution: N(mu, sigma^2), with score + ∇_x log p(x) = -(x - mu) / sigma^2 + + We verify that the empirical mean and variance after Langevin sampling + match the target within a loose tolerance (to allow for Monte Carlo noise). + """ + # target parameters + mu = 0.3 + sigma = 0.7 + + # number of particles + N = 20000 + start_time = 0.0 + stop_time = 1.0 + + # tolerances for mean and variance + tol_mean = 5e-2 + tol_var = 5e-2 + + # initial state: broad Gaussian, independent of target + seed = keras.random.SeedGenerator(42) + x0 = keras.random.normal((N,), dtype="float32", seed=seed) + initial_state = {"x": x0} + + # simple dummy noise schedule: constant alpha + class DummyNoiseSchedule: + def get_log_snr(self, t, training=False): + return keras.ops.zeros_like(t) + + def get_alpha_sigma(self, log_snr_t): + alpha_t = keras.ops.ones_like(log_snr_t) + sigma_t = keras.ops.ones_like(log_snr_t) + return alpha_t, sigma_t + + noise_schedule = DummyNoiseSchedule() + + # score of the target Gaussian + def score_fn(t, x): + s = -(x - mu) / (sigma**2) + return {"x": s} + + # run Langevin + final_state = integrate_stochastic( + drift_fn=None, + diffusion_fn=None, + score_fn=score_fn, + noise_schedule=noise_schedule, + state=initial_state, + start_time=start_time, + stop_time=stop_time, + steps=steps, + seed=seed, + method="langevin", + max_steps=1_000, + corrector_steps=1, + ) + + xT = np.array(final_state["x"]) + emp_mean = float(xT.mean()) + emp_var = float(xT.var()) + + exp_mean = mu + exp_var = sigma**2 + + np.testing.assert_allclose(emp_mean, exp_mean, atol=tol_mean, rtol=0.1) + np.testing.assert_allclose(emp_var, exp_var, atol=tol_var, rtol=0.1) diff --git a/tests/test_utils/test_optimal_transport.py b/tests/test_utils/test_optimal_transport.py index 53a5fd7a6..f12723c47 100644 --- a/tests/test_utils/test_optimal_transport.py +++ b/tests/test_utils/test_optimal_transport.py @@ -16,44 +16,119 @@ def test_jit_compile(): ot(x, y, regularization=1.0, seed=0, max_steps=10) -@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) -def test_shapes(method): +@pytest.mark.parametrize( + ["method", "partial_ot_factor", "conditional_ot_ratio"], + [ + ("log_sinkhorn", 1.0, 0.01), + ("log_sinkhorn", 0.8, 0.5), + ("sinkhorn", 1.0, 0.01), + ("sinkhorn", 0.8, 0.5), + ], +) +def test_shapes(method, partial_ot_factor, conditional_ot_ratio): x = keras.random.normal((128, 8), seed=0) y = keras.random.normal((128, 8), seed=1) - ox, oy = optimal_transport(x, y, regularization=1.0, seed=0, max_steps=10, method=method) + cond = None + if conditional_ot_ratio < 0.5: + cond = keras.random.normal((128, 4, 1), seed=2) + + ox, oy, ocond = optimal_transport( + x, + y, + conditions=cond, + regularization=1.0, + seed=0, + max_steps=10, + method=method, + partial_ot_factor=partial_ot_factor, + conditional_ot_ratio=conditional_ot_ratio, + ) assert keras.ops.shape(ox) == keras.ops.shape(x) assert keras.ops.shape(oy) == keras.ops.shape(y) - - -@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) -def test_transport_cost_improves(method): + if cond is not None: + assert keras.ops.shape(ocond) == keras.ops.shape(cond) + + +@pytest.mark.parametrize( + ["method", "partial_ot_factor", "conditional_ot_ratio"], + [ + ("log_sinkhorn", 1.0, 0.01), + ("log_sinkhorn", 0.8, 0.5), + ("sinkhorn", 1.0, 0.01), + ("sinkhorn", 0.8, 0.5), + ], +) +def test_transport_cost_improves(method, partial_ot_factor, conditional_ot_ratio): x = keras.random.normal((128, 2), seed=0) y = keras.random.normal((128, 2), seed=1) - before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) + cond = None + if conditional_ot_ratio < 0.5: + cond = keras.random.normal((128, 4, 1), seed=2) - x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000, method=method) + before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) - after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) + x_after, y_after, cond_after = optimal_transport( + x, + y, + conditions=cond, + regularization=0.1, + seed=0, + max_steps=1000, + method=method, + partial_ot_factor=partial_ot_factor, + conditional_ot_ratio=conditional_ot_ratio, + ) + after_cost = keras.ops.sum(keras.ops.norm(x_after - y_after, axis=-1)) assert after_cost < before_cost -@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) -def test_assignment_is_optimal(method): +@pytest.mark.parametrize( + ["method", "partial_ot_factor"], + [ + ("log_sinkhorn", 1.0), + ("log_sinkhorn", 0.8), + ("sinkhorn", 1.0), + ("sinkhorn", 0.8), + ], +) +def test_assignment_is_optimal(method, partial_ot_factor): y = keras.random.normal((16, 2), seed=0) p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0) - x = keras.ops.take(y, p, axis=0) - _, _, assignments = optimal_transport( - x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True + _, _, _, assignments = optimal_transport( + x, + y, + regularization=0.01, + seed=0, + max_steps=10_000, + method=method, + return_assignments=True, + partial_ot_factor=partial_ot_factor, ) - # transport is stochastic, so it is expected that a small fraction of assignments do not match - assert keras.ops.sum(assignments == p) > 14 + # transport is stochastic, so it is expected that a small fraction of assignments does not match + assert keras.ops.sum(assignments == p) >= 14 + + +@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) +def test_no_nans_or_infs(method): + """Test that algorithm produces finite values even with challenging inputs and auto regularization.""" + # Test with well-separated distributions + x = keras.random.normal((64, 4), seed=0) * 10.0 + y = keras.random.normal((64, 4), seed=1) * 10.0 + 100.0 + + ox, oy, _, assignments = optimal_transport( + x, y, regularization=0.1, seed=0, max_steps=1000, method=method, return_assignments=True + ) + + assert keras.ops.all(keras.ops.isfinite(ox)) + assert keras.ops.all(keras.ops.isfinite(oy)) + assert keras.ops.all(keras.ops.isfinite(assignments)) def test_assignment_aligns_with_pot(): @@ -65,7 +140,7 @@ def test_assignment_aligns_with_pot(): x = keras.random.normal((16, 2), seed=0) p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0) - y = x[p] + y = keras.ops.take(x, p, axis=0) a = keras.ops.ones(keras.ops.shape(x)[0]) b = keras.ops.ones(keras.ops.shape(y)[0]) @@ -76,7 +151,9 @@ def test_assignment_aligns_with_pot(): pot_assignments = keras.random.categorical(keras.ops.log(pot_plan), num_samples=1, seed=0) pot_assignments = keras.ops.squeeze(pot_assignments, axis=-1) - _, _, assignments = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=10_000, return_assignments=True) + _, _, _, assignments = optimal_transport( + x, y, method="log_sinkhorn", regularization=1e-3, seed=0, max_steps=10_000, return_assignments=True + ) assert_allclose(pot_assignments, assignments) @@ -87,8 +164,10 @@ def test_sinkhorn_plan_correct_marginals(): x1 = keras.random.normal((10, 2), seed=0) x2 = keras.random.normal((20, 2), seed=1) - assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=0), 0.05, atol=1e-6)) - assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=1), 0.1, atol=1e-6)) + plan = sinkhorn_plan(x1, x2, atol=1e-5, max_steps=1000) + + assert keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 0.05, atol=1e-6)) + assert keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 0.1, atol=1e-6)) def test_sinkhorn_plan_aligns_with_pot(): @@ -98,22 +177,25 @@ def test_sinkhorn_plan_aligns_with_pot(): pytest.skip("Need to install POT to run this test.") from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan - from bayesflow.utils.optimal_transport.euclidean import euclidean + from bayesflow.utils.optimal_transport.ot_utils import squared_euclidean x1 = keras.random.normal((10, 3), seed=0) x2 = keras.random.normal((20, 3), seed=1) a = keras.ops.ones(10) / 10 b = keras.ops.ones(20) / 20 - M = euclidean(x1, x2) + M = squared_euclidean(x1, x2) - pot_result = sinkhorn(a, b, M, 0.1, stopThr=1e-8) - our_result = sinkhorn_plan(x1, x2, regularization=0.1, rtol=1e-7) + pot_result = sinkhorn( + keras.ops.convert_to_numpy(a), keras.ops.convert_to_numpy(b), keras.ops.convert_to_numpy(M), 0.1, stopThr=1e-7 + ) + our_result = sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-7) - assert_allclose(pot_result, our_result) + assert_allclose(pot_result, our_result, rtol=1e-3, atol=1e-1) -def test_sinkhorn_plan_matches_analytical_result(): +@pytest.mark.parametrize("reg", [0.1, 0.1, 1.0]) +def test_sinkhorn_plan_matches_analytical_result(reg): from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan x1 = keras.ops.ones(16) @@ -122,12 +204,12 @@ def test_sinkhorn_plan_matches_analytical_result(): marginal_x1 = keras.ops.ones(16) / 16 marginal_x2 = keras.ops.ones(64) / 64 - result = sinkhorn_plan(x1, x2, regularization=0.1) + result = sinkhorn_plan(x1, x2, regularization=reg) # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals expected = keras.ops.outer(marginal_x1, marginal_x2) - assert_allclose(result, expected) + assert_allclose(result, expected, rtol=1e-4) def test_log_sinkhorn_plan_correct_marginals(): @@ -136,12 +218,10 @@ def test_log_sinkhorn_plan_correct_marginals(): x1 = keras.random.normal((10, 2), seed=0) x2 = keras.random.normal((20, 2), seed=1) - assert keras.ops.all( - keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=0), -keras.ops.log(20), atol=1e-3) - ) - assert keras.ops.all( - keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=1), -keras.ops.log(10), atol=1e-3) - ) + log_plan = log_sinkhorn_plan(x1, x2, atol=1e-5, max_steps=1000) + + assert keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(log_plan, axis=0), -keras.ops.log(20.0), atol=1e-3)) + assert keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(log_plan, axis=1), -keras.ops.log(10.0), atol=1e-3)) def test_log_sinkhorn_plan_aligns_with_pot(): @@ -151,19 +231,19 @@ def test_log_sinkhorn_plan_aligns_with_pot(): pytest.skip("Need to install POT to run this test.") from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan - from bayesflow.utils.optimal_transport.euclidean import euclidean + from bayesflow.utils.optimal_transport.ot_utils import squared_euclidean x1 = keras.random.normal((100, 3), seed=0) x2 = keras.random.normal((200, 3), seed=1) a = keras.ops.ones(100) / 100 b = keras.ops.ones(200) / 200 - M = euclidean(x1, x2) + M = squared_euclidean(x1, x2) - pot_result = keras.ops.log(sinkhorn_log(a, b, M, 0.1, stopThr=1e-7)) # sinkhorn_log returns probabilities - our_result = log_sinkhorn_plan(x1, x2, regularization=0.1) + pot_result = sinkhorn_log(a, b, M, 0.1, stopThr=1e-7) # sinkhorn_log returns probabilities + our_result = keras.ops.exp(log_sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-7)) - assert_allclose(pot_result, our_result) + assert_allclose(pot_result, our_result, rtol=1e-4, atol=1e-5) def test_log_sinkhorn_plan_matches_analytical_result(): @@ -180,4 +260,62 @@ def test_log_sinkhorn_plan_matches_analytical_result(): # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals expected = keras.ops.outer(marginal_x1, marginal_x2) - assert_allclose(result, expected) + assert_allclose(result, expected, rtol=1e-4) + + +def test_sinkhorn_vs_log_sinkhorn_consistency(): + """Test that Sinkhorn and log-Sinkhorn produce consistent results.""" + from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan + from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan + + x1 = keras.random.normal((30, 3), seed=0) + x2 = keras.random.normal((20, 3), seed=1) + + plan_sinkhorn = sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-5) + plan_log_sinkhorn = keras.ops.exp(log_sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-5)) + + assert_allclose(plan_sinkhorn, plan_log_sinkhorn, rtol=1e-3, atol=1e-1) + + +@pytest.mark.parametrize( + ["method", "s"], + [ + ("log_sinkhorn", 0.3), + ("log_sinkhorn", 0.8), + ("sinkhorn", 0.4), + ("sinkhorn", 0.7), + ], +) +def test_partial_ot_leaves_unmatched_mass(method, s): + """Test that partial OT correctly leaves a fraction of mass unmatched.""" + if method == "sinkhorn": + from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan as sinkhorn + else: + from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan as sinkhorn + n, m = 20, 20 + + # Create two distinct distributions + x = keras.random.normal((n, 2), seed=42) + y = keras.random.normal((m, 2), seed=123) + + # Get the transport plan with partial OT + plan = sinkhorn(x, y, regularization=0.1, max_steps=10_000, partial_ot_factor=s) + + if method == "log_sinkhorn": + plan = keras.ops.exp(plan) + + # Check marginal sums: each should be approximately s/n and s/m + row_sums = keras.ops.sum(plan, axis=1) + col_sums = keras.ops.sum(plan, axis=0) + + expected_row_mass = s / n + expected_col_mass = s / m + + # Each row should have approximately s/n mass (allowing small numerical error) + assert keras.ops.all(keras.ops.abs(row_sums - expected_row_mass) < 0.05) + + # Each column should have approximately s/m mass + assert keras.ops.all(keras.ops.abs(col_sums - expected_col_mass) < 0.05) + + # Total transported mass should be approximately s + assert abs(float(keras.ops.sum(plan)) - s) < 1e-3