Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions k_diffusion/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,52 @@
from functools import reduce
import math
import operator
from functools import reduce

from typing import Tuple

import numpy as np
from PIL.Image import Image
from skimage import transform
import torch
from torch import nn
from torch import Tensor


def translate2d(tx, ty):
mat = [[1, 0, tx],
[0, 1, ty],
[0, 0, 1]]
def translate2d(tx: float, ty: float) -> Tensor:
mat = [
[1, 0, tx],
[0, 1, ty],
[0, 0, 1]
]
return torch.tensor(mat, dtype=torch.float32)


def scale2d(sx, sy):
mat = [[sx, 0, 0],
[ 0, sy, 0],
[ 0, 0, 1]]
def scale2d(sx: float, sy: float) -> Tensor:
mat = [
[sx, 0, 0],
[0, sy, 0],
[0, 0, 1]
]
return torch.tensor(mat, dtype=torch.float32)


def rotate2d(theta):
mat = [[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[ 0, 0, 1]]
def rotate2d(theta: Tensor) -> Tensor:
mat = [
[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[0, 0, 1],
]
return torch.tensor(mat, dtype=torch.float32)


class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1 / 8):
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
self.a_trans = a_trans

def __call__(self, image):
def __call__(self, image: Image) -> Tuple[Image, Tensor, Tensor]:
h, w = image.size
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]

Expand All @@ -50,7 +60,7 @@ def __call__(self, image):
# scaling
do = (torch.rand([]) < self.a_prob).float()
a2 = torch.randn([]) * do
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
mats.append(scale2d(self.a_scale**a2, self.a_scale**a2))
# rotation
do = (torch.rand([]) < self.a_prob).float()
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
Expand All @@ -60,7 +70,7 @@ def __call__(self, image):
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
a5 = torch.randn([]) * do
mats.append(rotate2d(a4))
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
mats.append(scale2d(self.a_aniso**a5, self.a_aniso**-a5))
mats.append(rotate2d(-a4))
# translation
do = (torch.rand([]) < self.a_prob).float()
Expand All @@ -71,15 +81,25 @@ def __call__(self, image):
# form the transformation matrix and conditioning vector
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
mat = reduce(operator.matmul, mats)
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
cond = torch.stack(
[a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]
)

# apply the transformation
image_orig = np.array(image, dtype=np.float32) / 255
if image_orig.ndim == 2:
image_orig = image_orig[..., None]
image_np = np.array(image, dtype=np.float32) / 255
if image_np.ndim == 2:
image_np = image_np[..., None]
tf = transform.AffineTransform(mat.numpy())
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
image = transform.warp(
image_np,
tf.inverse,
order=3,
mode="reflect",
cval=0.5,
clip=False,
preserve_range=True,
)
image_orig = torch.as_tensor(image_np).movedim(2, 0) * 2 - 1
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
return image, image_orig, cond

Expand All @@ -88,7 +108,7 @@ class KarrasAugmentWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model

def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
if aug_cond is None:
aug_cond = input.new_zeros([input.shape[0], 9])
Expand Down
242 changes: 161 additions & 81 deletions k_diffusion/config.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,190 @@
from functools import partial
import json
import math
import warnings

from typing import Any, BinaryIO, TextIO, TypedDict
from typing import Callable, List, Optional, Tuple, Union

from jsonmerge import merge

from . import augmentation, layers, models, utils
from .models import ImageDenoiserModelV1
from .augmentation import KarrasAugmentWrapper
from . import layers, utils


class ModelConfig(TypedDict):
type: str
input_channels: int
input_size: Tuple[int, int]
patch_size: int
mapping_out: int
depths: List[int]
channels: List[int]
self_attn_depths: List[bool]
has_variance: bool
dropout_rate: float
augment_wrapper: bool
augment_prob: float
sigma_data: float
sigma_min: float
sigma_max: float
sigma_sample_density: dict
mapping_cond_dim: int
unet_cond_dim: int
cross_cond_dim: int
cross_attn_depths: Optional[Any]
skip_stages: int


class DatasetConfig(TypedDict):
type: str
location: str


class OptimizerConfig(TypedDict):
type: str
lr: float
betas: Tuple[float, float] # actually in JSON it's a list with two numbers
eps: float
weight_decay: float


def load_config(file):
class LRSchedConfig(TypedDict):
type: str
inv_gamma: float
power: float
warmup: float
max_value: float


class EMASchedConfig(TypedDict):
type: str
power: float
max_value: float


class Config(TypedDict):
model: ModelConfig
dataset: DatasetConfig
optimizer: OptimizerConfig
lr_sched: LRSchedConfig
ema_sched: EMASchedConfig


def load_config(file: Union[BinaryIO, TextIO]) -> Config:
defaults = {
'model': {
'sigma_data': 1.,
'patch_size': 1,
'dropout_rate': 0.,
'augment_wrapper': True,
'augment_prob': 0.,
'mapping_cond_dim': 0,
'unet_cond_dim': 0,
'cross_cond_dim': 0,
'cross_attn_depths': None,
'skip_stages': 0,
'has_variance': False,
},
'dataset': {
'type': 'imagefolder',
"model": {
"patch_size": 1,
"has_variance": False,
"dropout_rate": 0.0,
"augment_wrapper": True,
"augment_prob": 0.0,
"sigma_data": 1.0,
"mapping_cond_dim": 0,
"unet_cond_dim": 0,
"cross_cond_dim": 0,
"cross_attn_depths": None,
"skip_stages": 0,
},
'optimizer': {
'type': 'adamw',
'lr': 1e-4,
'betas': [0.95, 0.999],
'eps': 1e-6,
'weight_decay': 1e-3,
"dataset": {
"type": "imagefolder",
},
'lr_sched': {
'type': 'inverse',
'inv_gamma': 20000.,
'power': 1.,
'warmup': 0.99,
"optimizer": {
"type": "adamw",
"lr": 1e-4,
"betas": [0.95, 0.999],
"eps": 1e-6,
"weight_decay": 1e-3,
},
'ema_sched': {
'type': 'inverse',
'power': 0.6667,
'max_value': 0.9999
"lr_sched": {
"type": "inverse",
"inv_gamma": 20000.0,
"power": 1.0,
"warmup": 0.99,
},
"ema_sched": {"type": "inverse", "power": 0.6667, "max_value": 0.9999},
}
config = json.load(file)
return merge(defaults, config)


def make_model(config):
config = config['model']
assert config['type'] == 'image_v1'
model = models.ImageDenoiserModelV1(
config['input_channels'],
config['mapping_out'],
config['depths'],
config['channels'],
config['self_attn_depths'],
config['cross_attn_depths'],
patch_size=config['patch_size'],
dropout_rate=config['dropout_rate'],
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
unet_cond_dim=config['unet_cond_dim'],
cross_cond_dim=config['cross_cond_dim'],
skip_stages=config['skip_stages'],
has_variance=config['has_variance'],
def make_model(
config: Config,
) -> Union[ImageDenoiserModelV1, KarrasAugmentWrapper]:
model_config = config["model"]
assert model_config["type"] == "image_v1"
model: Union[ImageDenoiserModelV1, KarrasAugmentWrapper]
model = ImageDenoiserModelV1(
model_config["input_channels"],
model_config["mapping_out"],
model_config["depths"],
model_config["channels"],
model_config["self_attn_depths"],
model_config["cross_attn_depths"],
patch_size=model_config["patch_size"],
dropout_rate=model_config["dropout_rate"],
mapping_cond_dim=model_config["mapping_cond_dim"]
+ (9 if model_config["augment_wrapper"] else 0),
unet_cond_dim=model_config["unet_cond_dim"],
cross_cond_dim=model_config["cross_cond_dim"],
skip_stages=model_config["skip_stages"],
has_variance=model_config["has_variance"],
)
if config['augment_wrapper']:
model = augmentation.KarrasAugmentWrapper(model)
if model_config["augment_wrapper"]:
model = KarrasAugmentWrapper(model)
return model


def make_denoiser_wrapper(config):
config = config['model']
sigma_data = config.get('sigma_data', 1.)
has_variance = config.get('has_variance', False)
def make_denoiser_wrapper(config: Config) -> Callable[..., Union[layers.Denoiser, layers.DenoiserWithVariance]]:
model_config = config["model"]
sigma_data = model_config.get("sigma_data", 1.0)
has_variance = model_config.get("has_variance", False)
if not has_variance:
return partial(layers.Denoiser, sigma_data=sigma_data)
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)


def make_sample_density(config):
sd_config = config['sigma_sample_density']
sigma_data = config['sigma_data']
if sd_config['type'] == 'lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
def make_sample_density(config: ModelConfig):
sd_config = config["sigma_sample_density"]
sigma_data = config["sigma_data"]
if sd_config["type"] == "lognormal":
loc = sd_config["mean"] if "mean" in sd_config else sd_config["loc"]
scale = sd_config["std"] if "std" in sd_config else sd_config["scale"]
return partial(utils.rand_log_normal, loc=loc, scale=scale)
if sd_config['type'] == 'loglogistic':
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'loguniform':
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
if sd_config["type"] == "loglogistic":
loc = sd_config["loc"] if "loc" in sd_config else math.log(sigma_data)
scale = sd_config["scale"] if "scale" in sd_config else 0.5
min_value = sd_config["min_value"] if "min_value" in sd_config else 0.0
max_value = sd_config["max_value"] if "max_value" in sd_config else float("inf")
return partial(
utils.rand_log_logistic,
loc=loc,
scale=scale,
min_value=min_value,
max_value=max_value,
)
if sd_config["type"] == "loguniform":
min_value = (
sd_config["min_value"] if "min_value" in sd_config else config["sigma_min"]
)
max_value = (
sd_config["max_value"] if "max_value" in sd_config else config["sigma_max"]
)
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'v-diffusion':
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'split-lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
raise ValueError('Unknown sample density type')
if sd_config["type"] == "v-diffusion":
min_value = sd_config["min_value"] if "min_value" in sd_config else 0.0
max_value = sd_config["max_value"] if "max_value" in sd_config else float("inf")
return partial(
utils.rand_v_diffusion,
sigma_data=sigma_data,
min_value=min_value,
max_value=max_value,
)
if sd_config["type"] == "split-lognormal":
loc = sd_config["mean"] if "mean" in sd_config else sd_config["loc"]
scale_1 = sd_config["std_1"] if "std_1" in sd_config else sd_config["scale_1"]
scale_2 = sd_config["std_2"] if "std_2" in sd_config else sd_config["scale_2"]
return partial(
utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2
)
raise ValueError("Unknown sample density type")
Loading