diff --git a/.gitignore b/.gitignore index 8b626aa..2c898a0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ *.py[cod] *$py.class +*.pyc # C extensions *.so diff --git a/gan_inv/PTI.py b/gan_inv/PTI.py new file mode 100644 index 0000000..1b39679 --- /dev/null +++ b/gan_inv/PTI.py @@ -0,0 +1,53 @@ +import torch +from inversion import inverse_image,get_lr + +from tqdm import tqdm +from torch.nn import functional as F +from lpips import util +def toogle_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +class PTI: + def __init__(self,G,l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ): + self.g_ema = G + self.l2_lambda = l2_lambda + self.max_pti_step = max_pti_step + self.pti_lr = pti_lr + def cacl_loss(self,percept, generated_image,real_image): + + mse_loss = F.mse_loss(generated_image, real_image) + p_loss = percept(generated_image, real_image).sum() + loss = p_loss +self.l2_lambda * mse_loss + return loss + + def train(self,img): + inversed_result = inverse_image(self.g_ema,img,self.g_ema.img_resolution) + w_pivot = inversed_result['latent'] + ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1]) + toogle_grad(self.g_ema,True) + percept = util.PerceptualLoss( + model="net-lin", net="vgg", use_gpu='cuda:0' + ) + optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr) + print('start PTI') + pbar = tqdm(range(self.max_pti_step)) + for i in pbar: + lr = get_lr(i, self.pti_lr) + optimizer.param_groups[0]["lr"] = lr + + generated_image,feature = self.g_ema.synthesis(ws,noise_mode='const') + loss = self.cacl_loss(percept,generated_image,inversed_result['real']) + pbar.set_description( + ( + f"loss: {loss.item():.4f}" + ) + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + with torch.no_grad(): + generated_image = self.g_ema.synthesis(ws, noise_mode='const') + + return generated_image diff --git a/gan_inv/__init__.py b/gan_inv/__init__.py new file mode 100644 index 0000000..939e7c6 --- /dev/null +++ b/gan_inv/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/gan_inv/__pycache__/PTI.cpython-39.pyc b/gan_inv/__pycache__/PTI.cpython-39.pyc new file mode 100644 index 0000000..1454320 Binary files /dev/null and b/gan_inv/__pycache__/PTI.cpython-39.pyc differ diff --git a/gan_inv/__pycache__/__init__.cpython-39.pyc b/gan_inv/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..e02497c Binary files /dev/null and b/gan_inv/__pycache__/__init__.cpython-39.pyc differ diff --git a/gan_inv/__pycache__/inversion.cpython-39.pyc b/gan_inv/__pycache__/inversion.cpython-39.pyc new file mode 100644 index 0000000..386a933 Binary files /dev/null and b/gan_inv/__pycache__/inversion.cpython-39.pyc differ diff --git a/gan_inv/checkpoints/weights/v0.1/vgg.pth b/gan_inv/checkpoints/weights/v0.1/vgg.pth new file mode 100644 index 0000000..47e943c Binary files /dev/null and b/gan_inv/checkpoints/weights/v0.1/vgg.pth differ diff --git a/gan_inv/inversion.py b/gan_inv/inversion.py new file mode 100644 index 0000000..c03a383 --- /dev/null +++ b/gan_inv/inversion.py @@ -0,0 +1,343 @@ +import math +import os +from viz import renderer +import torch +from torch import optim +from torch.nn import functional as F +from torchvision import transforms +from PIL import Image +from tqdm import tqdm +import dataclasses +import dnnlib +from .lpips import util +import imageio + + + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + + return initial_lr * lr_ramp + + + + + +def make_image(tensor): + return ( + tensor.detach() + .clamp_(min=-1, max=1) + .add(1) + .div_(2) + .mul(255) + .type(torch.uint8) + .permute(0, 2, 3, 1) + .to("cpu") + .numpy() + ) + + +@dataclasses.dataclass +class InverseConfig: + lr_warmup = 0.05 + lr_decay = 0.25 + lr = 0.1 + noise = 0.05 + noise_decay = 0.75 + step = 1000 + noise_regularize = 1e5 + mse = 0.1 + + + +def inverse_image( + g_ema, + image, + percept, + image_size=256, + w_plus = False, + config=InverseConfig(), + device='cuda:0' +): + args = config + + n_mean_latent = 10000 + + resize = min(image_size, 256) + + if torch.is_tensor(image)==False: + transform = transforms.Compose( + [ + transforms.Resize(resize,), + transforms.CenterCrop(resize), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + img = transform(image) + + else: + img = transforms.functional.resize(image,resize) + transform = transforms.Compose( + [ + transforms.CenterCrop(resize), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + img = transform(img) + imgs = [] + imgs.append(img) + imgs = torch.stack(imgs, 0).to(device) + + with torch.no_grad(): + + #noise_sample = torch.randn(n_mean_latent, 512, device=device) + noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device) + #label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device) + w_samples = g_ema.mapping(noise_sample,None) + w_samples = w_samples[:, :1, :] + w_avg = w_samples.mean(0) + w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5 + + + + + noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name} + for noise in noises.values(): + noise = torch.randn_like(noise) + noise.requires_grad = True + + + + w_opt = w_avg.detach().clone() + if w_plus: + w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1) + w_opt.requires_grad = True + #if args.w_plus: + #latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) + + + + optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr) + + pbar = tqdm(range(args.step)) + latent_path = [] + + for i in pbar: + t = i / args.step + lr = get_lr(t, args.lr) + optimizer.param_groups[0]["lr"] = lr + noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2 + + w_noise = torch.randn_like(w_opt) * noise_strength + if w_plus: + ws = w_opt + w_noise + else: + ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1]) + + img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True) + + #latent_n = latent_noise(latent_in, noise_strength.item()) + + #latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises) + #img_gen, F = g_ema.generate(latent, noise) + + # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. + + if img_gen.shape[2] > 256: + img_gen = F.interpolate(img_gen, size=(256, 256), mode='area') + + p_loss = percept(img_gen,imgs) + + + # Noise regularization. + reg_loss = 0.0 + for v in noises.values(): + noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() + while True: + reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + mse_loss = F.mse_loss(img_gen, imgs) + + loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Normalize noise. + with torch.no_grad(): + for buf in noises.values(): + buf -= buf.mean() + buf *= buf.square().mean().rsqrt() + + if (i + 1) % 100 == 0: + latent_path.append(w_opt.detach().clone()) + + pbar.set_description( + ( + f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};" + f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" + ) + ) + + #latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises) + #img_gen, F = g_ema.generate(latent, noise) + if w_plus: + ws = latent_path[-1] + else: + ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1]) + + img_gen = g_ema.synthesis(ws, noise_mode='const') + + + result = { + "latent": latent_path[-1], + "sample": img_gen, + "real": imgs, + } + + return result + +def toogle_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +class PTI: + def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ): + self.g_ema = G + self.l2_lambda = l2_lambda + self.max_pti_step = max_pti_step + self.pti_lr = pti_lr + self.percept = percept + def cacl_loss(self,percept, generated_image,real_image): + + mse_loss = F.mse_loss(generated_image, real_image) + p_loss = percept(generated_image, real_image).sum() + loss = p_loss +self.l2_lambda * mse_loss + return loss + + def train(self,img,w_plus=False): + if torch.is_tensor(img) == False: + transform = transforms.Compose( + [ + transforms.Resize(self.g_ema.img_resolution, ), + transforms.CenterCrop(self.g_ema.img_resolution), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + real_img = transform(img).to('cuda').unsqueeze(0) + + else: + img = transforms.functional.resize(img, self.g_ema.img_resolution) + transform = transforms.Compose( + [ + transforms.CenterCrop(self.g_ema.img_resolution), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + real_img = transform(img).to('cuda').unsqueeze(0) + inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus) + w_pivot = inversed_result['latent'] + if w_plus: + ws = w_pivot + else: + ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1]) + toogle_grad(self.g_ema,True) + optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr) + print('start PTI') + pbar = tqdm(range(self.max_pti_step)) + for i in pbar: + t = i / self.max_pti_step + lr = get_lr(t, self.pti_lr) + optimizer.param_groups[0]["lr"] = lr + + generated_image = self.g_ema.synthesis(ws,noise_mode='const') + loss = self.cacl_loss(self.percept,generated_image,real_img) + pbar.set_description( + ( + f"loss: {loss.item():.4f}" + ) + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + with torch.no_grad(): + generated_image = self.g_ema.synthesis(ws, noise_mode='const') + + return generated_image,ws + +if __name__ == "__main__": + state = { + "images": { + # image_orig: the original image, change with seed/model is changed + # image_raw: image with mask and points, change durning optimization + # image_show: image showed on screen + }, + "temporal_params": { + # stop + }, + 'mask': + None, # mask for visualization, 1 for editing and 0 for unchange + 'last_mask': None, # last edited mask + 'show_mask': True, # add button + "generator_params": dnnlib.EasyDict(), + "params": { + "seed": 0, + "motion_lambda": 20, + "r1_in_pixels": 3, + "r2_in_pixels": 12, + "magnitude_direction_in_pixels": 1.0, + "latent_space": "w+", + "trunc_psi": 0.7, + "trunc_cutoff": None, + "lr": 0.001, + }, + "device": 'cuda:0', + "draw_interval": 1, + "renderer": renderer.Renderer(disable_timing=True), + "points": {}, + "curr_point": None, + "curr_type_point": "start", + 'editing_state': 'add_points', + 'pretrained_weight': 'stylegan2_horses_256_pytorch' + } + cache_dir = '../checkpoints' + valid_checkpoints_dict = { + f.split('/')[-1].split('.')[0]: os.path.join(cache_dir, f) + for f in os.listdir(cache_dir) + if (f.endswith('pkl') and os.path.exists(os.path.join(cache_dir, f))) + } + state['renderer'].init_network(state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr + ) + image = Image.open('/home/tianhao/research/drag3d/horse/render/0.png') + G = state['renderer'].G + #result = inverse_image(G,image,G.img_resolution) + percept = util.PerceptualLoss( + model="net-lin", net="vgg", use_gpu=True + ) + pti = PTI(G,percept) + result = pti.train(image,True) + imageio.imsave('../horse/test.png', make_image(result[0])[0]) + + + diff --git a/gan_inv/lpips/__init__.py b/gan_inv/lpips/__init__.py new file mode 100644 index 0000000..25f4ddc --- /dev/null +++ b/gan_inv/lpips/__init__.py @@ -0,0 +1,5 @@ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + diff --git a/gan_inv/lpips/__pycache__/__init__.cpython-39.pyc b/gan_inv/lpips/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..acb6b4a Binary files /dev/null and b/gan_inv/lpips/__pycache__/__init__.cpython-39.pyc differ diff --git a/gan_inv/lpips/__pycache__/base_model.cpython-39.pyc b/gan_inv/lpips/__pycache__/base_model.cpython-39.pyc new file mode 100644 index 0000000..8777517 Binary files /dev/null and b/gan_inv/lpips/__pycache__/base_model.cpython-39.pyc differ diff --git a/gan_inv/lpips/__pycache__/dist_model.cpython-39.pyc b/gan_inv/lpips/__pycache__/dist_model.cpython-39.pyc new file mode 100644 index 0000000..86f2878 Binary files /dev/null and b/gan_inv/lpips/__pycache__/dist_model.cpython-39.pyc differ diff --git a/gan_inv/lpips/__pycache__/networks_basic.cpython-39.pyc b/gan_inv/lpips/__pycache__/networks_basic.cpython-39.pyc new file mode 100644 index 0000000..f1c9166 Binary files /dev/null and b/gan_inv/lpips/__pycache__/networks_basic.cpython-39.pyc differ diff --git a/gan_inv/lpips/__pycache__/pretrained_networks.cpython-39.pyc b/gan_inv/lpips/__pycache__/pretrained_networks.cpython-39.pyc new file mode 100644 index 0000000..388aec6 Binary files /dev/null and b/gan_inv/lpips/__pycache__/pretrained_networks.cpython-39.pyc differ diff --git a/gan_inv/lpips/__pycache__/util.cpython-39.pyc b/gan_inv/lpips/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000..cfaa8a4 Binary files /dev/null and b/gan_inv/lpips/__pycache__/util.cpython-39.pyc differ diff --git a/gan_inv/lpips/base_model.py b/gan_inv/lpips/base_model.py new file mode 100644 index 0000000..8de1d16 --- /dev/null +++ b/gan_inv/lpips/base_model.py @@ -0,0 +1,58 @@ +import os +import numpy as np +import torch +from torch.autograd import Variable +from pdb import set_trace as st +from IPython import embed + +class BaseModel(): + def __init__(self): + pass; + + def name(self): + return 'BaseModel' + + def initialize(self, use_gpu=True, gpu_ids=[0]): + self.use_gpu = use_gpu + self.gpu_ids = gpu_ids + + def forward(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print('Loading network from %s'%save_path) + network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, 'done_flag'),flag) + np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') diff --git a/gan_inv/lpips/dist_model.py b/gan_inv/lpips/dist_model.py new file mode 100644 index 0000000..23bf66a --- /dev/null +++ b/gan_inv/lpips/dist_model.py @@ -0,0 +1,314 @@ + +from __future__ import absolute_import + +import sys +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import itertools +from .base_model import BaseModel +from scipy.ndimage import zoom +import fractions +import functools +import skimage.transform +from tqdm import tqdm +import urllib + +from IPython import embed + +from . import networks_basic as networks +from . import util + + +class DownloadProgressBar(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + + +def get_path(base_path): + BASE_DIR = os.path.join('checkpoints') + + save_path = os.path.join(BASE_DIR, base_path) + if not os.path.exists(save_path): + url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}" + print(f'{base_path} not found') + print('Try to download from huggingface: ', url) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + download_url(url, save_path) + print('Downloaded to ', save_path) + return save_path + + +def download_url(url, output_path): + with DownloadProgressBar(unit='B', unit_scale=True, + miniters=1, desc=url.split('/')[-1]) as t: + urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) + + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, + use_gpu=True, printNet=False, spatial=False, + is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): + ''' + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). + spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. + spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + gpu_ids - int array - [0] by default, gpus to use + ''' + BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model_name = '%s [%s]' % (model, net) + + if(self.model == 'net-lin'): # pretrained net + linear layer + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, + use_dropout=True, spatial=spatial, version=version, lpips=True) + kw = {} + if not use_gpu: + kw['map_location'] = 'cpu' + if(model_path is None): + model_path = get_path('weights/v%s/%s.pth' % (version, net)) + + if(not is_train): + print('Loading model from: %s' % model_path) + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif(self.model == 'net'): # pretrained network + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif(self.model in ['L2', 'l2']): + self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing + self.model_name = 'L2' + elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']): + self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace) + self.model_name = 'SSIM' + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = networks.BCERankingLoss() + self.parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + if(use_gpu): + self.net.to(gpu_ids[0]) + self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + if(self.is_train): + self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if(printNet): + print('---------- Networks initialized -------------') + networks.print_network(self.net) + print('-----------------------------------------------') + + def forward(self, in0, in1, retPerLayer=False): + ''' Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + ''' + + return self.net.forward(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if(hasattr(module, 'weight') and module.kernel_size == (1, 1)): + module.weight.data = torch.clamp(module.weight.data, min=0) + + def set_input(self, data): + self.input_ref = data['ref'] + self.input_p0 = data['p0'] + self.input_p1 = data['p1'] + self.input_judge = data['judge'] + + if(self.use_gpu): + self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + self.var_ref = Variable(self.input_ref, requires_grad=True) + self.var_p0 = Variable(self.input_p0, requires_grad=True) + self.var_p1 = Variable(self.input_p1, requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + self.d0 = self.forward(self.var_ref, self.var_p0) + self.d1 = self.forward(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) + + self.var_judge = Variable(1. * self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self, d0, d1, judge): + ''' d0, d1 are Variables, judge is a Tensor ''' + d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() + judge_per = judge.cpu().numpy().flatten() + return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) + + def get_current_errors(self): + retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()), + ('acc_r', self.acc_r)]) + + for key in retDict.keys(): + retDict[key] = np.mean(retDict[key]) + + return retDict + + def get_current_visuals(self): + zoom_factor = 256 / self.var_ref.data.size()[2] + + ref_img = util.tensor2im(self.var_ref.data) + p0_img = util.tensor2im(self.var_p0.data) + p1_img = util.tensor2im(self.var_p1.data) + + ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) + p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) + p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) + + return OrderedDict([('ref', ref_img_vis), + ('p0', p0_img_vis), + ('p1', p1_img_vis)]) + + def save(self, path, label): + if(self.use_gpu): + self.save_network(self.net.module, path, '', label) + else: + self.save_network(self.net, path, '', label) + self.save_network(self.rankLoss.net, path, 'rank', label) + + def update_learning_rate(self, nepoch_decay): + lrd = self.lr / nepoch_decay + lr = self.old_lr - lrd + + for param_group in self.optimizer_net.param_groups: + param_group['lr'] = lr + + print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr)) + self.old_lr = lr + + +def score_2afc_dataset(data_loader, func, name=''): + ''' Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + ''' + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist() + d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist() + gts += data['judge'].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5 + + return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) + + +def score_jnd_dataset(data_loader, func, name=''): + ''' Function computes JND score using distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return pytorch array of length N + OUTPUTS + [0] - JND score in [0,1], mAP score (area under precision-recall curve) + [1] - dictionary with following elements + ds - N array containing distances between two patches shown to human evaluator + sames - N array containing fraction of people who thought the two patches were identical + CONSTS + N - number of test triplets in data_loader + ''' + + ds = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist() + gts += data['same'].cpu().numpy().flatten().tolist() + + sames = np.array(gts) + ds = np.array(ds) + + sorted_inds = np.argsort(ds) + ds_sorted = ds[sorted_inds] + sames_sorted = sames[sorted_inds] + + TPs = np.cumsum(sames_sorted) + FPs = np.cumsum(1 - sames_sorted) + FNs = np.sum(sames_sorted) - TPs + + precs = TPs / (TPs + FPs) + recs = TPs / (TPs + FNs) + score = util.voc_ap(recs, precs) + + return(score, dict(ds=ds, sames=sames)) diff --git a/gan_inv/lpips/networks_basic.py b/gan_inv/lpips/networks_basic.py new file mode 100644 index 0000000..ea45e4c --- /dev/null +++ b/gan_inv/lpips/networks_basic.py @@ -0,0 +1,188 @@ + +from __future__ import absolute_import + +import sys +import torch +import torch.nn as nn +import torch.nn.init as init +from torch.autograd import Variable +import numpy as np +from pdb import set_trace as st +from skimage import color +from IPython import embed +from . import pretrained_networks as pn + +from . import util + + +def spatial_average(in_tens, keepdim=True): + return in_tens.mean([2,3],keepdim=keepdim) + +def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W + in_H = in_tens.shape[2] + scale_factor = 1.*out_H/in_H + + return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) + +# Learned perceptual metric +class PNetLin(nn.Module): + def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): + super(PNetLin, self).__init__() + + self.pnet_type = pnet_type + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.lpips = lpips + self.version = version + self.scaling_layer = ScalingLayer() + + if(self.pnet_type in ['vgg','vgg16']): + net_type = pn.vgg16 + self.chns = [64,128,256,512,512] + elif(self.pnet_type=='alex'): + net_type = pn.alexnet + self.chns = [64,192,384,256,256] + elif(self.pnet_type=='squeeze'): + net_type = pn.squeezenet + self.chns = [64,128,256,384,384,512,512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + if(lpips): + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] + if(self.pnet_type=='squeeze'): # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins+=[self.lin5,self.lin6] + + def forward(self, in0, in1, retPerLayer=False): + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + feats0, feats1, diffs = {}, {}, {} + + for kk in range(self.L): + feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk]-feats1[kk])**2 + + if(self.lpips): + if(self.spatial): + res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] + else: + res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] + else: + if(self.spatial): + res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] + else: + res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] + + val = res[0] + for l in range(1,self.L): + val += res[l] + + if(retPerLayer): + return (val, res) + else: + return val + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) + self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + ''' A single linear layer which does a 1x1 conv ''' + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + + layers = [nn.Dropout(),] if(use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] + self.model = nn.Sequential(*layers) + + +class Dist2LogitLayer(nn.Module): + ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' + def __init__(self, chn_mid=32, use_sigmoid=True): + super(Dist2LogitLayer, self).__init__() + + layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] + layers += [nn.LeakyReLU(0.2,True),] + layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] + layers += [nn.LeakyReLU(0.2,True),] + layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] + if(use_sigmoid): + layers += [nn.Sigmoid(),] + self.model = nn.Sequential(*layers) + + def forward(self,d0,d1,eps=0.1): + return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) + +class BCERankingLoss(nn.Module): + def __init__(self, chn_mid=32): + super(BCERankingLoss, self).__init__() + self.net = Dist2LogitLayer(chn_mid=chn_mid) + # self.parameters = list(self.net.parameters()) + self.loss = torch.nn.BCELoss() + + def forward(self, d0, d1, judge): + per = (judge+1.)/2. + self.logit = self.net.forward(d0,d1) + return self.loss(self.logit, per) + +# L2, DSSIM metrics +class FakeNet(nn.Module): + def __init__(self, use_gpu=True, colorspace='Lab'): + super(FakeNet, self).__init__() + self.use_gpu = use_gpu + self.colorspace=colorspace + +class L2(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert(in0.size()[0]==1) # currently only supports batchSize 1 + + if(self.colorspace=='RGB'): + (N,C,X,Y) = in0.size() + value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) + return value + elif(self.colorspace=='Lab'): + value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), + util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') + ret_var = Variable( torch.Tensor((value,) ) ) + if(self.use_gpu): + ret_var = ret_var.cuda() + return ret_var + +class DSSIM(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert(in0.size()[0]==1) # currently only supports batchSize 1 + + if(self.colorspace=='RGB'): + value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') + elif(self.colorspace=='Lab'): + value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), + util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') + ret_var = Variable( torch.Tensor((value,) ) ) + if(self.use_gpu): + ret_var = ret_var.cuda() + return ret_var + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print('Network',net) + print('Total number of parameters: %d' % num_params) diff --git a/gan_inv/lpips/pretrained_networks.py b/gan_inv/lpips/pretrained_networks.py new file mode 100644 index 0000000..077a244 --- /dev/null +++ b/gan_inv/lpips/pretrained_networks.py @@ -0,0 +1,181 @@ +from collections import namedtuple +import torch +from torchvision import models as tv +from IPython import embed + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2,5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) + out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if(num==18): + self.net = tv.resnet18(pretrained=pretrained) + elif(num==34): + self.net = tv.resnet34(pretrained=pretrained) + elif(num==50): + self.net = tv.resnet50(pretrained=pretrained) + elif(num==101): + self.net = tv.resnet101(pretrained=pretrained) + elif(num==152): + self.net = tv.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/gan_inv/lpips/util.py b/gan_inv/lpips/util.py new file mode 100644 index 0000000..4f8b582 --- /dev/null +++ b/gan_inv/lpips/util.py @@ -0,0 +1,160 @@ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from skimage.metrics import structural_similarity +import torch + + +from . import dist_model + +class PerceptualLoss(torch.nn.Module): + def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) + # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss + super(PerceptualLoss, self).__init__() + print('Setting up Perceptual loss...') + self.use_gpu = use_gpu + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model = dist_model.DistModel() + self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) + print('...[%s] initialized'%self.model.name()) + print('...Done') + + def forward(self, pred, target, normalize=False): + """ + Pred and target are Variables. + If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] + If normalize is False, assumes the images are already between [-1,+1] + + Inputs pred and target are Nx3xHxW + Output pytorch Variable N long + """ + + if normalize: + target = 2 * target - 1 + pred = 2 * pred - 1 + + return self.model.forward(target, pred) + +def normalize_tensor(in_feat,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) + return in_feat/(norm_factor+eps) + +def l2(p0, p1, range=255.): + return .5*np.mean((p0 / range - p1 / range)**2) + +def psnr(p0, p1, peak=255.): + return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) + +def dssim(p0, p1, range=255.): + return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2. + +def rgb2lab(in_img,mean_cent=False): + from skimage import color + img_lab = color.rgb2lab(in_img) + if(mean_cent): + img_lab[:,:,0] = img_lab[:,:,0]-50 + return img_lab + +def tensor2np(tensor_obj): + # change dimension of a tensor object into a numpy array + return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) + +def np2tensor(np_obj): + # change dimenion of np array into tensor array + return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): + # image tensor to lab tensor + from skimage import color + + img = tensor2im(image_tensor) + img_lab = color.rgb2lab(img) + if(mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + if(to_norm and not mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + img_lab = img_lab/100. + + return np2tensor(img_lab) + +def tensorlab2tensor(lab_tensor,return_inbnd=False): + from skimage import color + import warnings + warnings.filterwarnings("ignore") + + lab = tensor2np(lab_tensor)*100. + lab[:,:,0] = lab[:,:,0]+50 + + rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) + if(return_inbnd): + # convert back to lab, see if we match + lab_back = color.rgb2lab(rgb_back.astype('uint8')) + mask = 1.*np.isclose(lab_back,lab,atol=2.) + mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) + return (im2tensor(rgb_back),mask) + else: + return im2tensor(rgb_back) + +def rgb2lab(input): + from skimage import color + return color.rgb2lab(input / 255.) + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2vec(vector_tensor): + return vector_tensor.data.cpu().numpy()[:, :, 0, 0] + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): +# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): +# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) diff --git a/visualizer_gradio_custom.py b/visualizer_gradio_custom.py new file mode 100644 index 0000000..6e84057 --- /dev/null +++ b/visualizer_gradio_custom.py @@ -0,0 +1,964 @@ +import os +import os.path as osp +from argparse import ArgumentParser +from functools import partial + +import gradio as gr +import numpy as np +import torch +from PIL import Image +import imageio +import dnnlib +from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, + get_latest_points_pair, get_valid_mask, + on_change_single_global_state) +from viz.renderer import Renderer, add_watermark_np +from gan_inv.inversion import PTI +from gan_inv.lpips import util +parser = ArgumentParser() +parser.add_argument('--share',default='False') +parser.add_argument('--cache-dir', type=str, default='./checkpoints') +args = parser.parse_args() + +cache_dir = args.cache_dir + +device = 'cuda' + + +def reverse_point_pairs(points): + new_points = [] + for p in points: + new_points.append([p[1], p[0]]) + return new_points + + +def clear_state(global_state, target=None): + """Clear target history state from global_state + If target is not defined, points and mask will be both removed. + 1. set global_state['points'] as empty dict + 2. set global_state['mask'] as full-one mask. + """ + if target is None: + target = ['point', 'mask'] + if not isinstance(target, list): + target = [target] + if 'point' in target: + global_state['points'] = dict() + print('Clear Points State!') + if 'mask' in target: + image_raw = global_state["images"]["image_raw"] + global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), + dtype=np.uint8) + print('Clear mask State!') + + return global_state + + +def init_images(global_state): + """This function is called only ones with Gradio App is started. + 0. pre-process global_state, unpack value from global_state of need + 1. Re-init renderer + 2. run `renderer._render_drag_impl` with `is_drag=False` to generate + new image + 3. Assign images to global state and re-generate mask + """ + + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + + state['renderer'].init_network( + state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr, + ) + + state['renderer']._render_drag_impl(state['generator_params'], + is_drag=False, + to_pil=True) + + init_image = state['generator_params'].image + state['images']['image_orig'] = init_image + state['images']['image_raw'] = init_image + state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(init_image))) + state['mask'] = np.ones((init_image.size[1], init_image.size[0]), + dtype=np.uint8) + return global_state + + +def update_image_draw(image, points, mask, show_mask, global_state=None): + + image_draw = draw_points_on_image(image, points) + if show_mask and mask is not None and not (mask == 0).all() and not ( + mask == 1).all(): + image_draw = draw_mask_on_image(image_draw, mask) + + image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) + if global_state is not None: + global_state['images']['image_show'] = image_draw + return image_draw + + +def preprocess_mask_info(global_state, image): + """Function to handle mask information. + 1. last_mask is None: Do not need to change mask, return mask + 2. last_mask is not None: + 2.1 global_state is remove_mask: + 2.2 global_state is add_mask: + """ + if isinstance(image, dict): + last_mask = get_valid_mask(image['mask']) + else: + last_mask = None + mask = global_state['mask'] + + # mask in global state is a placeholder with all 1. + if (mask == 1).all(): + mask = last_mask + + # last_mask = global_state['last_mask'] + editing_mode = global_state['editing_state'] + + if last_mask is None: + return global_state + + if editing_mode == 'remove_mask': + updated_mask = np.clip(mask - last_mask, 0, 1) + print(f'Last editing_state is {editing_mode}, do remove.') + elif editing_mode == 'add_mask': + updated_mask = np.clip(mask + last_mask, 0, 1) + print(f'Last editing_state is {editing_mode}, do add.') + else: + updated_mask = mask + print(f'Last editing_state is {editing_mode}, ' + 'do nothing to mask.') + + global_state['mask'] = updated_mask + # global_state['last_mask'] = None # clear buffer + return global_state + + +valid_checkpoints_dict = { + f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f) + for f in os.listdir(cache_dir) + if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f))) +} +print(f'File under cache_dir ({cache_dir}):') +print(os.listdir(cache_dir)) +print('Valid checkpoint file:') +print(valid_checkpoints_dict) + +init_pkl = 'stylegan2_lions_512_pytorch' + + + +# Network & latents tab listeners +def on_change_pretrained_dropdown(pretrained_value, global_state): + """Function to handle model change. + 1. Set pretrained value to global_state + 2. Re-init images and clear all states + """ + global_state['pretrained_weight'] = pretrained_value + init_images(global_state) + clear_state(global_state) + + return global_state, global_state["images"]['image_show'] + + + +def on_click_reset_image(global_state): + """Reset image to the original one and clear all states + 1. Re-init images + 2. Clear all states + """ + + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + + # Update parameters +def on_change_update_image_seed(seed, global_state): + """Function to handle generation seed change. + 1. Set seed to global_state + 2. Re-init images and clear all states + """ + + global_state["params"]["seed"] = int(seed) + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + +def on_click_latent_space(latent_space, global_state): + """Function to reset latent space to optimize. + NOTE: this function we reset the image and all controls + 1. Set latent-space to global_state + 2. Re-init images and clear all state + """ + + global_state['params']['latent_space'] = latent_space + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + +def on_click_inverse_custom_image(custom_image,global_state): + print('inverse GAN') + + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + + state['renderer'].init_network( + state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr, + ) + + percept = util.PerceptualLoss( + model="net-lin", net="vgg", use_gpu=True + ) + + image = Image.open(custom_image.name) + + pti = PTI(global_state['renderer'].G,percept) + inversed_img, w_pivot = pti.train(image,state['params']['latent_space'] == 'w+') + inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) + inversed_img = inversed_img.cpu().numpy() + inversed_img = Image.fromarray(inversed_img) + global_state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(inversed_img))) + + global_state['images']['image_orig'] = inversed_img + global_state['images']['image_raw'] = inversed_img + + global_state['mask'] = np.ones((inversed_img.size[1], inversed_img.size[0]), + dtype=np.uint8) + global_state['generator_params'].image = inversed_img + global_state['generator_params'].w = w_pivot.detach().cpu().numpy() + global_state['renderer'].set_latent(w_pivot,global_state['params']['trunc_psi'],global_state['params']['trunc_cutoff']) + + del percept + del pti + print('inverse end') + + return global_state, global_state['images']['image_show'], gr.Button.update(interactive=True) + +def on_save_image(global_state,form_save_image_path): + imageio.imsave(form_save_image_path,global_state['images']['image_raw']) + +def on_reset_custom_image(global_state): + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + clear_state(state) + state['renderer'].w = state['renderer'].w0.detach().clone() + state['renderer'].w.requires_grad = True + state['renderer'].w_optim = torch.optim.Adam([state['renderer'].w], lr=state['renderer'].lr) + state['renderer']._render_drag_impl(state['generator_params'], + is_drag=False, + to_pil=True) + + init_image = state['generator_params'].image + state['images']['image_orig'] = init_image + state['images']['image_raw'] = init_image + state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(init_image))) + state['mask'] = np.ones((init_image.size[1], init_image.size[0]), + dtype=np.uint8) + return state, state['images']['image_show'] +def on_change_lr(lr, global_state): + if lr == 0: + print('lr is 0, do nothing.') + return global_state + else: + global_state["params"]["lr"] = lr + renderer = global_state['renderer'] + renderer.update_lr(lr) + print('New optimizer: ') + print(renderer.w_optim) + return global_state + + +def on_click_start(global_state, image): + p_in_pixels = [] + t_in_pixels = [] + valid_points = [] + + # handle of start drag in mask editing mode + global_state = preprocess_mask_info(global_state, image) + + # Prepare the points for the inference + if len(global_state["points"]) == 0: + # yield on_click_start_wo_points(global_state, image) + image_raw = global_state['images']['image_raw'] + update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + + yield ( + global_state, + 0, + global_state['images']['image_show'], + # gr.File.update(visible=False), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + # latent space + gr.Radio.update(interactive=True), + gr.Button.update(interactive=True), + # NOTE: disable stop button + gr.Button.update(interactive=False), + + # update other comps + gr.Dropdown.update(interactive=True), + gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Checkbox.update(interactive=True), + # gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + ) + else: + + # Transform the points into torch tensors + for key_point, point in global_state["points"].items(): + try: + p_start = point.get("start_temp", point["start"]) + p_end = point["target"] + + if p_start is None or p_end is None: + continue + + except KeyError: + continue + + p_in_pixels.append(p_start) + t_in_pixels.append(p_end) + valid_points.append(key_point) + + mask = torch.tensor(global_state['mask']).float() + drag_mask = 1 - mask + + renderer: Renderer = global_state["renderer"] + global_state['temporal_params']['stop'] = False + global_state['editing_state'] = 'running' + + # reverse points order + p_to_opt = reverse_point_pairs(p_in_pixels) + t_to_opt = reverse_point_pairs(t_in_pixels) + #print('Running with:') + #print(f' Source: {p_in_pixels}') + #print(f' Target: {t_in_pixels}') + step_idx = 0 + while True: + if global_state["temporal_params"]["stop"]: + break + + # do drage here! + renderer._render_drag_impl( + global_state['generator_params'], + p_to_opt, # point + t_to_opt, # target + drag_mask, # mask, + global_state['params']['motion_lambda'], # lambda_mask + reg=0, + feature_idx=5, # NOTE: do not support change for now + r1=global_state['params']['r1_in_pixels'], # r1 + r2=global_state['params']['r2_in_pixels'], # r2 + # random_seed = 0, + # noise_mode = 'const', + trunc_psi=global_state['params']['trunc_psi'], + # force_fp32 = False, + # layer_name = None, + # sel_channels = 3, + # base_channel = 0, + # img_scale_db = 0, + # img_normalize = False, + # untransform = False, + is_drag=True, + to_pil=True) + + if step_idx % global_state['draw_interval'] == 0: + #print('Current Source:') + for key_point, p_i, t_i in zip(valid_points, p_to_opt, + t_to_opt): + global_state["points"][key_point]["start_temp"] = [ + p_i[1], + p_i[0], + ] + global_state["points"][key_point]["target"] = [ + t_i[1], + t_i[0], + ] + start_temp = global_state["points"][key_point][ + "start_temp"] + #print(f' {start_temp}') + + image_result = global_state['generator_params']['image'] + image_draw = update_image_draw( + image_result, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + global_state['images']['image_raw'] = image_result + + yield ( + global_state, + step_idx, + global_state['images']['image_show'], + # gr.File.update(visible=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + # latent space + gr.Radio.update(interactive=False), + gr.Button.update(interactive=False), + # enable stop button in loop + gr.Button.update(interactive=True), + + # update other comps + gr.Dropdown.update(interactive=False), + gr.Number.update(interactive=False), + gr.Number.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Checkbox.update(interactive=False), + # gr.Number.update(interactive=False), + gr.Number.update(interactive=False), + ) + + # increate step + step_idx += 1 + + image_result = global_state['generator_params']['image'] + global_state['images']['image_raw'] = image_result + image_draw = update_image_draw(image_result, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state) + + # fp = NamedTemporaryFile(suffix=".png", delete=False) + # image_result.save(fp, "PNG") + + global_state['editing_state'] = 'add_points' + + yield ( + global_state, + 0, # reset step to 0 after stop. + global_state['images']['image_show'], + # gr.File.update(visible=True, value=fp.name), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + # latent space + gr.Radio.update(interactive=True), + gr.Button.update(interactive=True), + # NOTE: disable stop button with loop finish + gr.Button.update(interactive=False), + + # update other comps + gr.Dropdown.update(interactive=True), + gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + gr.Checkbox.update(interactive=True), + gr.Number.update(interactive=True), + ) + + + +def on_click_stop(global_state): + """Function to handle stop button is clicked. + 1. send a stop signal by set global_state["temporal_params"]["stop"] as True + 2. Disable Stop button + """ + global_state["temporal_params"]["stop"] = True + + return global_state, gr.Button.update(interactive=False) + + + +def on_click_remove_point(global_state): + choice = global_state["curr_point"] + del global_state["points"][choice] + + choices = list(global_state["points"].keys()) + + if len(choices) > 0: + global_state["curr_point"] = choices[0] + + return ( + gr.Dropdown.update(choices=choices, value=choices[0]), + global_state, + ) + + # Mask +def on_click_reset_mask(global_state): + global_state['mask'] = np.ones( + ( + global_state["images"]["image_raw"].size[1], + global_state["images"]["image_raw"].size[0], + ), + dtype=np.uint8, + ) + image_draw = update_image_draw(global_state['images']['image_raw'], + global_state['points'], + global_state['mask'], + global_state['show_mask'], global_state) + return global_state, image_draw + + + + # Image +def on_click_enable_draw(global_state, image): + """Function to start add mask mode. + 1. Preprocess mask info from last state + 2. Change editing state to add_mask + 3. Set curr image with points and mask + """ + global_state = preprocess_mask_info(global_state, image) + global_state['editing_state'] = 'add_mask' + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], + global_state['mask'], True, + global_state) + return (global_state, + gr.Image.update(value=image_draw, interactive=True)) + +def on_click_remove_draw(global_state, image): + """Function to start remove mask mode. + 1. Preprocess mask info from last state + 2. Change editing state to remove_mask + 3. Set curr image with points and mask + """ + global_state = preprocess_mask_info(global_state, image) + global_state['edinting_state'] = 'remove_mask' + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], + global_state['mask'], True, + global_state) + return (global_state, + gr.Image.update(value=image_draw, interactive=True)) + + + +def on_click_add_point(global_state, image: dict): + """Function switch from add mask mode to add points mode. + 1. Updaste mask buffer if need + 2. Change global_state['editing_state'] to 'add_points' + 3. Set current image with mask + """ + + global_state = preprocess_mask_info(global_state, image) + global_state['editing_state'] = 'add_points' + mask = global_state['mask'] + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], mask, + global_state['show_mask'], global_state) + + return (global_state, + gr.Image.update(value=image_draw, interactive=False)) + + + +def on_click_image(global_state, evt: gr.SelectData): + """This function only support click for point selection + """ + xy = evt.index + if global_state['editing_state'] != 'add_points': + print(f'In {global_state["editing_state"]} state. ' + 'Do not add points.') + + return global_state, global_state['images']['image_show'] + + points = global_state["points"] + + point_idx = get_latest_points_pair(points) + if point_idx is None: + points[0] = {'start': xy, 'target': None} + print(f'Click Image - Start - {xy}') + elif points[point_idx].get('target', None) is None: + points[point_idx]['target'] = xy + print(f'Click Image - Target - {xy}') + else: + points[point_idx + 1] = {'start': xy, 'target': None} + print(f'Click Image - Start - {xy}') + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + + return global_state, image_draw + + + +def on_click_clear_points(global_state): + """Function to handle clear all control points + 1. clear global_state['points'] (clear_state) + 2. re-init network + 2. re-draw image + """ + clear_state(global_state, target='point') + + renderer: Renderer = global_state["renderer"] + renderer.feat_refs = None + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, {}, global_state['mask'], + global_state['show_mask'], global_state) + return global_state, image_draw + + + +def on_click_show_mask(global_state, show_mask): + """Function to control whether show mask on image.""" + global_state['show_mask'] = show_mask + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + return global_state, image_draw + + +if __name__ == "__main__": + with gr.Blocks() as app: + # renderer = Renderer() + global_state = gr.State({ + "images": { + # image_orig: the original image, change with seed/model is changed + # image_raw: image with mask and points, change durning optimization + # image_show: image showed on screen + }, + "temporal_params": { + # stop + }, + 'mask': + None, # mask for visualization, 1 for editing and 0 for unchange + 'last_mask': None, # last edited mask + 'show_mask': True, # add button + "generator_params": dnnlib.EasyDict(), + "params": { + "seed": 0, + "motion_lambda": 20, + "r1_in_pixels": 3, + "r2_in_pixels": 12, + "magnitude_direction_in_pixels": 1.0, + "latent_space": "w+", + "trunc_psi": 0.7, + "trunc_cutoff": None, + "lr": 0.001, + }, + "device": device, + "draw_interval": 1, + "renderer": Renderer(disable_timing=True), + "points": {}, + "curr_point": None, + "curr_type_point": "start", + 'editing_state': 'add_points', + 'pretrained_weight': init_pkl + }) + + # init image + global_state = init_images(global_state) + + with gr.Row(): + with gr.Row(): + # Left --> tools + with gr.Column(scale=3): + # Pickle + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Pickle', show_label=False) + + with gr.Column(scale=4, min_width=10): + form_pretrained_dropdown = gr.Dropdown( + choices=list(valid_checkpoints_dict.keys()), + label="Pretrained Model", + value=init_pkl, + ) + + # Latent + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Latent', show_label=False) + + with gr.Column(scale=4, min_width=10): + form_seed_number = gr.Number( + value=global_state.value['params']['seed'], + interactive=True, + label="Seed", + ) + form_lr_number = gr.Number( + value=global_state.value["params"]["lr"], + interactive=True, + label="Step Size") + + with gr.Row(): + with gr.Column(scale=2, min_width=10): + form_reset_image = gr.Button("Reset Image") + with gr.Column(scale=3, min_width=10): + form_latent_space = gr.Radio( + ['w', 'w+'], + value=global_state.value['params'] + ['latent_space'], + interactive=True, + label='Latent space to optimize', + show_label=False, + ) + with gr.Row(): + with gr.Column(scale=3, min_width=10): + form_custom_image = gr.UploadButton(label="inverse custom image", + file_types=['.png', '.jpg', '.jpeg']) + with gr.Column(scale=3, min_width=10): + form_reset_custom_image = gr.Button('reset custom image', interactive=False) + with gr.Row(): + with gr.Column(scale=3, min_width=10): + form_save_image_path = gr.Textbox(label="save image to",value='./test.png') + form_save_image = gr.Button('save',interactive=True) + + + # Drag + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Drag', show_label=False) + with gr.Column(scale=4, min_width=10): + with gr.Row(): + with gr.Column(scale=1, min_width=10): + enable_add_points = gr.Button('Add Points') + with gr.Column(scale=1, min_width=10): + undo_points = gr.Button('Reset Points') + with gr.Row(): + with gr.Column(scale=1, min_width=10): + form_start_btn = gr.Button("Start") + with gr.Column(scale=1, min_width=10): + form_stop_btn = gr.Button("Stop") + + form_steps_number = gr.Number(value=0, + label="Steps", + interactive=False) + + # Mask + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Mask', show_label=False) + with gr.Column(scale=4, min_width=10): + enable_add_mask = gr.Button('Edit Flexible Area') + with gr.Row(): + with gr.Column(scale=1, min_width=10): + form_reset_mask_btn = gr.Button("Reset mask") + with gr.Column(scale=1, min_width=10): + show_mask = gr.Checkbox( + label='Show Mask', + value=global_state.value['show_mask'], + show_label=False) + + with gr.Row(): + form_lambda_number = gr.Number( + value=global_state.value["params"] + ["motion_lambda"], + interactive=True, + label="Lambda", + ) + + form_draw_interval_number = gr.Number( + value=global_state.value["draw_interval"], + label="Draw Interval (steps)", + interactive=True, + visible=False) + + # Right --> Image + with gr.Column(scale=8): + form_image = ImageMask( + value=global_state.value['images']['image_show'], + brush_radius=20).style( + width=768, + height=768) # NOTE: hard image size code here. + gr.Markdown(""" + ## Quick Start + + 1. Select desired `Pretrained Model` and adjust `Seed` to generate an + initial image. + 2. Click on image to add control points. + 3. Click `Start` and enjoy it! + + ## Advance Usage + + 1. Change `Step Size` to adjust learning rate in drag optimization. + 2. Select `w` or `w+` to change latent space to optimize: + * Optimize on `w` space may cause greater influence to the image. + * Optimize on `w+` space may work slower than `w`, but usually achieve + better results. + * Note that changing the latent space will reset the image, points and + mask (this has the same effect as `Reset Image` button). + 3. Click `Edit Flexible Area` to create a mask and constrain the + unmasked region to remain unchanged. + """) + gr.HTML(""" + +