From 139baaab11363978e59c7fd7eb21d895dc5743b7 Mon Sep 17 00:00:00 2001 From: jorektheglitch Date: Thu, 3 Nov 2022 18:28:15 +0300 Subject: [PATCH 1/5] Add typehints --- k_diffusion/augmentation.py | 26 +++++---- k_diffusion/config.py | 97 +++++++++++++++++++++++++--------- k_diffusion/evaluation.py | 27 ++++++---- k_diffusion/gns.py | 1 - k_diffusion/layers.py | 11 ++-- k_diffusion/models/image_v1.py | 26 +++++---- k_diffusion/sampling.py | 51 +++++++++--------- k_diffusion/utils.py | 20 +++---- 8 files changed, 165 insertions(+), 94 deletions(-) diff --git a/k_diffusion/augmentation.py b/k_diffusion/augmentation.py index 7dd17c68..694f9e37 100644 --- a/k_diffusion/augmentation.py +++ b/k_diffusion/augmentation.py @@ -1,28 +1,32 @@ -from functools import reduce import math import operator +from functools import reduce + +from typing import Tuple import numpy as np +from PIL.Image import Image from skimage import transform import torch from torch import nn +from torch import Tensor -def translate2d(tx, ty): +def translate2d(tx: float, ty: float) -> Tensor: mat = [[1, 0, tx], [0, 1, ty], [0, 0, 1]] return torch.tensor(mat, dtype=torch.float32) -def scale2d(sx, sy): +def scale2d(sx: float, sy: float) -> Tensor: mat = [[sx, 0, 0], [ 0, sy, 0], [ 0, 0, 1]] return torch.tensor(mat, dtype=torch.float32) -def rotate2d(theta): +def rotate2d(theta: Tensor) -> Tensor: mat = [[torch.cos(theta), torch.sin(-theta), 0], [torch.sin(theta), torch.cos(theta), 0], [ 0, 0, 1]] @@ -36,7 +40,7 @@ def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): self.a_aniso = a_aniso self.a_trans = a_trans - def __call__(self, image): + def __call__(self, image: Image) -> Tuple[Image, Tensor, Tensor]: h, w = image.size mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] @@ -74,12 +78,12 @@ def __call__(self, image): cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) # apply the transformation - image_orig = np.array(image, dtype=np.float32) / 255 - if image_orig.ndim == 2: - image_orig = image_orig[..., None] + image_np = np.array(image, dtype=np.float32) / 255 + if image_np.ndim == 2: + image_np = image_np[..., None] tf = transform.AffineTransform(mat.numpy()) - image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) - image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 + image = transform.warp(image_np, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) + image_orig = torch.as_tensor(image_np).movedim(2, 0) * 2 - 1 image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 return image, image_orig, cond @@ -88,7 +92,7 @@ class KarrasAugmentWrapper(nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - + def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): if aug_cond is None: aug_cond = input.new_zeros([input.shape[0], 9]) diff --git a/k_diffusion/config.py b/k_diffusion/config.py index 4b504d6d..72d7bb4d 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -3,13 +3,62 @@ import math import warnings +from typing import Any, BinaryIO, List, Optional, TextIO, TypedDict, Union + from jsonmerge import merge from . import augmentation, layers, models, utils -def load_config(file): - defaults = { +class ModelConfig(TypedDict): + sigma_data: float + patch_size: int + dropout_rate: float + augment_wrapper: bool + augment_prob: float + mapping_cond_dim: int + unet_cond_dim: int + cross_cond_dim: int + cross_attn_depths: Optional[Any] + skip_stages: int + has_variance: bool + + +class DatasetConfig(TypedDict): + type: str + + +class OptimizerConfig(TypedDict): + type: str + lr: float + betas: List[float] + eps: float + weight_decay: float + + +class LRSchedConfig(TypedDict): + type: str + inv_gamma: float + power: float + warmup: float + + +class EMASchedConfig(TypedDict): + type: str + power: float + max_value: float + + +class Config(TypedDict): + model: ModelConfig + dataset: DatasetConfig + optimizer: OptimizerConfig + lr_sched: LRSchedConfig + ema_sched: EMASchedConfig + + +def load_config(file: Union[BinaryIO, TextIO]) -> Config: + defaults: Config = { 'model': { 'sigma_data': 1., 'patch_size': 1, @@ -49,39 +98,39 @@ def load_config(file): return merge(defaults, config) -def make_model(config): - config = config['model'] - assert config['type'] == 'image_v1' +def make_model(config: Config): + model_config = config['model'] + assert model_config['type'] == 'image_v1' model = models.ImageDenoiserModelV1( - config['input_channels'], - config['mapping_out'], - config['depths'], - config['channels'], - config['self_attn_depths'], - config['cross_attn_depths'], - patch_size=config['patch_size'], - dropout_rate=config['dropout_rate'], - mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0), - unet_cond_dim=config['unet_cond_dim'], - cross_cond_dim=config['cross_cond_dim'], - skip_stages=config['skip_stages'], - has_variance=config['has_variance'], + model_config['input_channels'], + model_config['mapping_out'], + model_config['depths'], + model_config['channels'], + model_config['self_attn_depths'], + model_config['cross_attn_depths'], + patch_size=model_config['patch_size'], + dropout_rate=model_config['dropout_rate'], + mapping_cond_dim=model_config['mapping_cond_dim'] + (9 if model_config['augment_wrapper'] else 0), + unet_cond_dim=model_config['unet_cond_dim'], + cross_cond_dim=model_config['cross_cond_dim'], + skip_stages=model_config['skip_stages'], + has_variance=model_config['has_variance'], ) - if config['augment_wrapper']: + if model_config['augment_wrapper']: model = augmentation.KarrasAugmentWrapper(model) return model -def make_denoiser_wrapper(config): - config = config['model'] - sigma_data = config.get('sigma_data', 1.) - has_variance = config.get('has_variance', False) +def make_denoiser_wrapper(config: Config): + model_config = config['model'] + sigma_data = model_config.get('sigma_data', 1.) + has_variance = model_config.get('has_variance', False) if not has_variance: return partial(layers.Denoiser, sigma_data=sigma_data) return partial(layers.DenoiserWithVariance, sigma_data=sigma_data) -def make_sample_density(config): +def make_sample_density(config: ModelConfig): sd_config = config['sigma_sample_density'] sigma_data = config['sigma_data'] if sd_config['type'] == 'lognormal': diff --git a/k_diffusion/evaluation.py b/k_diffusion/evaluation.py index 2c34bbf1..17127fe7 100644 --- a/k_diffusion/evaluation.py +++ b/k_diffusion/evaluation.py @@ -2,11 +2,14 @@ import os from pathlib import Path +from typing import Optional + from cleanfid.inception_torchscript import InceptionV3W import clip from resize_right import resize import torch from torch import nn +from torch import Tensor from torch.nn import functional as F from torchvision import transforms from tqdm.auto import trange @@ -24,7 +27,7 @@ def __init__(self, device='cpu'): self.model = InceptionV3W(str(path), resize_inside=False).to(device) self.size = (299, 299) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: if x.shape[2:4] != self.size: x = resize(x, out_shape=self.size, pad_mode='reflect') if x.shape[1] == 1: @@ -41,7 +44,7 @@ def __init__(self, name='ViT-L/14@336px', device='cpu'): std=(0.26862954, 0.26130258, 0.27577711)) self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: if x.shape[2:4] != self.size: x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1) x = self.normalize(x) @@ -50,7 +53,7 @@ def forward(self, x): return x -def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): +def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size: Optional[int]) -> Tensor: n_per_proc = math.ceil(n / accelerator.num_processes) feats_all = [] try: @@ -63,13 +66,13 @@ def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): return torch.cat(feats_all)[:n] -def polynomial_kernel(x, y): +def polynomial_kernel(x: Tensor, y: Tensor) -> Tensor: d = x.shape[-1] dot = x @ y.transpose(-2, -1) return (dot / d + 1) ** 3 -def squared_mmd(x, y, kernel=polynomial_kernel): +def squared_mmd(x: Tensor, y: Tensor, kernel=polynomial_kernel) -> Tensor: m = x.shape[-2] n = y.shape[-2] kxx = kernel(x, x) @@ -85,7 +88,7 @@ def squared_mmd(x, y, kernel=polynomial_kernel): @utils.tf32_mode(matmul=False) -def kid(x, y, max_size=5000): +def kid(x: Tensor, y: Tensor, max_size: int = 5000) -> Tensor: x_size, y_size = x.shape[0], y.shape[0] n_partitions = math.ceil(max(x_size / max_size, y_size / max_size)) total_mmd = x.new_zeros([]) @@ -98,20 +101,24 @@ def kid(x, y, max_size=5000): class _MatrixSquareRootEig(torch.autograd.Function): @staticmethod - def forward(ctx, a): + def forward(ctx, a: Tensor) -> Tensor: # type: ignore # this violates LSP + vals: Tensor + vecs: Tensor vals, vecs = torch.linalg.eigh(a) ctx.save_for_backward(vals, vecs) return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output: Tensor) -> Tensor: # type: ignore # this violates LSP + vals: Tensor + vecs: Tensor vals, vecs = ctx.saved_tensors d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) vecs_t = vecs.transpose(-2, -1) return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t -def sqrtm_eig(a): +def sqrtm_eig(a: Tensor) -> Tensor: if a.ndim < 2: raise RuntimeError('tensor of matrices must have at least 2 dimensions') if a.shape[-2] != a.shape[-1]: @@ -120,7 +127,7 @@ def sqrtm_eig(a): @utils.tf32_mode(matmul=False) -def fid(x, y, eps=1e-8): +def fid(x: Tensor, y: Tensor, eps: float = 1e-8) -> Tensor: x_mean = x.mean(dim=0) y_mean = y.mean(dim=0) mean_term = (x_mean - y_mean).pow(2).sum() diff --git a/k_diffusion/gns.py b/k_diffusion/gns.py index dcb7b8d8..6cdbe0fc 100644 --- a/k_diffusion/gns.py +++ b/k_diffusion/gns.py @@ -1,5 +1,4 @@ import torch -from torch import nn class DDPGradientStatsHook: diff --git a/k_diffusion/layers.py b/k_diffusion/layers.py index cdeba0ad..1cd84455 100644 --- a/k_diffusion/layers.py +++ b/k_diffusion/layers.py @@ -1,8 +1,9 @@ import math +from typing import Tuple -from einops import rearrange, repeat import torch from torch import nn +from torch import Tensor from torch.nn import functional as F from . import utils @@ -17,26 +18,26 @@ def __init__(self, inner_model, sigma_data=1.): self.inner_model = inner_model self.sigma_data = sigma_data - def get_scalings(self, sigma): + def get_scalings(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor]: c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return c_skip, c_out, c_in - def loss(self, input, noise, sigma, **kwargs): + def loss(self, input: Tensor, noise: Tensor, sigma: Tensor, **kwargs): c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] noised_input = input + noise * utils.append_dims(sigma, input.ndim) model_output = self.inner_model(noised_input * c_in, sigma, **kwargs) target = (input - c_skip * noised_input) / c_out return (model_output - target).pow(2).flatten(1).mean(1) - def forward(self, input, sigma, **kwargs): + def forward(self, input: Tensor, sigma: Tensor, **kwargs): c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip class DenoiserWithVariance(Denoiser): - def loss(self, input, noise, sigma, **kwargs): + def loss(self, input: Tensor, noise: Tensor, sigma: Tensor, **kwargs): c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] noised_input = input + noise * utils.append_dims(sigma, input.ndim) model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs) diff --git a/k_diffusion/models/image_v1.py b/k_diffusion/models/image_v1.py index 9ffd5f2c..54d2e5c0 100644 --- a/k_diffusion/models/image_v1.py +++ b/k_diffusion/models/image_v1.py @@ -1,5 +1,7 @@ import math +from typing import List, Union, overload + import torch from torch import nn from torch.nn import functional as F @@ -7,13 +9,17 @@ from .. import layers, utils -def orthogonal_(module): +@overload +def orthogonal_(module: nn.Conv2d) -> nn.Conv2d: ... +@overload +def orthogonal_(module: nn.Linear) -> nn.Linear: ... +def orthogonal_(module: Union[nn.Conv2d, nn.Linear]) -> Union[nn.Conv2d, nn.Linear]: nn.init.orthogonal_(module.weight) return module class ResConvBlock(layers.ConditionedResidualBlock): - def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): + def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.) -> None: skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) super().__init__( layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)), @@ -28,8 +34,8 @@ def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.) class DBlock(layers.ConditionedSequential): - def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0): - modules = [nn.Identity()] + def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0) -> None: + modules: List[nn.Module] = [nn.Identity()] for i in range(n_layers): my_c_in = c_in if i == 0 else c_mid my_c_out = c_mid if i < n_layers - 1 else c_out @@ -49,8 +55,8 @@ def set_downsample(self, downsample): class UBlock(layers.ConditionedSequential): - def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0): - modules = [] + def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0) -> None: + modules: List[nn.Module] = [] for i in range(n_layers): my_c_in = c_in if i == 0 else c_mid my_c_out = c_mid if i < n_layers - 1 else c_out @@ -76,8 +82,8 @@ def set_upsample(self, upsample): class MappingNet(nn.Sequential): - def __init__(self, feats_in, feats_out, n_layers=2): - layers = [] + def __init__(self, feats_in, feats_out, n_layers=2) -> None: + layers: List[nn.Module] = [] for i in range(n_layers): layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))) layers.append(nn.GELU()) @@ -85,7 +91,7 @@ def __init__(self, feats_in, feats_out, n_layers=2): class ImageDenoiserModelV1(nn.Module): - def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False): + def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False) -> None: super().__init__() self.c_in = c_in self.channels = channels @@ -99,7 +105,7 @@ def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_att self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1) self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) nn.init.zeros_(self.proj_out.weight) - nn.init.zeros_(self.proj_out.bias) + nn.init.zeros_(self.proj_out.bias) # type: ignore # self.proj_out.bias may be None and all falls if cross_cond_dim == 0: cross_attn_depths = [False] * len(self_attn_depths) d_blocks, u_blocks = [], [] diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index f072dd5c..cf47d4e9 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -1,19 +1,22 @@ import math +from typing import Any, Callable, Dict, Tuple + from scipy import integrate import torch from torch import nn +from torch import Tensor from torchdiffeq import odeint from tqdm.auto import trange, tqdm from . import utils -def append_zero(x): +def append_zero(x: Tensor) -> Tensor: return torch.cat([x, x.new_zeros([1])]) -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): +def get_sigmas_karras(n: int, sigma_min: float, sigma_max: float, rho: float = 7., device='cpu') -> Tensor: """Constructs the noise schedule of Karras et al. (2022).""" ramp = torch.linspace(0, 1, n) min_inv_rho = sigma_min ** (1 / rho) @@ -22,20 +25,20 @@ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): return append_zero(sigmas).to(device) -def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): +def get_sigmas_exponential(n: int, sigma_min: float, sigma_max: float, device='cpu') -> Tensor: """Constructs an exponential noise schedule.""" sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() return append_zero(sigmas) -def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): +def get_sigmas_vp(n: int, beta_d: float = 19.9, beta_min: float = 0.1, eps_s: float = 1e-3, device='cpu') -> Tensor: """Constructs a continuous VP noise schedule.""" t = torch.linspace(1, eps_s, n, device=device) sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) return append_zero(sigmas) -def to_d(x, sigma, denoised): +def to_d(x: Tensor, sigma: Tensor, denoised: Tensor) -> Tensor: """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / utils.append_dims(sigma, x.ndim) @@ -51,7 +54,7 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): @torch.no_grad() -def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_euler(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.) -> Tensor: """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -72,7 +75,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.): +def sample_euler_ancestral(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, eta=1.) -> Tensor: """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -90,7 +93,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis @torch.no_grad() -def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heun(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.) -> Tensor: """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -119,7 +122,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_dpm_2(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.) -> Tensor: """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -150,7 +153,7 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.): +def sample_dpm_2_ancestral(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, eta=1.) -> Tensor: """Ancestral sampling with DPM-Solver inspired second-order steps.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -177,7 +180,7 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis return x -def linear_multistep_coeff(order, t, i, j): +def linear_multistep_coeff(order: int, t, i: int, j: int): if order - 1 > i: raise ValueError(f'Order {order} too high for step {i}') def fn(tau): @@ -191,7 +194,7 @@ def fn(tau): @torch.no_grad() -def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): +def sample_lms(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, order=4) -> Tensor: extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigmas_cpu = sigmas.detach().cpu().numpy() @@ -211,12 +214,12 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o @torch.no_grad() -def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): +def log_likelihood(model, x: Tensor, sigma_min: float, sigma_max: float, extra_args=None, atol=1e-4, rtol=1e-4) -> Tuple[Tensor, Dict[str, int]]: extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) v = torch.randint_like(x, 2) * 2 - 1 - fevals = 0 - def ode_fn(sigma, x): + fevals: int = 0 + def ode_fn(sigma, x: Tensor): nonlocal fevals with torch.enable_grad(): x = x[0].detach().requires_grad_() @@ -230,7 +233,7 @@ def ode_fn(sigma, x): t = x.new_tensor([sigma_min, sigma_max]) sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') latent, delta_ll = sol[0][-1], sol[1][-1] - ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + ll_prior: Tensor = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) return ll_prior + delta_ll, {'fevals': fevals} @@ -248,7 +251,7 @@ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1 def limiter(self, x): return 1 + math.atan(x - 1) - def propose_step(self, error): + def propose_step(self, error) -> bool: inv_error = 1 / (float(error) + self.eps) if not self.errs: self.errs = [inv_error, inv_error, inv_error] @@ -266,7 +269,7 @@ def propose_step(self, error): class DPMSolver(nn.Module): """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" - def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): + def __init__(self, model, extra_args: Dict[str, Any] = None, eps_callback=None, info_callback=None): super().__init__() self.model = model self.extra_args = {} if extra_args is None else extra_args @@ -318,7 +321,7 @@ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) return x_3, eps_cache - def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.): + def dpm_solver_fast(self, x: Tensor, t_start, t_end, nfe, eta=0., s_noise=1.) -> Tensor: if not t_end > t_start and eta: raise ValueError('eta must be 0 for reverse sampling') @@ -331,7 +334,7 @@ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.): orders = [3] * (m - 1) + [nfe % 3] for i in range(len(orders)): - eps_cache = {} + eps_cache: dict = {} t, t_next = ts[i], ts[i + 1] if eta: sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) @@ -356,7 +359,7 @@ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.): return x - def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1.): + def dpm_solver_adaptive(self, x: Tensor, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1.) -> Tuple[Tensor, Dict[str, int]]: if order not in {2, 3}: raise ValueError('order should be 2 or 3') forward = t_end > t_start @@ -372,7 +375,7 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078 info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} while s < t_end - 1e-5 if forward else s > t_end + 1e-5: - eps_cache = {} + eps_cache: dict = {} t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) if eta: sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) @@ -410,7 +413,7 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078 @torch.no_grad() -def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1.): +def sample_dpm_fast(model, x: Tensor, sigma_min, sigma_max, n, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, eta=0., s_noise=1.): """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: raise ValueError('sigma_min and sigma_max must not be 0') @@ -422,7 +425,7 @@ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback @torch.no_grad() -def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., return_info=False): +def sample_dpm_adaptive(model, x: Tensor, sigma_min, sigma_max, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., return_info: bool = False): """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: raise ValueError('sigma_min and sigma_max must not be 0') diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 9afedb99..705e895b 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -3,25 +3,27 @@ import math from pathlib import Path import shutil +from typing import Optional, Union import urllib import warnings from PIL import Image import torch from torch import nn, optim +from torch import Tensor from torch.utils import data from torchvision.transforms import functional as TF -def from_pil_image(x): +def from_pil_image(x: Image.Image) -> Tensor: """Converts from a PIL image to a tensor.""" - x = TF.to_tensor(x) - if x.ndim == 2: - x = x[..., None] - return x * 2 - 1 + tensor = TF.to_tensor(x) + if tensor.ndim == 2: + tensor = tensor[..., None] + return tensor * 2 - 1 -def to_pil_image(x): +def to_pil_image(x: Tensor) -> Image.Image: """Converts from a tensor to a PIL image.""" if x.ndim == 4: assert x.shape[0] == 1 @@ -37,7 +39,7 @@ def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): return {image_key: images} -def append_dims(x, target_dims): +def append_dims(x: Tensor, target_dims: int) -> Tensor: """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: @@ -45,12 +47,12 @@ def append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -def n_params(module): +def n_params(module) -> int: """Returns the number of trainable parameters in a module.""" return sum(p.numel() for p in module.parameters()) -def download_file(path, url, digest=None): +def download_file(path: Union[str, Path], url: str, digest: Optional[str] = None) -> Path: """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) From 931b477d4d24f7e5bdbb875f2721c2a266730d46 Mon Sep 17 00:00:00 2001 From: jorektheglitch Date: Thu, 3 Nov 2022 18:59:35 +0300 Subject: [PATCH 2/5] Autoformat internals code --- k_diffusion/augmentation.py | 36 +-- k_diffusion/config.py | 160 ++++++------ k_diffusion/evaluation.py | 56 +++-- k_diffusion/external.py | 58 +++-- k_diffusion/gns.py | 38 ++- k_diffusion/layers.py | 123 ++++++--- k_diffusion/models/image_v1.py | 215 +++++++++++++--- k_diffusion/sampling.py | 447 +++++++++++++++++++++++++++------ k_diffusion/utils.py | 154 ++++++++---- 9 files changed, 964 insertions(+), 323 deletions(-) diff --git a/k_diffusion/augmentation.py b/k_diffusion/augmentation.py index 694f9e37..126c780b 100644 --- a/k_diffusion/augmentation.py +++ b/k_diffusion/augmentation.py @@ -13,28 +13,26 @@ def translate2d(tx: float, ty: float) -> Tensor: - mat = [[1, 0, tx], - [0, 1, ty], - [0, 0, 1]] + mat = [[1, 0, tx], [0, 1, ty], [0, 0, 1]] return torch.tensor(mat, dtype=torch.float32) def scale2d(sx: float, sy: float) -> Tensor: - mat = [[sx, 0, 0], - [ 0, sy, 0], - [ 0, 0, 1]] + mat = [[sx, 0, 0], [0, sy, 0], [0, 0, 1]] return torch.tensor(mat, dtype=torch.float32) def rotate2d(theta: Tensor) -> Tensor: - mat = [[torch.cos(theta), torch.sin(-theta), 0], - [torch.sin(theta), torch.cos(theta), 0], - [ 0, 0, 1]] + mat = [ + [torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], + ] return torch.tensor(mat, dtype=torch.float32) class KarrasAugmentationPipeline: - def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): + def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1 / 8): self.a_prob = a_prob self.a_scale = a_scale self.a_aniso = a_aniso @@ -54,7 +52,7 @@ def __call__(self, image: Image) -> Tuple[Image, Tensor, Tensor]: # scaling do = (torch.rand([]) < self.a_prob).float() a2 = torch.randn([]) * do - mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) + mats.append(scale2d(self.a_scale**a2, self.a_scale**a2)) # rotation do = (torch.rand([]) < self.a_prob).float() a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do @@ -64,7 +62,7 @@ def __call__(self, image: Image) -> Tuple[Image, Tensor, Tensor]: a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do a5 = torch.randn([]) * do mats.append(rotate2d(a4)) - mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) + mats.append(scale2d(self.a_aniso**a5, self.a_aniso**-a5)) mats.append(rotate2d(-a4)) # translation do = (torch.rand([]) < self.a_prob).float() @@ -75,14 +73,24 @@ def __call__(self, image: Image) -> Tuple[Image, Tensor, Tensor]: # form the transformation matrix and conditioning vector mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) mat = reduce(operator.matmul, mats) - cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) + cond = torch.stack( + [a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7] + ) # apply the transformation image_np = np.array(image, dtype=np.float32) / 255 if image_np.ndim == 2: image_np = image_np[..., None] tf = transform.AffineTransform(mat.numpy()) - image = transform.warp(image_np, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) + image = transform.warp( + image_np, + tf.inverse, + order=3, + mode="reflect", + cval=0.5, + clip=False, + preserve_range=True, + ) image_orig = torch.as_tensor(image_np).movedim(2, 0) * 2 - 1 image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 return image, image_orig, cond diff --git a/k_diffusion/config.py b/k_diffusion/config.py index 72d7bb4d..f8762697 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -59,101 +59,115 @@ class Config(TypedDict): def load_config(file: Union[BinaryIO, TextIO]) -> Config: defaults: Config = { - 'model': { - 'sigma_data': 1., - 'patch_size': 1, - 'dropout_rate': 0., - 'augment_wrapper': True, - 'augment_prob': 0., - 'mapping_cond_dim': 0, - 'unet_cond_dim': 0, - 'cross_cond_dim': 0, - 'cross_attn_depths': None, - 'skip_stages': 0, - 'has_variance': False, + "model": { + "sigma_data": 1.0, + "patch_size": 1, + "dropout_rate": 0.0, + "augment_wrapper": True, + "augment_prob": 0.0, + "mapping_cond_dim": 0, + "unet_cond_dim": 0, + "cross_cond_dim": 0, + "cross_attn_depths": None, + "skip_stages": 0, + "has_variance": False, }, - 'dataset': { - 'type': 'imagefolder', + "dataset": { + "type": "imagefolder", }, - 'optimizer': { - 'type': 'adamw', - 'lr': 1e-4, - 'betas': [0.95, 0.999], - 'eps': 1e-6, - 'weight_decay': 1e-3, + "optimizer": { + "type": "adamw", + "lr": 1e-4, + "betas": [0.95, 0.999], + "eps": 1e-6, + "weight_decay": 1e-3, }, - 'lr_sched': { - 'type': 'inverse', - 'inv_gamma': 20000., - 'power': 1., - 'warmup': 0.99, - }, - 'ema_sched': { - 'type': 'inverse', - 'power': 0.6667, - 'max_value': 0.9999 + "lr_sched": { + "type": "inverse", + "inv_gamma": 20000.0, + "power": 1.0, + "warmup": 0.99, }, + "ema_sched": {"type": "inverse", "power": 0.6667, "max_value": 0.9999}, } config = json.load(file) return merge(defaults, config) def make_model(config: Config): - model_config = config['model'] - assert model_config['type'] == 'image_v1' + model_config = config["model"] + assert model_config["type"] == "image_v1" model = models.ImageDenoiserModelV1( - model_config['input_channels'], - model_config['mapping_out'], - model_config['depths'], - model_config['channels'], - model_config['self_attn_depths'], - model_config['cross_attn_depths'], - patch_size=model_config['patch_size'], - dropout_rate=model_config['dropout_rate'], - mapping_cond_dim=model_config['mapping_cond_dim'] + (9 if model_config['augment_wrapper'] else 0), - unet_cond_dim=model_config['unet_cond_dim'], - cross_cond_dim=model_config['cross_cond_dim'], - skip_stages=model_config['skip_stages'], - has_variance=model_config['has_variance'], + model_config["input_channels"], + model_config["mapping_out"], + model_config["depths"], + model_config["channels"], + model_config["self_attn_depths"], + model_config["cross_attn_depths"], + patch_size=model_config["patch_size"], + dropout_rate=model_config["dropout_rate"], + mapping_cond_dim=model_config["mapping_cond_dim"] + + (9 if model_config["augment_wrapper"] else 0), + unet_cond_dim=model_config["unet_cond_dim"], + cross_cond_dim=model_config["cross_cond_dim"], + skip_stages=model_config["skip_stages"], + has_variance=model_config["has_variance"], ) - if model_config['augment_wrapper']: + if model_config["augment_wrapper"]: model = augmentation.KarrasAugmentWrapper(model) return model def make_denoiser_wrapper(config: Config): - model_config = config['model'] - sigma_data = model_config.get('sigma_data', 1.) - has_variance = model_config.get('has_variance', False) + model_config = config["model"] + sigma_data = model_config.get("sigma_data", 1.0) + has_variance = model_config.get("has_variance", False) if not has_variance: return partial(layers.Denoiser, sigma_data=sigma_data) return partial(layers.DenoiserWithVariance, sigma_data=sigma_data) def make_sample_density(config: ModelConfig): - sd_config = config['sigma_sample_density'] - sigma_data = config['sigma_data'] - if sd_config['type'] == 'lognormal': - loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] - scale = sd_config['std'] if 'std' in sd_config else sd_config['scale'] + sd_config = config["sigma_sample_density"] + sigma_data = config["sigma_data"] + if sd_config["type"] == "lognormal": + loc = sd_config["mean"] if "mean" in sd_config else sd_config["loc"] + scale = sd_config["std"] if "std" in sd_config else sd_config["scale"] return partial(utils.rand_log_normal, loc=loc, scale=scale) - if sd_config['type'] == 'loglogistic': - loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data) - scale = sd_config['scale'] if 'scale' in sd_config else 0.5 - min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. - max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') - return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) - if sd_config['type'] == 'loguniform': - min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min'] - max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max'] + if sd_config["type"] == "loglogistic": + loc = sd_config["loc"] if "loc" in sd_config else math.log(sigma_data) + scale = sd_config["scale"] if "scale" in sd_config else 0.5 + min_value = sd_config["min_value"] if "min_value" in sd_config else 0.0 + max_value = sd_config["max_value"] if "max_value" in sd_config else float("inf") + return partial( + utils.rand_log_logistic, + loc=loc, + scale=scale, + min_value=min_value, + max_value=max_value, + ) + if sd_config["type"] == "loguniform": + min_value = ( + sd_config["min_value"] if "min_value" in sd_config else config["sigma_min"] + ) + max_value = ( + sd_config["max_value"] if "max_value" in sd_config else config["sigma_max"] + ) return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) - if sd_config['type'] == 'v-diffusion': - min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. - max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') - return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value) - if sd_config['type'] == 'split-lognormal': - loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] - scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1'] - scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2'] - return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2) - raise ValueError('Unknown sample density type') + if sd_config["type"] == "v-diffusion": + min_value = sd_config["min_value"] if "min_value" in sd_config else 0.0 + max_value = sd_config["max_value"] if "max_value" in sd_config else float("inf") + return partial( + utils.rand_v_diffusion, + sigma_data=sigma_data, + min_value=min_value, + max_value=max_value, + ) + if sd_config["type"] == "split-lognormal": + loc = sd_config["mean"] if "mean" in sd_config else sd_config["loc"] + scale_1 = sd_config["std_1"] if "std_1" in sd_config else sd_config["scale_1"] + scale_2 = sd_config["std_2"] if "std_2" in sd_config else sd_config["scale_2"] + return partial( + utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2 + ) + raise ValueError("Unknown sample density type") diff --git a/k_diffusion/evaluation.py b/k_diffusion/evaluation.py index 17127fe7..2634b3c6 100644 --- a/k_diffusion/evaluation.py +++ b/k_diffusion/evaluation.py @@ -18,18 +18,21 @@ class InceptionV3FeatureExtractor(nn.Module): - def __init__(self, device='cpu'): + def __init__(self, device="cpu"): super().__init__() - path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' - url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' - digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' - utils.download_file(path / 'inception-2015-12-05.pt', url, digest) + path = ( + Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) + / "k-diffusion" + ) + url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt" + digest = "f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4" + utils.download_file(path / "inception-2015-12-05.pt", url, digest) self.model = InceptionV3W(str(path), resize_inside=False).to(device) self.size = (299, 299) def forward(self, x: Tensor) -> Tensor: if x.shape[2:4] != self.size: - x = resize(x, out_shape=self.size, pad_mode='reflect') + x = resize(x, out_shape=self.size, pad_mode="reflect") if x.shape[1] == 1: x = torch.cat([x] * 3, dim=1) x = (x * 127.5 + 127.5).clamp(0, 255) @@ -37,27 +40,38 @@ def forward(self, x: Tensor) -> Tensor: class CLIPFeatureExtractor(nn.Module): - def __init__(self, name='ViT-L/14@336px', device='cpu'): + def __init__(self, name="ViT-L/14@336px", device="cpu"): super().__init__() self.model = clip.load(name, device=device)[0].eval().requires_grad_(False) - self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)) - self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution) + self.normalize = transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ) + self.size = ( + self.model.visual.input_resolution, + self.model.visual.input_resolution, + ) def forward(self, x: Tensor) -> Tensor: if x.shape[2:4] != self.size: - x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1) + x = resize(x.add(1).div(2), out_shape=self.size, pad_mode="reflect").clamp( + 0, 1 + ) x = self.normalize(x) x = self.model.encode_image(x).float() x = F.normalize(x) * x.shape[1] ** 0.5 return x -def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size: Optional[int]) -> Tensor: +def compute_features( + accelerator, sample_fn, extractor_fn, n, batch_size: Optional[int] +) -> Tensor: n_per_proc = math.ceil(n / accelerator.num_processes) feats_all = [] try: - for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process): + for i in trange( + 0, n_per_proc, batch_size, disable=not accelerator.is_main_process + ): cur_batch_size = min(n - i, batch_size) samples = sample_fn(cur_batch_size)[:cur_batch_size] feats_all.append(accelerator.gather(extractor_fn(samples))) @@ -93,8 +107,12 @@ def kid(x: Tensor, y: Tensor, max_size: int = 5000) -> Tensor: n_partitions = math.ceil(max(x_size / max_size, y_size / max_size)) total_mmd = x.new_zeros([]) for i in range(n_partitions): - cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)] - cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)] + cur_x = x[ + round(i * x_size / n_partitions) : round((i + 1) * x_size / n_partitions) + ] + cur_y = y[ + round(i * y_size / n_partitions) : round((i + 1) * y_size / n_partitions) + ] total_mmd = total_mmd + squared_mmd(cur_x, cur_y) return total_mmd / n_partitions @@ -120,9 +138,9 @@ def backward(ctx, grad_output: Tensor) -> Tensor: # type: ignore # this violat def sqrtm_eig(a: Tensor) -> Tensor: if a.ndim < 2: - raise RuntimeError('tensor of matrices must have at least 2 dimensions') + raise RuntimeError("tensor of matrices must have at least 2 dimensions") if a.shape[-2] != a.shape[-1]: - raise RuntimeError('tensor must be batches of square matrices') + raise RuntimeError("tensor must be batches of square matrices") return _MatrixSquareRootEig.apply(a) @@ -137,5 +155,7 @@ def fid(x: Tensor, y: Tensor, eps: float = 1e-8) -> Tensor: x_cov = x_cov + eps_eye y_cov = y_cov + eps_eye x_cov_sqrt = sqrtm_eig(x_cov) - cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt)) + cov_term = torch.trace( + x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt) + ) return mean_term + cov_term diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 2f1d2588..e949df7a 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -12,12 +12,12 @@ class VDenoiser(nn.Module): def __init__(self, inner_model): super().__init__() self.inner_model = inner_model - self.sigma_data = 1. + self.sigma_data = 1.0 def get_scalings(self, sigma): - c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = -sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 return c_skip, c_out, c_in def sigma_to_t(self, sigma): @@ -27,15 +27,24 @@ def t_to_sigma(self, t): return (t * math.pi / 2).tan() def loss(self, input, noise, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + c_skip, c_out, c_in = [ + utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] noised_input = input + noise * utils.append_dims(sigma, input.ndim) - model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + model_output = self.inner_model( + noised_input * c_in, self.sigma_to_t(sigma), **kwargs + ) target = (input - c_skip * noised_input) / c_out return (model_output - target).pow(2).flatten(1).mean(1) def forward(self, input, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip + c_skip, c_out, c_in = [ + utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] + return ( + self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + + input * c_skip + ) class DiscreteSchedule(nn.Module): @@ -44,8 +53,8 @@ class DiscreteSchedule(nn.Module): def __init__(self, sigmas, quantize): super().__init__() - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) + self.register_buffer("sigmas", sigmas) + self.register_buffer("log_sigmas", sigmas.log()) self.quantize = quantize @property @@ -69,7 +78,12 @@ def sigma_to_t(self, sigma, quantize=None): dists = log_sigma - self.log_sigmas[:, None] if quantize: return dists.abs().argmin(dim=0).view(sigma.shape) - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + low_idx = ( + dists.ge(0) + .cumsum(dim=0) + .argmax(dim=0) + .clamp(max=self.log_sigmas.shape[0] - 2) + ) high_idx = low_idx + 1 low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] w = (low - log_sigma) / (low - high) @@ -91,24 +105,28 @@ class DiscreteEpsDDPMDenoiser(DiscreteSchedule): def __init__(self, model, alphas_cumprod, quantize): super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) self.inner_model = model - self.sigma_data = 1. + self.sigma_data = 1.0 def get_scalings(self, sigma): c_out = -sigma - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 return c_out, c_in def get_eps(self, *args, **kwargs): return self.inner_model(*args, **kwargs) def loss(self, input, noise, sigma, **kwargs): - c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + c_out, c_in = [ + utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] noised_input = input + noise * utils.append_dims(sigma, input.ndim) eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) return (eps - noise).pow(2).flatten(1).mean(1) def forward(self, input, sigma, **kwargs): - c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + c_out, c_in = [ + utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) return input + eps * c_out @@ -116,8 +134,12 @@ def forward(self, input, sigma, **kwargs): class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): """A wrapper for OpenAI diffusion models.""" - def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): - alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) + def __init__( + self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu" + ): + alphas_cumprod = torch.tensor( + diffusion.alphas_cumprod, device=device, dtype=torch.float32 + ) super().__init__(model, alphas_cumprod, quantize=quantize) self.has_learned_sigmas = has_learned_sigmas @@ -131,7 +153,7 @@ def get_eps(self, *args, **kwargs): class CompVisDenoiser(DiscreteEpsDDPMDenoiser): """A wrapper for CompVis diffusion models.""" - def __init__(self, model, quantize=False, device='cpu'): + def __init__(self, model, quantize=False, device="cpu"): super().__init__(model, model.alphas_cumprod, quantize=quantize) def get_eps(self, *args, **kwargs): diff --git a/k_diffusion/gns.py b/k_diffusion/gns.py index 6cdbe0fc..c97e9490 100644 --- a/k_diffusion/gns.py +++ b/k_diffusion/gns.py @@ -6,7 +6,9 @@ def __init__(self, ddp_module): try: ddp_module.register_comm_hook(self, self._hook_fn) except AttributeError: - raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') + raise ValueError( + "DDPGradientStatsHook does not support non-DDP wrapped modules" + ) self._clear_state() def _clear_state(self): @@ -17,11 +19,15 @@ def _clear_state(self): def _hook_fn(self, bucket): buf = bucket.buffer() self.bucket_sq_norms_small_batch.append(buf.pow(2).sum()) - fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() + fut = torch.distributed.all_reduce( + buf, op=torch.distributed.ReduceOp.AVG, async_op=True + ).get_future() + def callback(fut): buf = fut.value()[0] self.bucket_sq_norms_large_batch.append(buf.pow(2).sum()) return buf + return fut.then(callback) def get_stats(self): @@ -49,10 +55,10 @@ class GradientNoiseScale: def __init__(self, beta=0.9998, eps=1e-8): self.beta = beta self.eps = eps - self.ema_sq_norm = 0. - self.ema_var = 0. - self.beta_cumprod = 1. - self.gradient_noise_scale = float('nan') + self.ema_sq_norm = 0.0 + self.ema_var = 0.0 + self.beta_cumprod = 1.0 + self.gradient_noise_scale = float("nan") def state_dict(self): """Returns the state of the object as a :class:`dict`.""" @@ -66,7 +72,9 @@ def load_state_dict(self, state_dict): """ self.__dict__.update(state_dict) - def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): + def update( + self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch + ): """Updates the state with a new batch's gradient statistics, and returns the current gradient noise scale. @@ -80,12 +88,18 @@ def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_larg n_large_batch (int): The total batch size of the mean of the microbatch or per sample gradients. """ - est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) - est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) + est_sq_norm = ( + n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch + ) / (n_large_batch - n_small_batch) + est_var = (sq_norm_small_batch - sq_norm_large_batch) / ( + 1 / n_small_batch - 1 / n_large_batch + ) self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var self.beta_cumprod *= self.beta - self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) + self.gradient_noise_scale = max(self.ema_var, self.eps) / max( + self.ema_sq_norm, self.eps + ) return self.gradient_noise_scale def get_gns(self): @@ -95,4 +109,6 @@ def get_gns(self): def get_stats(self): """Returns the current (debiased) estimates of the squared mean gradient and gradient variance.""" - return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) + return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / ( + 1 - self.beta_cumprod + ) diff --git a/k_diffusion/layers.py b/k_diffusion/layers.py index 1cd84455..6ab4bfe4 100644 --- a/k_diffusion/layers.py +++ b/k_diffusion/layers.py @@ -10,37 +10,46 @@ # Karras et al. preconditioned denoiser + class Denoiser(nn.Module): """A Karras et al. preconditioner for denoising diffusion models.""" - def __init__(self, inner_model, sigma_data=1.): + def __init__(self, inner_model, sigma_data=1.0): super().__init__() self.inner_model = inner_model self.sigma_data = sigma_data def get_scalings(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 return c_skip, c_out, c_in def loss(self, input: Tensor, noise: Tensor, sigma: Tensor, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + c_skip, c_out, c_in = [ + utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] noised_input = input + noise * utils.append_dims(sigma, input.ndim) model_output = self.inner_model(noised_input * c_in, sigma, **kwargs) target = (input - c_skip * noised_input) / c_out return (model_output - target).pow(2).flatten(1).mean(1) def forward(self, input: Tensor, sigma: Tensor, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + c_skip, c_out, c_in = [ + utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip class DenoiserWithVariance(Denoiser): def loss(self, input: Tensor, noise: Tensor, sigma: Tensor, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + c_skip, c_out, c_in = [ + utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] noised_input = input + noise * utils.append_dims(sigma, input.ndim) - model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs) + model_output, logvar = self.inner_model( + noised_input * c_in, sigma, return_variance=True, **kwargs + ) logvar = utils.append_dims(logvar, model_output.ndim) target = (input - c_skip * noised_input) / c_out losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2 @@ -49,6 +58,7 @@ def loss(self, input: Tensor, noise: Tensor, sigma: Tensor, **kwargs): # Residual blocks + class ResidualBlock(nn.Module): def __init__(self, *main, skip=None): super().__init__() @@ -61,6 +71,7 @@ def forward(self, input): # Noise level (and other) conditioning + class ConditionedModule(nn.Module): pass @@ -91,12 +102,16 @@ def __init__(self, *main, skip=None): self.skip = skip if skip else nn.Identity() def forward(self, input, cond): - skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input) + skip = ( + self.skip(input, cond) + if isinstance(self.skip, ConditionedModule) + else self.skip(input) + ) return self.main(input, cond) + skip class AdaGN(ConditionedModule): - def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): + def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key="cond"): super().__init__() self.num_groups = num_groups self.eps = eps @@ -106,13 +121,18 @@ def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): def forward(self, input, cond): weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1) input = F.group_norm(input, self.num_groups, eps=self.eps) - return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1) + return torch.addcmul( + utils.append_dims(bias, input.ndim), + input, + utils.append_dims(weight, input.ndim) + 1, + ) # Attention + class SelfAttention2d(ConditionedModule): - def __init__(self, c_in, n_head, norm, dropout_rate=0.): + def __init__(self, c_in, n_head, norm, dropout_rate=0.0): super().__init__() assert c_in % n_head == 0 self.norm_in = norm(c_in) @@ -134,8 +154,16 @@ def forward(self, input, cond): class CrossAttention2d(ConditionedModule): - def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0., - cond_key='cross', cond_key_padding='cross_padding'): + def __init__( + self, + c_dec, + c_enc, + n_head, + norm_dec, + dropout_rate=0.0, + cond_key="cross", + cond_key_padding="cross_padding", + ): super().__init__() assert c_dec % n_head == 0 self.cond_key = cond_key @@ -156,7 +184,7 @@ def forward(self, input, cond): kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2) k, v = kv.chunk(2, dim=1) scale = k.shape[3] ** -0.25 - att = ((q * scale) @ (k.transpose(2, 3) * scale)) + att = (q * scale) @ (k.transpose(2, 3) * scale) att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000 att = att.softmax(3) att = self.dropout(att) @@ -168,48 +196,67 @@ def forward(self, input, cond): # Downsampling/upsampling _kernels = { - 'linear': - [1 / 8, 3 / 8, 3 / 8, 1 / 8], - 'cubic': - [-0.01171875, -0.03515625, 0.11328125, 0.43359375, - 0.43359375, 0.11328125, -0.03515625, -0.01171875], - 'lanczos3': - [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, - -0.066637322306633, 0.13550527393817902, 0.44638532400131226, - 0.44638532400131226, 0.13550527393817902, -0.066637322306633, - -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [ + -0.01171875, + -0.03515625, + 0.11328125, + 0.43359375, + 0.43359375, + 0.11328125, + -0.03515625, + -0.01171875, + ], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], } -_kernels['bilinear'] = _kernels['linear'] -_kernels['bicubic'] = _kernels['cubic'] +_kernels["bilinear"] = _kernels["linear"] +_kernels["bicubic"] = _kernels["cubic"] class Downsample2d(nn.Module): - def __init__(self, kernel='linear', pad_mode='reflect'): + def __init__(self, kernel="linear", pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([_kernels[kernel]]) self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer('kernel', kernel_1d.T @ kernel_1d) + self.register_buffer("kernel", kernel_1d.T @ kernel_1d) def forward(self, x): x = F.pad(x, (self.pad,) * 4, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + weight = x.new_zeros( + [x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]] + ) indices = torch.arange(x.shape[1], device=x.device) weight[indices, indices] = self.kernel.to(weight) return F.conv2d(x, weight, stride=2) class Upsample2d(nn.Module): - def __init__(self, kernel='linear', pad_mode='reflect'): + def __init__(self, kernel="linear", pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([_kernels[kernel]]) * 2 self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer('kernel', kernel_1d.T @ kernel_1d) + self.register_buffer("kernel", kernel_1d.T @ kernel_1d) def forward(self, x): x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + weight = x.new_zeros( + [x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]] + ) indices = torch.arange(x.shape[1], device=x.device) weight[indices, indices] = self.kernel.to(weight) return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) @@ -217,11 +264,14 @@ def forward(self, x): # Embeddings + class FourierFeatures(nn.Module): - def __init__(self, in_features, out_features, std=1.): + def __init__(self, in_features, out_features, std=1.0): super().__init__() assert out_features % 2 == 0 - self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) + self.register_buffer( + "weight", torch.randn([out_features // 2, in_features]) * std + ) def forward(self, input): f = 2 * math.pi * input @ self.weight.T @@ -230,6 +280,7 @@ def forward(self, input): # U-Nets + class UNet(ConditionedModule): def __init__(self, d_blocks, u_blocks, skip_stages=0): super().__init__() @@ -239,7 +290,7 @@ def __init__(self, d_blocks, u_blocks, skip_stages=0): def forward(self, input, cond): skips = [] - for block in self.d_blocks[self.skip_stages:]: + for block in self.d_blocks[self.skip_stages :]: input = block(input, cond) skips.append(input) for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): diff --git a/k_diffusion/models/image_v1.py b/k_diffusion/models/image_v1.py index 54d2e5c0..e09301b9 100644 --- a/k_diffusion/models/image_v1.py +++ b/k_diffusion/models/image_v1.py @@ -1,5 +1,3 @@ -import math - from typing import List, Union, overload import torch @@ -13,14 +11,22 @@ def orthogonal_(module: nn.Conv2d) -> nn.Conv2d: ... @overload def orthogonal_(module: nn.Linear) -> nn.Linear: ... + + def orthogonal_(module: Union[nn.Conv2d, nn.Linear]) -> Union[nn.Conv2d, nn.Linear]: nn.init.orthogonal_(module.weight) return module class ResConvBlock(layers.ConditionedResidualBlock): - def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.) -> None: - skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) + def __init__( + self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.0 + ) -> None: + skip = ( + None + if c_in == c_out + else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) + ) super().__init__( layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)), nn.GELU(), @@ -30,22 +36,57 @@ def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.) nn.GELU(), nn.Conv2d(c_mid, c_out, 3, padding=1), nn.Dropout2d(dropout_rate, inplace=True), - skip=skip) + skip=skip, + ) class DBlock(layers.ConditionedSequential): - def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0) -> None: + def __init__( + self, + n_layers, + feats_in, + c_in, + c_mid, + c_out, + group_size=32, + head_size=64, + dropout_rate=0.0, + downsample=False, + self_attn=False, + cross_attn=False, + c_enc=0, + ) -> None: modules: List[nn.Module] = [nn.Identity()] for i in range(n_layers): my_c_in = c_in if i == 0 else c_mid my_c_out = c_mid if i < n_layers - 1 else c_out - modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) + modules.append( + ResConvBlock( + feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate + ) + ) if self_attn: - norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) - modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) + norm = lambda c_in: layers.AdaGN( + feats_in, c_in, max(1, my_c_out // group_size) + ) + modules.append( + layers.SelfAttention2d( + my_c_out, max(1, my_c_out // head_size), norm, dropout_rate + ) + ) if cross_attn: - norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) - modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) + norm = lambda c_in: layers.AdaGN( + feats_in, c_in, max(1, my_c_out // group_size) + ) + modules.append( + layers.CrossAttention2d( + my_c_out, + c_enc, + max(1, my_c_out // head_size), + norm, + dropout_rate, + ) + ) super().__init__(*modules) self.set_downsample(downsample) @@ -55,18 +96,52 @@ def set_downsample(self, downsample): class UBlock(layers.ConditionedSequential): - def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0) -> None: + def __init__( + self, + n_layers, + feats_in, + c_in, + c_mid, + c_out, + group_size=32, + head_size=64, + dropout_rate=0.0, + upsample=False, + self_attn=False, + cross_attn=False, + c_enc=0, + ) -> None: modules: List[nn.Module] = [] for i in range(n_layers): my_c_in = c_in if i == 0 else c_mid my_c_out = c_mid if i < n_layers - 1 else c_out - modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) + modules.append( + ResConvBlock( + feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate + ) + ) if self_attn: - norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) - modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) + norm = lambda c_in: layers.AdaGN( + feats_in, c_in, max(1, my_c_out // group_size) + ) + modules.append( + layers.SelfAttention2d( + my_c_out, max(1, my_c_out // head_size), norm, dropout_rate + ) + ) if cross_attn: - norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) - modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) + norm = lambda c_in: layers.AdaGN( + feats_in, c_in, max(1, my_c_out // group_size) + ) + modules.append( + layers.CrossAttention2d( + my_c_out, + c_enc, + max(1, my_c_out // head_size), + norm, + dropout_rate, + ) + ) modules.append(nn.Identity()) super().__init__(*modules) self.set_upsample(upsample) @@ -85,13 +160,30 @@ class MappingNet(nn.Sequential): def __init__(self, feats_in, feats_out, n_layers=2) -> None: layers: List[nn.Module] = [] for i in range(n_layers): - layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))) + layers.append( + orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)) + ) layers.append(nn.GELU()) super().__init__(*layers) class ImageDenoiserModelV1(nn.Module): - def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False) -> None: + def __init__( + self, + c_in, + feats_in, + depths, + channels, + self_attn_depths, + cross_attn_depths=None, + mapping_cond_dim=0, + unet_cond_dim=0, + cross_cond_dim=0, + dropout_rate=0.0, + patch_size=1, + skip_stages=0, + has_variance=False, + ) -> None: super().__init__() self.c_in = c_in self.channels = channels @@ -102,8 +194,16 @@ def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_att if mapping_cond_dim > 0: self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) self.mapping = MappingNet(feats_in, feats_in) - self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1) - self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) + self.proj_in = nn.Conv2d( + (c_in + unet_cond_dim) * self.patch_size**2, + channels[max(0, skip_stages - 1)], + 1, + ) + self.proj_out = nn.Conv2d( + channels[max(0, skip_stages - 1)], + c_in * self.patch_size**2 + (1 if self.has_variance else 0), + 1, + ) nn.init.zeros_(self.proj_out.weight) nn.init.zeros_(self.proj_out.bias) # type: ignore # self.proj_out.bias may be None and all falls if cross_cond_dim == 0: @@ -111,24 +211,63 @@ def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_att d_blocks, u_blocks = [], [] for i in range(len(depths)): my_c_in = channels[max(0, i - 1)] - d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) + d_blocks.append( + DBlock( + depths[i], + feats_in, + my_c_in, + channels[i], + channels[i], + downsample=i > skip_stages, + self_attn=self_attn_depths[i], + cross_attn=cross_attn_depths[i], + c_enc=cross_cond_dim, + dropout_rate=dropout_rate, + ) + ) for i in range(len(depths)): my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i] my_c_out = channels[max(0, i - 1)] - u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) + u_blocks.append( + UBlock( + depths[i], + feats_in, + my_c_in, + channels[i], + my_c_out, + upsample=i > skip_stages, + self_attn=self_attn_depths[i], + cross_attn=cross_attn_depths[i], + c_enc=cross_cond_dim, + dropout_rate=dropout_rate, + ) + ) self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages) - def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False): + def forward( + self, + input, + sigma, + mapping_cond=None, + unet_cond=None, + cross_cond=None, + cross_cond_padding=None, + return_variance=False, + ): c_noise = sigma.log() / 4 timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) - mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) + mapping_cond_embed = ( + torch.zeros_like(timestep_embed) + if mapping_cond is None + else self.mapping_cond(mapping_cond) + ) mapping_out = self.mapping(timestep_embed + mapping_cond_embed) - cond = {'cond': mapping_out} + cond = {"cond": mapping_out} if unet_cond is not None: input = torch.cat([input, unet_cond], dim=1) if cross_cond is not None: - cond['cross'] = cross_cond - cond['cross_padding'] = cross_cond_padding + cond["cross"] = cross_cond + cond["cross_padding"] = cross_cond_padding if self.patch_size > 1: input = F.pixel_unshuffle(input, self.patch_size) input = self.proj_in(input) @@ -143,8 +282,12 @@ def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=No return input def set_skip_stages(self, skip_stages): - self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1) - self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1) + self.proj_in = nn.Conv2d( + self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1 + ) + self.proj_out = nn.Conv2d( + self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1 + ) nn.init.zeros_(self.proj_out.weight) nn.init.zeros_(self.proj_out.bias) self.u_net.skip_stages = skip_stages @@ -156,7 +299,15 @@ def set_skip_stages(self, skip_stages): def set_patch_size(self, patch_size): self.patch_size = patch_size - self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1) - self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) + self.proj_in = nn.Conv2d( + (self.c_in + self.unet_cond_dim) * self.patch_size**2, + self.channels[max(0, self.u_net.skip_stages - 1)], + 1, + ) + self.proj_out = nn.Conv2d( + self.channels[max(0, self.u_net.skip_stages - 1)], + self.c_in * self.patch_size**2 + (1 if self.has_variance else 0), + 1, + ) nn.init.zeros_(self.proj_out.weight) nn.init.zeros_(self.proj_out.bias) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index cf47d4e9..4e12cd18 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -16,7 +16,9 @@ def append_zero(x: Tensor) -> Tensor: return torch.cat([x, x.new_zeros([1])]) -def get_sigmas_karras(n: int, sigma_min: float, sigma_max: float, rho: float = 7., device='cpu') -> Tensor: +def get_sigmas_karras( + n: int, sigma_min: float, sigma_max: float, rho: float = 7.0, device="cpu" +) -> Tensor: """Constructs the noise schedule of Karras et al. (2022).""" ramp = torch.linspace(0, 1, n) min_inv_rho = sigma_min ** (1 / rho) @@ -25,16 +27,26 @@ def get_sigmas_karras(n: int, sigma_min: float, sigma_max: float, rho: float = 7 return append_zero(sigmas).to(device) -def get_sigmas_exponential(n: int, sigma_min: float, sigma_max: float, device='cpu') -> Tensor: +def get_sigmas_exponential( + n: int, sigma_min: float, sigma_max: float, device="cpu" +) -> Tensor: """Constructs an exponential noise schedule.""" - sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + sigmas = torch.linspace( + math.log(sigma_max), math.log(sigma_min), n, device=device + ).exp() return append_zero(sigmas) -def get_sigmas_vp(n: int, beta_d: float = 19.9, beta_min: float = 0.1, eps_s: float = 1e-3, device='cpu') -> Tensor: +def get_sigmas_vp( + n: int, + beta_d: float = 19.9, + beta_min: float = 0.1, + eps_s: float = 1e-3, + device="cpu", +) -> Tensor: """Constructs a continuous VP noise schedule.""" t = torch.linspace(1, eps_s, n, device=device) - sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + sigmas = torch.sqrt(torch.exp(beta_d * t**2 / 2 + beta_min * t) - 1) return append_zero(sigmas) @@ -43,31 +55,58 @@ def to_d(x: Tensor, sigma: Tensor, denoised: Tensor) -> Tensor: return (x - denoised) / utils.append_dims(sigma, x.ndim) -def get_ancestral_step(sigma_from, sigma_to, eta=1.): +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): """Calculates the noise level (sigma_down) to step down to and the amount of noise to add (sigma_up) when doing an ancestral sampling step.""" if not eta: - return sigma_to, 0. - sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_to, 0.0 + sigma_up = min( + sigma_to, + eta + * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 return sigma_down, sigma_up @torch.no_grad() -def sample_euler(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.) -> Tensor: +def sample_euler( + model, + x: Tensor, + sigmas: Tensor, + extra_args: Dict[str, Any] = None, + callback: Callable[[Dict], Any] = None, + disable: bool = None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +) -> Tensor: """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigma_hat # Euler method x = x + d * dt @@ -75,7 +114,15 @@ def sample_euler(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = @torch.no_grad() -def sample_euler_ancestral(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, eta=1.) -> Tensor: +def sample_euler_ancestral( + model, + x: Tensor, + sigmas: Tensor, + extra_args: Dict[str, Any] = None, + callback: Callable[[Dict], Any] = None, + disable: bool = None, + eta=1.0, +) -> Tensor: """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -83,7 +130,15 @@ def sample_euler_ancestral(model, x: Tensor, sigmas: Tensor, extra_args: Dict[st denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d = to_d(x, sigmas[i], denoised) # Euler method dt = sigma_down - sigmas[i] @@ -93,20 +148,43 @@ def sample_euler_ancestral(model, x: Tensor, sigmas: Tensor, extra_args: Dict[st @torch.no_grad() -def sample_heun(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.) -> Tensor: +def sample_heun( + model, + x: Tensor, + sigmas: Tensor, + extra_args: Dict[str, Any] = None, + callback: Callable[[Dict], Any] = None, + disable: bool = None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +) -> Tensor: """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method @@ -122,20 +200,43 @@ def sample_heun(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = N @torch.no_grad() -def sample_dpm_2(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.) -> Tensor: +def sample_dpm_2( + model, + x: Tensor, + sigmas: Tensor, + extra_args: Dict[str, Any] = None, + callback: Callable[[Dict], Any] = None, + disable: bool = None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +) -> Tensor: """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Euler method dt = sigmas[i + 1] - sigma_hat @@ -153,7 +254,15 @@ def sample_dpm_2(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = @torch.no_grad() -def sample_dpm_2_ancestral(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, eta=1.) -> Tensor: +def sample_dpm_2_ancestral( + model, + x: Tensor, + sigmas: Tensor, + extra_args: Dict[str, Any] = None, + callback: Callable[[Dict], Any] = None, + disable: bool = None, + eta=1.0, +) -> Tensor: """Ancestral sampling with DPM-Solver inspired second-order steps.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -161,7 +270,15 @@ def sample_dpm_2_ancestral(model, x: Tensor, sigmas: Tensor, extra_args: Dict[st denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d = to_d(x, sigmas[i], denoised) if sigma_down == 0: # Euler method @@ -182,19 +299,29 @@ def sample_dpm_2_ancestral(model, x: Tensor, sigmas: Tensor, extra_args: Dict[st def linear_multistep_coeff(order: int, t, i: int, j: int): if order - 1 > i: - raise ValueError(f'Order {order} too high for step {i}') + raise ValueError(f"Order {order} too high for step {i}") + def fn(tau): - prod = 1. + prod = 1.0 for k in range(order): if j == k: continue prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] @torch.no_grad() -def sample_lms(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, order=4) -> Tensor: +def sample_lms( + model, + x: Tensor, + sigmas: Tensor, + extra_args: Dict[str, Any] = None, + callback: Callable[[Dict], Any] = None, + disable: bool = None, + order=4, +) -> Tensor: extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigmas_cpu = sigmas.detach().cpu().numpy() @@ -206,19 +333,39 @@ def sample_lms(model, x: Tensor, sigmas: Tensor, extra_args: Dict[str, Any] = No if len(ds) > order: ds.pop(0) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) cur_order = min(i + 1, order) - coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + coeffs = [ + linear_multistep_coeff(cur_order, sigmas_cpu, i, j) + for j in range(cur_order) + ] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x @torch.no_grad() -def log_likelihood(model, x: Tensor, sigma_min: float, sigma_max: float, extra_args=None, atol=1e-4, rtol=1e-4) -> Tuple[Tensor, Dict[str, int]]: +def log_likelihood( + model, + x: Tensor, + sigma_min: float, + sigma_max: float, + extra_args=None, + atol=1e-4, + rtol=1e-4, +) -> Tuple[Tensor, Dict[str, int]]: extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) v = torch.randint_like(x, 2) * 2 - 1 fevals: int = 0 + def ode_fn(sigma, x: Tensor): nonlocal fevals with torch.enable_grad(): @@ -229,17 +376,23 @@ def ode_fn(sigma, x: Tensor): grad = torch.autograd.grad((d * v).sum(), x)[0] d_ll = (v * grad).flatten(1).sum(1) return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) t = x.new_tensor([sigma_min, sigma_max]) - sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method="dopri5") latent, delta_ll = sol[0][-1], sol[1][-1] - ll_prior: Tensor = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) - return ll_prior + delta_ll, {'fevals': fevals} + ll_prior: Tensor = ( + torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + ) + return ll_prior + delta_ll, {"fevals": fevals} class PIDStepSizeController: """A PID controller for ODE adaptive step size control.""" - def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): + + def __init__( + self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8 + ): self.h = h self.b1 = (pcoeff + icoeff + dcoeff) / order self.b2 = -(pcoeff + 2 * dcoeff) / order @@ -256,7 +409,9 @@ def propose_step(self, error) -> bool: if not self.errs: self.errs = [inv_error, inv_error, inv_error] self.errs[0] = inv_error - factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + factor = ( + self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + ) factor = self.limiter(factor) accept = factor >= self.accept_safety if accept: @@ -269,7 +424,13 @@ def propose_step(self, error) -> bool: class DPMSolver(nn.Module): """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" - def __init__(self, model, extra_args: Dict[str, Any] = None, eps_callback=None, info_callback=None): + def __init__( + self, + model, + extra_args: Dict[str, Any] = None, + eps_callback=None, + info_callback=None, + ): super().__init__() self.model = model self.extra_args = {} if extra_args is None else extra_args @@ -286,7 +447,9 @@ def eps(self, eps_cache, key, x, t, *args, **kwargs): if key in eps_cache: return eps_cache[key], eps_cache sigma = self.sigma(t) * x.new_ones([x.shape[0]]) - eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) + eps = ( + x - self.model(x, sigma, *args, **self.extra_args, **kwargs) + ) / self.sigma(t) if self.eps_callback is not None: self.eps_callback() return eps, {key: eps, **eps_cache} @@ -294,36 +457,53 @@ def eps(self, eps_cache, key, x, t, *args, **kwargs): def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) x_1 = x - self.sigma(t_next) * h.expm1() * eps return x_1, eps_cache def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) s1 = t + r1 * h u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps - eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) - x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + eps_r1, eps_cache = self.eps(eps_cache, "eps_r1", u1, s1) + x_2 = ( + x + - self.sigma(t_next) * h.expm1() * eps + - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + ) return x_2, eps_cache def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) s1 = t + r1 * h s2 = t + r2 * h u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps - eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) - u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) - eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) - x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + eps_r1, eps_cache = self.eps(eps_cache, "eps_r1", u1, s1) + u2 = ( + x + - self.sigma(s2) * (r2 * h).expm1() * eps + - self.sigma(s2) + * (r2 / r1) + * ((r2 * h).expm1() / (r2 * h) - 1) + * (eps_r1 - eps) + ) + eps_r2, eps_cache = self.eps(eps_cache, "eps_r2", u2, s2) + x_3 = ( + x + - self.sigma(t_next) * h.expm1() * eps + - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + ) return x_3, eps_cache - def dpm_solver_fast(self, x: Tensor, t_start, t_end, nfe, eta=0., s_noise=1.) -> Tensor: + def dpm_solver_fast( + self, x: Tensor, t_start, t_end, nfe, eta=0.0, s_noise=1.0 + ) -> Tensor: if not t_end > t_start and eta: - raise ValueError('eta must be 0 for reverse sampling') + raise ValueError("eta must be 0 for reverse sampling") m = math.floor(nfe / 3) + 1 ts = torch.linspace(t_start, t_end, m + 1, device=x.device) @@ -341,58 +521,93 @@ def dpm_solver_fast(self, x: Tensor, t_start, t_end, nfe, eta=0., s_noise=1.) -> t_next_ = torch.minimum(t_end, self.t(sd)) su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 else: - t_next_, su = t_next, 0. + t_next_, su = t_next, 0.0 - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) denoised = x - self.sigma(t) * eps if self.info_callback is not None: - self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) + self.info_callback( + {"x": x, "i": i, "t": ts[i], "t_up": t, "denoised": denoised} + ) if orders[i] == 1: - x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_1_step( + x, t, t_next_, eps_cache=eps_cache + ) elif orders[i] == 2: - x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_2_step( + x, t, t_next_, eps_cache=eps_cache + ) else: - x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_3_step( + x, t, t_next_, eps_cache=eps_cache + ) x = x + su * s_noise * torch.randn_like(x) return x - def dpm_solver_adaptive(self, x: Tensor, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1.) -> Tuple[Tensor, Dict[str, int]]: + def dpm_solver_adaptive( + self, + x: Tensor, + t_start, + t_end, + order=3, + rtol=0.05, + atol=0.0078, + h_init=0.05, + pcoeff=0.0, + icoeff=1.0, + dcoeff=0.0, + accept_safety=0.81, + eta=0.0, + s_noise=1.0, + ) -> Tuple[Tensor, Dict[str, int]]: if order not in {2, 3}: - raise ValueError('order should be 2 or 3') + raise ValueError("order should be 2 or 3") forward = t_end > t_start if not forward and eta: - raise ValueError('eta must be 0 for reverse sampling') + raise ValueError("eta must be 0 for reverse sampling") h_init = abs(h_init) * (1 if forward else -1) atol = torch.tensor(atol) rtol = torch.tensor(rtol) s = t_start x_prev = x accept = True - pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) - info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} + pid = PIDStepSizeController( + h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety + ) + info = {"steps": 0, "nfe": 0, "n_accept": 0, "n_reject": 0} while s < t_end - 1e-5 if forward else s > t_end + 1e-5: eps_cache: dict = {} - t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) + t = ( + torch.minimum(t_end, s + pid.h) + if forward + else torch.maximum(t_end, s + pid.h) + ) if eta: sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) t_ = torch.minimum(t_end, self.t(sd)) su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 else: - t_, su = t, 0. + t_, su = t, 0.0 - eps, eps_cache = self.eps(eps_cache, 'eps', x, s) + eps, eps_cache = self.eps(eps_cache, "eps", x, s) denoised = x - self.sigma(s) * eps if order == 2: x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) - x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_2_step( + x, s, t_, eps_cache=eps_cache + ) else: - x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) - x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) + x_low, eps_cache = self.dpm_solver_2_step( + x, s, t_, r1=1 / 3, eps_cache=eps_cache + ) + x_high, eps_cache = self.dpm_solver_3_step( + x, s, t_, eps_cache=eps_cache + ) delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 accept = pid.propose_step(error) @@ -400,40 +615,114 @@ def dpm_solver_adaptive(self, x: Tensor, t_start, t_end, order=3, rtol=0.05, ato x_prev = x_low x = x_high + su * s_noise * torch.randn_like(x_high) s = t - info['n_accept'] += 1 + info["n_accept"] += 1 else: - info['n_reject'] += 1 - info['nfe'] += order - info['steps'] += 1 + info["n_reject"] += 1 + info["nfe"] += order + info["steps"] += 1 if self.info_callback is not None: - self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) + self.info_callback( + { + "x": x, + "i": info["steps"] - 1, + "t": s, + "t_up": s, + "denoised": denoised, + "error": error, + "h": pid.h, + **info, + } + ) return x, info @torch.no_grad() -def sample_dpm_fast(model, x: Tensor, sigma_min, sigma_max, n, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, eta=0., s_noise=1.): +def sample_dpm_fast( + model, + x: Tensor, + sigma_min, + sigma_max, + n, + extra_args: Dict[str, Any] = None, + callback: Callable[[Dict], Any] = None, + disable: bool = None, + eta=0.0, + s_noise=1.0, +): """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: - raise ValueError('sigma_min and sigma_max must not be 0') + raise ValueError("sigma_min and sigma_max must not be 0") with tqdm(total=n, disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: - dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) - return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise) + dpm_solver.info_callback = lambda info: callback( + { + "sigma": dpm_solver.sigma(info["t"]), + "sigma_hat": dpm_solver.sigma(info["t_up"]), + **info, + } + ) + return dpm_solver.dpm_solver_fast( + x, + dpm_solver.t(torch.tensor(sigma_max)), + dpm_solver.t(torch.tensor(sigma_min)), + n, + eta, + s_noise, + ) @torch.no_grad() -def sample_dpm_adaptive(model, x: Tensor, sigma_min, sigma_max, extra_args: Dict[str, Any] = None, callback: Callable[[Dict], Any] = None, disable: bool = None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., return_info: bool = False): +def sample_dpm_adaptive( + model, + x: Tensor, + sigma_min, + sigma_max, + extra_args: Dict[str, Any] = None, + callback: Callable[[Dict], Any] = None, + disable: bool = None, + order=3, + rtol=0.05, + atol=0.0078, + h_init=0.05, + pcoeff=0.0, + icoeff=1.0, + dcoeff=0.0, + accept_safety=0.81, + eta=0.0, + s_noise=1.0, + return_info: bool = False, +): """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: - raise ValueError('sigma_min and sigma_max must not be 0') + raise ValueError("sigma_min and sigma_max must not be 0") with tqdm(disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: - dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) - x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) + dpm_solver.info_callback = lambda info: callback( + { + "sigma": dpm_solver.sigma(info["t"]), + "sigma_hat": dpm_solver.sigma(info["t_up"]), + **info, + } + ) + x, info = dpm_solver.dpm_solver_adaptive( + x, + dpm_solver.t(torch.tensor(sigma_max)), + dpm_solver.t(torch.tensor(sigma_min)), + order, + rtol, + atol, + h_init, + pcoeff, + icoeff, + dcoeff, + accept_safety, + eta, + s_noise, + ) if return_info: return x, info return x diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 705e895b..ce53a6d1 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -33,7 +33,7 @@ def to_pil_image(x: Tensor) -> Image.Image: return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2) -def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): +def hf_datasets_augs_helper(examples, transform, image_key, mode="RGB"): """Apply passed in transforms for HuggingFace Datasets.""" images = [transform(image.convert(mode)) for image in examples[image_key]] return {image_key: images} @@ -43,7 +43,9 @@ def append_dims(x: Tensor, target_dims: int) -> Tensor: """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) return x[(...,) + (None,) * dims_to_append] @@ -52,17 +54,19 @@ def n_params(module) -> int: return sum(p.numel() for p in module.parameters()) -def download_file(path: Union[str, Path], url: str, digest: Optional[str] = None) -> Path: +def download_file( + path: Union[str, Path], url: str, digest: Optional[str] = None +) -> Path: """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) if not path.exists(): - with urllib.request.urlopen(url) as response, open(path, 'wb') as f: + with urllib.request.urlopen(url) as response, open(path, "wb") as f: shutil.copyfileobj(response, f) if digest is not None: - file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() + file_digest = hashlib.sha256(open(path, "rb").read()).hexdigest() if digest != file_digest: - raise OSError(f'hash of {path} (url: {url}) failed to validate') + raise OSError(f"hash of {path} (url: {url}) failed to validate") return path @@ -119,8 +123,15 @@ class EMAWarmup: last_epoch (int): The index of last epoch. Default: 0. """ - def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, - last_epoch=0): + def __init__( + self, + inv_gamma=1.0, + power=1.0, + min_value=0.0, + max_value=1.0, + start_at=0, + last_epoch=0, + ): self.inv_gamma = inv_gamma self.power = power self.min_value = min_value @@ -144,7 +155,7 @@ def get_value(self): """Gets the current EMA decay rate.""" epoch = max(0, self.last_epoch - self.start_at) value = 1 - (1 + epoch / self.inv_gamma) ** -self.power - return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) + return 0.0 if epoch < 0 else min(self.max_value, max(self.min_value, value)) def step(self): """Updates the step count.""" @@ -168,28 +179,39 @@ class InverseLR(optim.lr_scheduler._LRScheduler): each update. Default: ``False``. """ - def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., - last_epoch=-1, verbose=False): + def __init__( + self, + optimizer, + inv_gamma=1.0, + power=1.0, + warmup=0.0, + min_lr=0.0, + last_epoch=-1, + verbose=False, + ): self.inv_gamma = inv_gamma self.power = power - if not 0. <= warmup < 1: - raise ValueError('Invalid value for warmup') + if not 0.0 <= warmup < 1: + raise ValueError("Invalid value for warmup") self.warmup = warmup self.min_lr = min_lr super().__init__(optimizer, last_epoch, verbose) def get_lr(self): if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`." + ) return self._get_closed_form_lr() def _get_closed_form_lr(self): warmup = 1 - self.warmup ** (self.last_epoch + 1) lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power - return [warmup * max(self.min_lr, base_lr * lr_mult) - for base_lr in self.base_lrs] + return [ + warmup * max(self.min_lr, base_lr * lr_mult) for base_lr in self.base_lrs + ] class ExponentialLR(optim.lr_scheduler._LRScheduler): @@ -209,53 +231,85 @@ class ExponentialLR(optim.lr_scheduler._LRScheduler): each update. Default: ``False``. """ - def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., - last_epoch=-1, verbose=False): + def __init__( + self, + optimizer, + num_steps, + decay=0.5, + warmup=0.0, + min_lr=0.0, + last_epoch=-1, + verbose=False, + ): self.num_steps = num_steps self.decay = decay - if not 0. <= warmup < 1: - raise ValueError('Invalid value for warmup') + if not 0.0 <= warmup < 1: + raise ValueError("Invalid value for warmup") self.warmup = warmup self.min_lr = min_lr super().__init__(optimizer, last_epoch, verbose) def get_lr(self): if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`." + ) return self._get_closed_form_lr() def _get_closed_form_lr(self): warmup = 1 - self.warmup ** (self.last_epoch + 1) lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch - return [warmup * max(self.min_lr, base_lr * lr_mult) - for base_lr in self.base_lrs] + return [ + warmup * max(self.min_lr, base_lr * lr_mult) for base_lr in self.base_lrs + ] -def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): +def rand_log_normal(shape, loc=0.0, scale=1.0, device="cpu", dtype=torch.float32): """Draws samples from an lognormal distribution.""" return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() -def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): +def rand_log_logistic( + shape, + loc=0.0, + scale=1.0, + min_value=0.0, + max_value=float("inf"), + device="cpu", + dtype=torch.float32, +): """Draws samples from an optionally truncated log-logistic distribution.""" min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) min_cdf = min_value.log().sub(loc).div(scale).sigmoid() max_cdf = max_value.log().sub(loc).div(scale).sigmoid() - u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf + u = ( + torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + + min_cdf + ) return u.logit().mul(scale).add(loc).exp().to(dtype) -def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): +def rand_log_uniform(shape, min_value, max_value, device="cpu", dtype=torch.float32): """Draws samples from an log-uniform distribution.""" min_value = math.log(min_value) max_value = math.log(max_value) - return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() - - -def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): + return ( + torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + + min_value + ).exp() + + +def rand_v_diffusion( + shape, + sigma_data=1.0, + min_value=0.0, + max_value=float("inf"), + device="cpu", + dtype=torch.float32, +): """Draws samples from a truncated v-diffusion training timestep distribution.""" min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi @@ -263,7 +317,9 @@ def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), return torch.tan(u * math.pi / 2) * sigma_data -def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32): +def rand_split_log_normal( + shape, loc, scale_1, scale_2, device="cpu", dtype=torch.float32 +): """Draws samples from a split lognormal distribution.""" n = torch.randn(shape, device=device, dtype=dtype).abs() u = torch.rand(shape, device=device, dtype=dtype) @@ -277,13 +333,27 @@ class FolderOfImages(data.Dataset): """Recursively finds all images in a directory. It does not support classes/targets.""" - IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'} + IMG_EXTENSIONS = { + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", + } def __init__(self, root, transform=None): super().__init__() self.root = Path(root) self.transform = nn.Identity() if transform is None else transform - self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS) + self.paths = sorted( + path + for path in self.root.rglob("*") + if path.suffix.lower() in self.IMG_EXTENSIONS + ) def __repr__(self): return f'FolderOfImages(root="{self.root}", len: {len(self)})' @@ -293,10 +363,10 @@ def __len__(self): def __getitem__(self, key): path = self.paths[key] - with open(path, 'rb') as f: - image = Image.open(f).convert('RGB') + with open(path, "rb") as f: + image = Image.open(f).convert("RGB") image = self.transform(image) - return image, + return (image,) class CSVLogger: @@ -304,13 +374,13 @@ def __init__(self, filename, columns): self.filename = Path(filename) self.columns = columns if self.filename.exists(): - self.file = open(self.filename, 'a') + self.file = open(self.filename, "a") else: - self.file = open(self.filename, 'w') + self.file = open(self.filename, "w") self.write(*self.columns) def write(self, *args): - print(*args, sep=',', file=self.file, flush=True) + print(*args, sep=",", file=self.file, flush=True) @contextmanager From a3af42c16f0d679d80d0fdf0b92bc10f357d3bbd Mon Sep 17 00:00:00 2001 From: jorektheglitch Date: Thu, 3 Nov 2022 20:23:36 +0300 Subject: [PATCH 3/5] Improve typehints in config.py --- k_diffusion/config.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/k_diffusion/config.py b/k_diffusion/config.py index f8762697..72cdea67 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -1,37 +1,50 @@ from functools import partial import json import math -import warnings -from typing import Any, BinaryIO, List, Optional, TextIO, TypedDict, Union +from typing import Any, BinaryIO, TextIO, TypedDict +from typing import Callable, List, Optional, Tuple, Union from jsonmerge import merge -from . import augmentation, layers, models, utils +from .models import ImageDenoiserModelV1 +from .augmentation import KarrasAugmentWrapper +from . import layers, utils class ModelConfig(TypedDict): - sigma_data: float + type: str + input_channels: int + input_size: Tuple[int, int] patch_size: int + mapping_out: int + depths: List[int] + channels: List[int] + self_attn_depths: List[bool] + has_variance: bool dropout_rate: float augment_wrapper: bool augment_prob: float + sigma_data: float + sigma_min: float + sigma_max: float + sigma_sample_density: dict mapping_cond_dim: int unet_cond_dim: int cross_cond_dim: int cross_attn_depths: Optional[Any] skip_stages: int - has_variance: bool class DatasetConfig(TypedDict): type: str + location: str class OptimizerConfig(TypedDict): type: str lr: float - betas: List[float] + betas: Tuple[float, float] # actually in JSON it's a list with two numbers eps: float weight_decay: float @@ -41,6 +54,7 @@ class LRSchedConfig(TypedDict): inv_gamma: float power: float warmup: float + max_value: float class EMASchedConfig(TypedDict): @@ -58,19 +72,19 @@ class Config(TypedDict): def load_config(file: Union[BinaryIO, TextIO]) -> Config: - defaults: Config = { + defaults = { "model": { - "sigma_data": 1.0, "patch_size": 1, + "has_variance": False, "dropout_rate": 0.0, "augment_wrapper": True, "augment_prob": 0.0, + "sigma_data": 1.0, "mapping_cond_dim": 0, "unet_cond_dim": 0, "cross_cond_dim": 0, "cross_attn_depths": None, "skip_stages": 0, - "has_variance": False, }, "dataset": { "type": "imagefolder", @@ -94,10 +108,13 @@ def load_config(file: Union[BinaryIO, TextIO]) -> Config: return merge(defaults, config) -def make_model(config: Config): +def make_model( + config: Config, +) -> Union[ImageDenoiserModelV1, KarrasAugmentWrapper]: model_config = config["model"] assert model_config["type"] == "image_v1" - model = models.ImageDenoiserModelV1( + model: Union[ImageDenoiserModelV1, KarrasAugmentWrapper] + model = ImageDenoiserModelV1( model_config["input_channels"], model_config["mapping_out"], model_config["depths"], @@ -114,11 +131,11 @@ def make_model(config: Config): has_variance=model_config["has_variance"], ) if model_config["augment_wrapper"]: - model = augmentation.KarrasAugmentWrapper(model) + model = KarrasAugmentWrapper(model) return model -def make_denoiser_wrapper(config: Config): +def make_denoiser_wrapper(config: Config) -> Callable[..., Union[layers.Denoiser, layers.DenoiserWithVariance]]: model_config = config["model"] sigma_data = model_config.get("sigma_data", 1.0) has_variance = model_config.get("has_variance", False) From b2f1b0954d2c9a31ee975064cd7ad696a46d05c6 Mon Sep 17 00:00:00 2001 From: jorektheglitch Date: Thu, 3 Nov 2022 20:25:40 +0300 Subject: [PATCH 4/5] Codestyle and typehints fixes in train.py --- train.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index eec6d782..ac3d57bf 100755 --- a/train.py +++ b/train.py @@ -8,19 +8,20 @@ import math import json from pathlib import Path +from typing import Union import accelerate import torch -from torch import nn, optim +from torch import optim from torch import multiprocessing as mp from torch.utils import data from torchvision import datasets, transforms, utils -from tqdm.auto import trange, tqdm +from tqdm.auto import tqdm import k_diffusion as K -def main(): +def main() -> None: p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) p.add_argument('--batch-size', type=int, default=64, @@ -104,6 +105,7 @@ def main(): log_config['parameters'] = K.utils.n_params(inner_model) wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True) + opt: Union[optim.AdamW, optim.SGD] if opt_config['type'] == 'adamw': opt = optim.AdamW(inner_model.parameters(), lr=opt_config['lr'] if args.lr is None else args.lr, @@ -214,7 +216,7 @@ def main(): ema_sched.load_state_dict(ckpt['ema_sched']) epoch = ckpt['epoch'] + 1 step = ckpt['step'] + 1 - if args.gns and ckpt.get('gns_stats', None) is not None: + if gns_stats and ckpt.get('gns_stats', None) is not None: gns_stats.load_state_dict(ckpt['gns_stats']) del ckpt @@ -305,7 +307,7 @@ def save(): losses_all = accelerator.gather(losses) loss = losses_all.mean() accelerator.backward(losses.mean()) - if args.gns: + if gns_stats: sq_norm_small_batch, sq_norm_large_batch = gns_stats_hook.get_stats() gns_stats.update(sq_norm_small_batch, sq_norm_large_batch, reals.shape[0], reals.shape[0] * accelerator.num_processes) opt.step() @@ -318,7 +320,7 @@ def save(): if accelerator.is_main_process: if step % 25 == 0: - if args.gns: + if gns_stats: tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}, gns: {gns_stats.get_gns():g}') else: tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}') @@ -330,7 +332,7 @@ def save(): 'lr': sched.get_last_lr()[0], 'ema_decay': ema_decay, } - if args.gns: + if gns_stats: log_dict['gradient_noise_scale'] = gns_stats.get_gns() wandb.log(log_dict, step=step) From 595a7dd06a67d4472070ee3bf95393d3b3bf635c Mon Sep 17 00:00:00 2001 From: jorektheglitch Date: Thu, 3 Nov 2022 22:26:20 +0300 Subject: [PATCH 5/5] Return aestetical matrixes in code --- k_diffusion/augmentation.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/k_diffusion/augmentation.py b/k_diffusion/augmentation.py index 126c780b..aef58040 100644 --- a/k_diffusion/augmentation.py +++ b/k_diffusion/augmentation.py @@ -13,20 +13,28 @@ def translate2d(tx: float, ty: float) -> Tensor: - mat = [[1, 0, tx], [0, 1, ty], [0, 0, 1]] + mat = [ + [1, 0, tx], + [0, 1, ty], + [0, 0, 1] + ] return torch.tensor(mat, dtype=torch.float32) def scale2d(sx: float, sy: float) -> Tensor: - mat = [[sx, 0, 0], [0, sy, 0], [0, 0, 1]] + mat = [ + [sx, 0, 0], + [0, sy, 0], + [0, 0, 1] + ] return torch.tensor(mat, dtype=torch.float32) def rotate2d(theta: Tensor) -> Tensor: mat = [ [torch.cos(theta), torch.sin(-theta), 0], - [torch.sin(theta), torch.cos(theta), 0], - [0, 0, 1], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], ] return torch.tensor(mat, dtype=torch.float32)