From 8bedc2d9c3c8db9f657d0239704f007971b30ae7 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Tue, 20 Aug 2024 06:12:26 +0000 Subject: [PATCH 01/14] feature(wrh): add initial version of edm --- density_func.ipynb | 771 ++++++++++++++++++ .../edm_diffusion_model.py | 266 ++++++ .../edm_diffusion_model/edm_preconditioner.py | 116 +++ .../swiss_roll/swiss_roll_edm_diffusion.py | 220 +++++ 4 files changed, 1373 insertions(+) create mode 100644 density_func.ipynb create mode 100644 grl/generative_models/edm_diffusion_model/edm_diffusion_model.py create mode 100644 grl/generative_models/edm_diffusion_model/edm_preconditioner.py create mode 100644 grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py diff --git a/density_func.ipynb b/density_func.ipynb new file mode 100644 index 0000000..d788bec --- /dev/null +++ b/density_func.ipynb @@ -0,0 +1,771 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from base64 import b64encode\n", + "import pickle\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "from easydict import EasyDict\n", + "from IPython.display import HTML\n", + "from rich.progress import track\n", + "from sklearn.datasets import make_swiss_roll\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import animation\n", + "matplotlib.use(\"Agg\")\n", + "\n", + "from grl.generative_models.diffusion_model import DiffusionModel\n", + "from grl.generative_models.conditional_flow_model import IndependentConditionalFlowModel, OptimalTransportConditionalFlowModel\n", + "from grl.generative_models.metric import compute_likelihood\n", + "from grl.utils import set_seed\n", + "from grl.utils.log import log" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.utils import shuffle as util_shuffle\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "\n", + "def plot2d(data):\n", + " plt.scatter(data[:, 0], data[:, 1])\n", + " plt.show()\n", + "\n", + "\n", + "def show_video(video_path, video_width=600):\n", + "\n", + " video_file = open(video_path, \"r+b\").read()\n", + "\n", + " video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n", + " return HTML(\n", + " f\"\"\"\"\"\"\n", + " )\n", + "\n", + "\n", + "def render_video(\n", + " data_list, video_save_path, iteration, fps=100, dpi=100\n", + "):\n", + " if not os.path.exists(video_save_path):\n", + " os.makedirs(video_save_path)\n", + " fig = plt.figure(figsize=(6, 6))\n", + " plt.xlim([-5, 5])\n", + " plt.ylim([-5, 5])\n", + " ims = []\n", + " colors = np.linspace(0, 1, len(data_list))\n", + "\n", + " for i, data in enumerate(data_list):\n", + " im = plt.scatter(data[:, 0], data[:, 1], s=1)\n", + " title = plt.text(0.5, 1.05, f't={i/len(data_list):.2f}', ha='center', va='bottom', transform=plt.gca().transAxes)\n", + " ims.append([im, title])\n", + "\n", + " ani = animation.ArtistAnimation(fig, ims, interval=0.1, blit=True)\n", + " ani.save(os.path.join(video_save_path, f'iteration_{iteration}.mp4'), fps=fps, dpi=dpi)\n", + " # clean up\n", + " plt.close(fig)\n", + " plt.clf()\n", + "\n", + "def load_and_plot_results(file_path):\n", + " try:\n", + " with open(file_path, \"rb\") as f:\n", + " results = pickle.load(f)\n", + " except Exception as e:\n", + " print(f\"Failed to load the file: {e}\")\n", + " return\n", + "\n", + " plt.figure(figsize=(10, 6))\n", + " x = results[\"iterations\"]\n", + " if \"gradients\" in results and results[\"gradients\"]:\n", + " plt.plot(x, results[\"gradients\"], label=\"Gradients\")\n", + " if \"losses\" in results and results[\"losses\"]:\n", + " plt.plot(x, results[\"losses\"], label=\"Losses\")\n", + " plt.xlabel(\"Iteration\")\n", + " plt.ylabel(\"Log(Value)\")\n", + " plt.yscale(\"log\")\n", + " # Specify y-ticks\n", + " y_ticks = [1e-1, 5e-1, 1, 5, 10]\n", + " plt.yticks(y_ticks, [f\"{y:.0e}\" for y in y_ticks])\n", + " plt.title(\"Training Metrics Over Iterations\")\n", + " plt.legend()\n", + " plt.grid(True)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "x_size = 2\n", + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "t_embedding_dim = 32\n", + "\n", + "diffusion_model_config = EasyDict(\n", + " dict(\n", + " device=device,\n", + " project=\"linear_vp_sde_noise_function_score_matching\",\n", + " diffusion_model=dict(\n", + " device=device,\n", + " x_size=x_size,\n", + " alpha=1.0,\n", + " solver=dict(\n", + " type=\"ODESolver\",\n", + " args=dict(\n", + " library=\"torchdiffeq_adjoint\",\n", + " ),\n", + " ),\n", + " path=dict(\n", + " type=\"linear_vp_sde\",\n", + " beta_0=0.1,\n", + " beta_1=20.0,\n", + " ),\n", + " model=dict(\n", + " type=\"noise_function\",\n", + " args=dict(\n", + " t_encoder=dict(\n", + " type=\"GaussianFourierProjectionTimeEncoder\",\n", + " args=dict(\n", + " embed_dim=t_embedding_dim,\n", + " scale=30.0,\n", + " ),\n", + " ),\n", + " backbone=dict(\n", + " type=\"TemporalSpatialResidualNet\",\n", + " args=dict(\n", + " hidden_sizes=[128, 64, 32],\n", + " output_dim=x_size,\n", + " t_dim=t_embedding_dim,\n", + " ),\n", + " ),\n", + " ),\n", + " ),\n", + " ),\n", + " parameter=dict(\n", + " training_loss_type=\"score_matching\",\n", + " lr=5e-4,\n", + " data_num=100000,\n", + " # weight_decay=1e-4,\n", + " iterations=100000,\n", + " batch_size=4096,\n", + " # clip_grad_norm=1.0,\n", + " eval_freq=1000,\n", + " video_save_path=\"./video-diffusion\",\n", + " device=device,\n", + " ),\n", + " )\n", + ")\n", + "\n", + "flow_model_config = EasyDict(\n", + " dict(\n", + " device=device,\n", + " project=\"icfm_velocity_function_flow_matching\",\n", + " flow_model=dict(\n", + " device=device,\n", + " x_size=x_size,\n", + " alpha=1.0,\n", + " solver=dict(\n", + " type=\"ODESolver\",\n", + " args=dict(\n", + " library=\"torchdiffeq_adjoint\",\n", + " ),\n", + " ),\n", + " path=dict(\n", + " sigma=0.1,\n", + " ),\n", + " model=dict(\n", + " type=\"velocity_function\",\n", + " args=dict(\n", + " t_encoder=dict(\n", + " type=\"GaussianFourierProjectionTimeEncoder\",\n", + " args=dict(\n", + " embed_dim=t_embedding_dim,\n", + " scale=30.0,\n", + " ),\n", + " ),\n", + " backbone=dict(\n", + " type=\"TemporalSpatialResidualNet\",\n", + " args=dict(\n", + " hidden_sizes=[128, 64, 32],\n", + " output_dim=x_size,\n", + " t_dim=t_embedding_dim,\n", + " ),\n", + " ),\n", + " ),\n", + " ),\n", + " ),\n", + " parameter=dict(\n", + " training_loss_type=\"flow_matching\",\n", + " lr=5e-4,\n", + " data_num=100000,\n", + " # weight_decay=1e-4,\n", + " iterations=100000,\n", + " batch_size=4096,\n", + " # clip_grad_norm=1.0,\n", + " eval_freq=1000,\n", + " video_save_path=\"./video-flow\",\n", + " device=device,\n", + " ),\n", + " )\n", + ")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAi8AAAGdCAYAAADaPpOnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABrrUlEQVR4nO3deVhV1f4/8DcgowiEIjgGSk45kDNqioqi0kBaZtY1MzVNK1PrHr3mlAmpmdeGa369Stb1eq3UhmNOKOKAs+SMoRKKggPBEZkE9u8Pf5Ao4Fn7rLWn83k9D89TuIfPOZy99/usvfZaDpIkSSCEEEII0QlHtQsghBBCCGFB4YUQQgghukLhhRBCCCG6QuGFEEIIIbpC4YUQQgghukLhhRBCCCG6QuGFEEIIIbpC4YUQQgghulJD7QJ4Ky0txZUrV1CrVi04ODioXQ4hhBBCrCBJEm7duoX69evD0bH6thXDhZcrV66gUaNGapdBCCGEEBkuXbqEhg0bVruM4cJLrVq1ANx98V5eXipXQwghhBBrWCwWNGrUqPw6Xh3DhZeyW0VeXl4UXgghhBCdsabLB3XYJYQQQoiuUHghhBBCiK5QeCGEEEKIrlB4IYQQQoiuUHghhBBCiK5QeCGEEEKIrlB4IYQQQoiuUHghhBBCiK4YbpA6QohxnL5sQeTnuyFZseze9/ugga+78JoIIeoTGl6io6Oxfv16nD17Fu7u7ujWrRs+/vhjNG/evNr1vvvuO3zwwQdITU3FY489ho8//hiDBg0SWSohxEYHU7IwdEWiavvvvmAH921un9QLwQGe3LdLCLGN0PCya9cuTJgwAZ06dUJxcTGmT5+O/v374/Tp06hZs2al6+zbtw8vvfQSoqOj8dRTT2HNmjWIiorC0aNH0bp1a5HlEkIqoXYoUVP4kl1WLUchhxBlOUiSZE2LLBfXr19H3bp1sWvXLvTs2bPSZV588UXcvn0bv/zyS/nvunbtipCQECxbtuyh+7BYLPD29kZOTg7NbUSIFZKv3MLApQkoVbsQg/IAsGt6OPy8XNUuhRBNY7l+K9rnJScnBwDg6+tb5TKJiYmYPHlyhd9FRERg48aNlS5fWFiIwsLC8v+3WCy2F0qIweQWFOO1lQdxKO1PtUuxO3kAOs3fXum/bRzXHSGBPorWQ4gRKBZeSktLMWnSJHTv3r3a2z8ZGRnw9/ev8Dt/f39kZGRUunx0dDTmzJnDtVZC9CwpNRtRy/aqXQaxQlV/p4SpvdG4jofC1RCiH4qFlwkTJuDkyZPYs2cP1+1OmzatQkuNxWJBo0aNuO6DEK2y5/4oRtZz0c4Hfkf9agj5iyLhZeLEifjll1+QkJCAhg0bVrtsQEAAMjMzK/wuMzMTAQEBlS7v6uoKV1e6l0yMLyUjFxH/3IUSxXqpES25v/NwbSdg27R+8PV0UakiQtQjNLxIkoS33noLGzZsQHx8PIKCgh66TmhoKOLi4jBp0qTy323btg2hoaECKyVEe6hVhVTnZgnQft62Cr/b8nZPNK9fS6WKCFGO0PAyYcIErFmzBj/++CNq1apV3m/F29sb7u53B5MaMWIEGjRogOjoaADAO++8g169euGTTz5BZGQk1q5di8OHD2P58uUiSyVEdfvP3cSwlfvVLoPoWMTShAr/T2GGGJXQR6UdHBwq/f2qVaswcuRIAEBYWBgCAwMRGxtb/u/fffcdZsyYUT5I3YIFC6wepI4elSZ6QR1ridJ2Tg5DUN3Kx9giRG0s129Fx3lRAoUXolXpWfno92k88u7QiCpyRDTzxJJXesDdxQnA3fez/6fxuE3vpyz9H6uJf/7tyfL3kxC1UXih8EI04siFPzFk+T61y1DcutGh6Bxc9XhOWpWTdwdDl+1D8rVctUtRHD3NRNRG4YXCC1FJflEJ3l5zFNvOXlO7FCH0Gkp4M3rIocHziBoovFB4IQq6binEwH/uwo3bd9QuxWZBnsDGyf3h7eGsdim6dvHabfRfEg8j3NGaM+gxvNLjMTg5Vt6HkRBeKLxQeCGC6T2w/PxmD7Rp7K12GXanqLgUH/18Gl8f+EPtUmShIENEovBC4YUIkFtQjBH/PoCjl7LVLsVqL4X6Y05ke7jUcFS7FFKNjOwCRCyJR05BidqlWI1uIRLeKLxQeCGclJRKiN19ER/+ekbtUh7qw8hmGN49mL4VG0RWbhGe/Xw3LmUXqF3KQ22a+CRaNaTzLbENhRcKL8RGpy9bMOjz3WqXUSUaGt4+6WGKiP2mvgjwcVO7DKJDFF4ovBAZtNyPZe2orujarLbaZRCN0XLrzKyBwRjxZDNqCSRWo/BC4YUw0OIcQtSfgMiRk3cHL/xrL85dv612KRXQGDLEGhReKLyQh9BaK8uLXeviw6c6UMdawlVGdgH6L4mHRSMdgak1hlSHwguFF1KFE2k5ePrLPWqXAQBImNobjet4qF0GsSNamk9r7/t90MDXXe0yiIZQeKHwQu5RUirh3/EXMH/rWVXroNYVoiVaaX2kMYdIGQovFF4I7nZm7Lt4B/7MU6/JnIZZJ3qghWktvh3RGT1a+am2f6I+Ci8UXuzaxWu30XtxvGr73/J2TzSvX0u1/RNiq/3nbmLYyv2q7PtRD+CnqTRFhT2i8ELhxS6pOTYLBRZiVGoGGRozxr5QeKHwYleOXPgTQ5bvU3y/dK+e2JOi4lJ8+NMpfHMwTfF9U+d2+0DhhcKLXVBjfBYaLI6Qu+PJPP+vvfhd4fFkdk4OQ1DdmorukyiHwguFF0NTuqWF7sETUjU1piygEGNMFF4ovBiS0mNUUD8WQtgo3T+GbicZC4UXCi+GomRH3Be6+OGjpzvSWCyE2EDpkX2pY68xUHih8GIIKRm5CF+yS5F9USsLIWLsO3sDw2MPKLKv32bS7V09o/BC4UXXMrIL0DUmTvh+qC8LIcpRavwlDxdHJM2MoNZTHaLwQuFFl3ILitF69hbh+6EnhghRj1JPKj3Vuh4+f6W90H0Qvii8UHjRndBZZlwtFLuPTROfRKuG9JkgRCv2nL6OV1YfFLoPOu71g8ILhRfdUGKWZ3oigRBtU+JJQurUq30UXii8aJ7ofi2OAA7P6AdfTxdh+yCE8JV85RYiliYI234NRyB53iA4OToI2weRj8ILhRdNCzSZhW3b3Qk49EEEPN1qCNsHIUQs0Z17/96/Bcb3aSps+0QeCi8UXjRJ5HgtFFoIMZ60G3nouWinsO0fmh4OPy9XYdsnbCi8UHjRlJy8O2g3d6uQbbs4AkdnUmghxMhEtsS4OzvizIcDhWybsKHwQuFFM9pMM+OWgE+YA4AkGpCKELsicuBKmi9JfRReKLyo7rqlEJ3mbxeybWrqtR/XLYUYsGQXbubdqXa5lcM6oE9IgEJVEbWJugXtAOBiTCT37RLraCa8JCQkYOHChThy5AiuXr2KDRs2ICoqqsrl4+Pj0bt37wd+f/XqVQQEWHdiovCiPlEdcumRZ/1QehJNnuiRWv0Q9Tmj6ULUwXL9FtpR4Pbt22jXrh1GjRqFwYMHW71ecnJyhcLr1q0rojzCmajWFhpkSlusbQ3Rq4c9wk+tPNoREuiD1JhIvLk6AZtO3+K23bLHtVOpFUazFLtt5ODgYHXLy59//gkfHx9Z+6GWF3WIaG2Z0rcZ3ur3GPftEuvsSMrAqLVH1C5DVzaO646QQB+1y7BLRcWlaDbjV+7bpb4wytFMy4tcISEhKCwsROvWrTF79mx07969ymULCwtRWPjXuPIWi0WJEsn/J2I+IhpISlnJV25h4NIElKpdiAFUdgvDC0A8DZgonEsNR6TGRHIfALP34njqC6NBmgov9erVw7Jly9CxY0cUFhZixYoVCAsLw4EDB9C+feUTbEVHR2POnDkKV0oAoNdHZvzBr6UWAPU3EO3itdvovyQedyipKMYCoP28bRV+F/aoM/71el+4uzipU5SBBfi4ITUmkuvUIxLuti7TwwLaoanbRpXp1asXGjdujG+++abSf6+s5aVRo0Z020gw3reJfn6zB9o09ua6TQLsO3sDw2MPqF0GsQJ1EhWD97nKw8URp+fSuDAi6P620b06d+6MPXuqTs+urq5wdaUkrBTeTbI1HIGU+dQcywuFFf26f04fCjN8pMZEcn2YIK+oFIEmM3XmVZnmw0tSUhLq1aundhkE/L/B7H2/Dxr4unPdpr2h/irGdW+YqVcD2DydBmWUy8/LFakxkXh7zT78dPxPLtsMNJnpHKYioeElNzcXKSkp5f9/8eJFJCUlwdfXF40bN8a0adOQnp6O1atXAwCWLFmCoKAgPP744ygoKMCKFSuwY8cObN0qZmh5Yj2ewcW/lgsO/KMft+3Zm+1Hr2L0uqNql0EUdLUYFabY+GFsN3Ro8oiKFenT0uHdsGgov6eSui/YQa3HKhEaXg4fPlxh0LnJkycDAF599VXExsbi6tWrSEtLK//3oqIiTJkyBenp6fDw8EDbtm2xffv2SgeuI8rgPXbLbzSkP7Os3CI8tTQBVyyFD1+Y2IUhy/eV/zeNO8Om7KkkXlMNFJeCbiOpgKYHIFUKMpnB68MR7OeJ7VN6cdqa8WVkF6D/4nhYikrULkVT7h1HJf54JkauOaxuQRqz/Pkn0L9jfbXL0BWercpH6ZF4m2hmegA1UHjhg+cBfXI2zfpsjazcIjz1zwRcuWW8FhYtPwLP85FaLaEWGevxnLXax90ZSbP6c9mWvaHwQuFFtpJSCU2nb+KyLW/3GvhtVgSXbRnZ1sNXMPb7Y2qXIZs9jCqr99GG140ORedgX7XL0DyeX9roNhI7Ci8UXmT5/uAfmLr+JJdtUfNp9Q6mZGHoikS1y7Da3/sFYmzvVjTq8X0ysgvQ/9N4WAr1c3uPJjit3rAvdmD/pXwu26IAw4bCC4UXZk1NZvA4/To6ABei6YCtjF5uC9FcLrbTQzil20pVyy8qQcuZm7ls68zcATSSspUovFB4YcKrqZS+0VVOyxcymrFbOXtOX8crqw+qXUal6NitHK9z45PBdfDN6C5ctmVkFF4ovFiN18FJzaMVlZRK+HxzMj5NOK92KeVa+wBrJ1Hnaa1IychFxD93oURDZ+A1I7ugW4s6apehKbxGFafJHR+OwguFF6vwCC40QFNFGdkFCFsYhwINdIGo7wz8Oo3G1dELLYWZZUNCMKBTA7XL0BT6oicehRcKLw/F40CMfq4NXurSmEM1+pd85dYDc9OoYfukXggO8FS7DMKBFuapCnACtvyDAnAZCjBiUXih8FKlomI+Q2Ofnz+InjwBcOTCnxVGO1Xae+GPYlyfx+lvYXBaGLRQy2P1KGnt/oswbTxt83YowDyIwguFl0p98ONv+Cbxss3boYMO2H/uJoat3K/Kvqlfgn0rKi7FjO+OY91v6arsnzr38hsPi86lFVF4ofDygObTzCi08S9d39sN+6b15VOQTqkVWmiQMVIVtQY5pBDD5zYSBZi/UHih8FIBjwPM3idUVOP2ELWwEBZFxaWYse43rDt+RdH92vvtJAow/FB4ofBSjg4s2yjdEZcm1iM8KN1HxgHACTuew2zahmP47wHbQqM9n2fLUHih8AKAgost0rPy0X3BDsX2Z+/fXok4SanZiFq2V5F9NfZxR4KpjyL70hoeD0PY+4MQFF4ovFBwkYnX01jW+HZEZ/Ro5afIvggpKi7F9HW/4XsFbivFDu+IsLb+wvejRbaee78c3h6D2tbjVI2+UHix8/Bi68HjCOCCHQaXN1btxJbkPOH72ft+HzTwdRe+H0KqolQfLnudfsLWc/CYJ4Pwj8hWnKrRDwovdhxebD1oWtXzwqZ3nuRUjT6cvmzBoM93C90H9WUhWnTdUoiIxbuQVXBH6H7sscO/7QEmEP+IfJxTNfpA4cVOw4utB8tJO+twl1tQjNaztwjdx9pRXdG1WW2h+yCEh00HL+PN9b8J275fTRcc+qCfsO1rke23kJ7AoLb286WHwosdhhdbDxJ769/SdaYZGUXitm+vzeVE/+KPZ2LkmsPCtm9vU1jYem62p068FF7sLLxQcLEerxliq7JzchiC6tYUtn1ClCJ6biV7Ou+0/cAMiw135uzlvaLwYkfhpYnJjFIb1reXgwLgN6laZWi0UWJUB1OyMHRFopBt//xmD7Rp7C1k21rz2spE7DyXJXt9ezhXU3ixk/DSI2YHLmfny17fHg4GQOyYLfTkELEXIkOMvZyLfklKx8S1SbLXN/r7ROHFDsLLhqPpeHddkuz1jX4QlBHV2mJv9+0JKbPn9HW8svog9+3aSz8xWyd1NPK5m8KLwcMLffgfLifvDtrN3cp9uxvHdUdIoA/37RKiN9uPXsXodUe5b9cezk+AbV+sjPoeUXgxeHihD3312k43w2JLR6BK0CSJhFROROvmoenh8PNy5b5draFzeUUUXgwcXujDXj3eJ1L3Go44M28g120SYjTXLYXoNH871216ODvi9IfGPvZsnY7EaOd0luu3o0I1EQ4ouFTtuqWQe3A5ND2cggshVvDzckVqTCQ2juvObZt5d0qFPiGoBS41HDGqe6Ds9Y3+/lSHWl50goJL1XgfwNSvhRDb8D4mj87oB19PF67b1BJbnhx9sokXvhlrjCld6LaRwcJLsxm/oqhYXicOCi7Wq+EApEQb+/0iRClZuUVoP28bt+35uDsjaVZ/btvTmqYmM0pkrntm7gC4uzhxrUcNdNvIQP538JLs4HLO4Lc8eAaXve/3oeBCCEe+ni5IjYnEU635DEKXnX/H0LdJztvwRbPlzM0cK9EHannRMFseiX65awN8FBXCtyCN4DnEv7uzI84YvFOg3m09fAVjvz9W7TJ0q0/bbO2Yej8jtyjbcxcBzdw2SkhIwMKFC3HkyBFcvXoVGzZsQFRUVLXrxMfHY/LkyTh16hQaNWqEGTNmYOTIkVbv00jhxZ4/xFXh+c3LXh7H1AJRY4LwYC+Do2nBsvgziNl8gcu2jHKrpDL2eu7XTHj59ddfsXfvXnTo0AGDBw9+aHi5ePEiWrdujXHjxmH06NGIi4vDpEmTYDabERERYdU+jRJe7PXDWx1ewaWGI5Ay35jvkZqsaSHRG3uae0cptg6yea8ng2vjm9FduWxLa+Se75o+4oS4vw/gXI0yNBNeKuzIweGh4eXvf/87zGYzTp48Wf67YcOGITs7G5s3W3dPzwjhpdXMzcgrktd1i4JL9WguItvlFhTj+S924+z1PLVLUcXSqDZ4pmtjtcvQvTbTzLjF6epD572KTs6OgKdbDc7ViMdy/dbUq0tMTER4eHiF30VERGDSpElVrlNYWIjCwsLy/7dYLKLKU8R1SyEFl/vwCi5GfX9Eiz+eiZFrDqtdhma8vfEE3t54osLvaK4rdieiI7lN4xFoMhvy+D4zd4CszritZ28x5PtxL02Fl4yMDPj7+1f4nb+/PywWC/Lz8+Hu/uA35ujoaMyZM0epEoWTO0qlUT+oPIJLsJ8ntk/pxaEa+7Dv7A0Mjz2gdhm6Er5kV4X/pxY+63h7OCM1JpLLcW7EAOPu4oQng2tjd8pN5nWN+H7cS/ePSk+bNg05OTnlP5cuXVK7JNnkHsC/zTTm2Ac8TmgnZ0dQcHmIouJSjFt1AIEmMwJNZgouHHRfsKP8/fzlgH7PSUpJjYlEIIe7/EZ8lNqWPj1h8433fpTRVMtLQEAAMjMzK/wuMzMTXl5elba6AICrqytcXfX/xEjofHmP/vq63P32YjQ8TkJG/tZhq4vXbiP803iUGGqgBG2auOE4Jm44DgCYN7AZXun1mMoVaVP89EjkFhSj9ewtNm3HiC0OclunUi13+6jpsf/Lw2iq5SU0NBRxcRUv4tu2bUNoaKhKFSkjJ+8OrloKZK17dK6xDtKiYtvnM6nhSMGlMulZ+Wjxj00INJnRezEFFzXM+PVceYvMmoQUtcvRHE+3GlyOXSO2wMh9X2wNg1ol9Gmj3NxcpKTcPUCfeOIJLF68GL1794avry8aN26MadOmIT09HatXrwbw16PSEyZMwKhRo7Bjxw68/fbbhn9UWu6BZrQL9Ie/nMS/9/xh0zaor8GDYncmY/YWulBq2bIhIRjQqYHaZWgKtb4+yJaWKT28F5p5VDo+Ph69e/d+4PevvvoqYmNjMXLkSKSmpiI+Pr7COu+++y5Onz6Nhg0b4oMPPjD0IHXB082QM/r/uXkD4VJDUw1nNnlq6S6cvJJr0zb0cHAqJeHkNYz49pDaZRAZEqb2RuM6HmqXoQnD/xWPfX/ctmkbRjsv9F+8C+eusZ8r9fC50kx4UYOewkt6Vj66L9jBvN7g9n5YPLSzgIrU8fRnu3Ei3bZH3I12gpJr5Y6zmLv1vNplCLX3/T7449ptw3csptaYu/KLSmyeu8do5wejttZTeNFJeDHqB5DFh7+cxr/3XLRpG0Z6P+TQY+fbnZPDEFS3pir71lur1Mz+TTGqTwu1y1CdrbeRjHaeMOL1g8KLDsJLk2lmlMp457X8wWPFY7I2I70frPafu4lhK/erXUaV9NBMfb89p6/jldUH1S6jSkdn9IOvp4vaZaiGAsxfsnKL0H7eNub1tHxcUnjReHhJu5GHnot2Mq9ntBOXLSciBwAXDXQiYqG1loPFzzyOwd0C1S5DmOuWQoQv3omcAnkjX4tgzx3TKcD8pd2cLcjJL2ZeT6vvAYUXjYcXOQdfDQApGv3AyWHLCaihjzv2mPpwrEYftNIqQLMwa2eW7P2mvgjwcVO7DMVRgPmLkW4fUXjRcHjpPG8bruUWMa+nxQ+aXLaceD4dGoLn2ttXJ8ak1GxELdur2v5/GNsNHZo8otr+9eCXA5fKB6JTmgOA03MHwN3FSZX9q4UCzF1yOzRrMfhSeNFoeJE7CdkZA52YbDnhnJ8/CE6ODhyr0baM7AJ0jZE38rIt5g9qjuE9gxXfr1GkZ+Wj7yc7UaBwD+oWfh7YPOXBoSmMjALMXUP/lYiDf2Qxr6e110/hRaPhRc6B1rGeM75/xxhzF9lyotHaQSaa0iOEfjm4HQZ1bqjoPu1BSamEuRuS8PWhK4rt097+lhRg7pLzPkztH4iJfR4XUI08FF40GF5OX7Zg0Oe7mdez5wOrjFHeA2tsPpSOcT8kKbKvpVFt8EzXxorsi9xt3h/xf/tw6JJtYxpZy576JtlyfnmstjO2vaf/L4hybx9pqUWbwosGw4ucg+u3mf0NMelikMkMuR8yI90yq47cAQtZ0Zgh2pB85RYGLE2QfVywsJfwb0uAOTk7whCTFz7/5V4cTstmXk8rnxGW67dxxpfXsPGrDzOv4wFjzBb9zb4/ZJ+gwx57xC6CS6DJLDy4rBsditSYSAouGtG8fi1cjIlEakwklka1EbqvQJMZe05fF7oPLbDlAmyUyQu/f7O7rPXSs/I5VyIetbwIJncgNq0kYVuUlEpoOn2T7PWN8B5UJyUjF+FLdgnb/gsdvREzuLtmmoRJ9ZKv3ELE0gSh+zD6MQXQLWq5Hf218Nqp5UVDWnzAHly2T+oloBLlUXCpWqDJLCy4LI1qg9SYSCx8vgcFFx1pXr8WUmMicX7+ILzSoZ6QfQSazDiYwv5Uip7Ycu5QuqO8CAE+bnCRcdyfvqxMXyxeKLwIlJFdIGsKgOAAT/7FKMzev/1UJb+oRNgJMnZ4R6TGRFInXJ1zcnTAvBfaIzUmEvMHNee+/aErEg1xka6OLeeQFz9XfngC3s7NH8S8jpwHStRE4UUgOU13J2dHCKhEWe3myL9/bOTgEvXJJptnx63MtyM6IzUmEmFt/blvm6hreM9gpMZEYsFTLblvO9BkRk7eHe7b1Qq555IDlwuQX6SdqSDk+mFsN+Z19p+7KaASMSi8CJKSkcu8jr8zdN/jPSu3SNZcG4Cxg0ugyYyk63y7l61+pRNSYyLRo5Uf1+0S7RnaowlSYyIRE8m3w3W7uVvRWcbkfnoh98ugiC8ZSpMzKraWJ3q9H4UXQeT0Zzjwof4v3nJmOQWMH1x4WjYkBKkxkejZui7X7RLtG/ZkU+7HyrXcIsPeRvJ0q4FmdeXdhjfCeyKn/2TCyWsCKuGPwosAyVduMa+zbnSogEqUJfdg3/u+MSdZTLuRx/UE2KKuB1JjIjGgk33N7UQelBoTiaMz+nHdphEu1pXZOln+AxByzuVaIqf/pJZmrK8OhRcB5Dzu2DnYV0AlyrGlp3oDX3eOlWhDoMmMnot2ctvembkDsHmyfc1bQ6rn6+mC1JhIWX0bqmLUfjByW6tEP7quBDkhN/54poBK+KLwwpmcpL5p4pMCKlGW3J7qRrxdxPMbbNngcvYwWB+Rp0OTR7geR+3mbkW3aP0/cXM/ue+R3lukfD1d4Mx4pR+5hn1gVaVReOFMTlLX+/wjcg9uCi7VS42J1H2LHFFOakwkfpvJZ46eKzkFur9oV2a/qa+s9fQ+Ns7x2QOY19F63xcKLxyl3chjXkfvA9IlpWbLWu/Q9HC+hWgAr5P99km9DBnsiHjeHs5IjYnE4z58tme0ABPg4wYXJ/YB3IauSBRQjXLcXZxQ253tdWu97wuFF47k9HHQ+4B0Ucv2ylrPz8uVcyXq4nWST42J1P1ngqjPbIrEmbns37YrY7QAc+4j9gHcAP2/D3umsT82ruVxXyi8cJKVW8S8jt5bXYLodhEAPie1jeO6G+59Iepyd3Hi9pnS+4X7fudljEALACfScjhXohw5rS9aHveFwgsnvWSMb6Lnb9gXr92WNVv0uXkDudeiJh4n9dSYSIQE+theDCGVSI2J5PJFyUgBxsnRAVP6NmNe7+kv9wioRjlyWl+0OucRhRdOWJ8x+vnNHkLqUErvxfHM60Q+7gWXGsb5yPEKLoSIFhzgyeWzZqQA81a/x2St117H74G7ixNYn1vU6pxHxrmSqGhHUgbzOm0aewuoRBlye6F/8Tf9PxJextaT+LSIFhRciOIowFQk5/3IApBbIG8KFC04JGPcFzndIkSj8MLBqLVHmJZfO6qroEqUIacXOq/Og1pg68n7/PxBeKN3U07VEMImNSYSts6gZqQAs2ZkF+Z1Ws+WP/ms2nw9XZjXeUqDg/VReLGRnEHpujarLaASZTz1KfucTW18YZhB1mw9aafGRMLJkf1RTUJ4SuEwvYBRAky3FnVkrSdn8l2t2Dk5jGn5K5ZCMYXYgMKLjVgHpZsaJu8+qxbkF5XgZCb7Afvz+8a4PcIjuBCiFWXTC9jCKAFGTpCTM/muVgTVrcm8jtYem6bworDx/fUbXh6XMU38lrd7CqhEeRRciFHZ+tl86Ut+c3ipxdfTBZ4yWof1PPIua/cFrT02beutT7vG2lHXAdDtLYPrlkKUylivef1a3GtRGgUX9U2I3Qrz2conDHy7ZwNMHhSibEEGkxoTKftznpiWh/yiEt3fGj45dwDzezB0RaJuj2853Rdy8u7A28NZQDXsHCRJkjNcB5MvvvgCCxcuREZGBtq1a4fPPvsMnTt3rnTZ2NhYvPbaaxV+5+rqioKCAqv2ZbFY4O3tjZycHHh5iZ0ziPWDvn1SL92O7SLnxHZm7gDdn9CSr9yyaWZZvZ7YlPDtrt8x49dziu7z5zd76PpJP9FsCepG+KwnnLzG/EDCmpFdZPebUVsbk5lpmI+mtT0Q95642e1Zrt/CW17+97//YfLkyVi2bBm6dOmCJUuWICIiAsnJyahbt26l63h5eSE5Obn8/x0ctNdacV1GBya9BpeMbOuC472a1TRGJ10KLnys35eKyT+dUruMSgcZ2/t+HzTwdVehGu2xpQUm0GTW/We+Z+vKr0nVGR57QLeve9eMfmjPMMDq+Zvs8/eJIrzPy+LFizFmzBi89tpraNWqFZYtWwYPDw+sXLmyynUcHBwQEBBQ/uPv7y+6TGZd529nWv793uyjOWpF15g45nW2fqDPg/le9v4t1BYXr91GkMmMwP//o4XgUpXuC3aU1/numni1y1GdLZ9dI3Tg3TiuO/M6WuvMai05j01fvHZbQCXshIaXoqIiHDlyBOHhf80g7OjoiPDwcCQmVj1LZ25uLh599FE0atQIzz77LE6dqvrEV1hYCIvFUuFHCSWMy7/RL1hIHaLJGYxp2ZAQ/oUojIILuyMX/iwPAb0Xx8uaPkJtG47fLn8Nb67S71getrLlM6znTqwAZE3VobXOrCxYx7npvyReTCGMhIaXGzduoKSk5IGWE39/f2RkVN7ZtXnz5li5ciV+/PFHfPvttygtLUW3bt1w+fLlSpePjo6Gt7d3+U+jRo24v477pd1gbzrTa0fdqBj2E/iATg0EVKKcTzadlb2uPQaXCbFbEWgyY8jyfWqXwtWm5OLyILN083G1y1Gc3M/y0BVVfzHVi00T2UcD1+ukjaz9de7IeXJDAM09Kh0aGooRI0YgJCQEvXr1wvr16+Hn54evvvqq0uWnTZuGnJyc8p9Lly4Jr7HnIrZHA38Y201QJeKlMHZ3Wf78E2IKUUhJqYTPEs7LWteegsu+szfKL+xVPQVkJIvjLyHQZMbor/R/W4TFoenhD1+oEnq/fdSqIfvDHnqftJGFFm4dCQ0vderUgZOTEzIzMyv8PjMzEwEBAVZtw9nZGU888QRSUlIq/XdXV1d4eXlV+NGaDk0eUbsEWX45wB4E+3esL6AS5TSdvknWevYSXH7an4ZAkxnDYw+oXYoqtl9EeWjT8/w21vLzcoW7s7zLhN4DjJzWFy3OAWQN1teqhVtHQsOLi4sLOnTogLi4vzp8lpaWIi4uDqGhoVZto6SkBCdOnEC9evVElclErx9OOSZuYGsqXxrVRlAlyvhw/QlZ6+l9hnBrbD96FYEmM97eKO89MqLWs7eg9Qx9X6CtcebDgbLXzcnTb6ucnNaXgRq4qMvB+lq1cOtI+G2jyZMn4//+7//w9ddf48yZMxg/fjxu375dPpbLiBEjMG3atPLl586di61bt+LChQs4evQoXnnlFfzxxx8YPXq06FKt0ovhsTJAvxe2pNRs5nWe6dqYfyEKKSmV8O+DabLWNfK4ITl5d+7eLll3VO1SNCm3+G4Lw9vfsD+RpydyWxbbzd3KuRJlrRtt3ZfsMpm5+g1rrNT+Ii88vLz44otYtGgRZs6ciZCQECQlJWHz5s3lnXjT0tJw9erV8uX//PNPjBkzBi1btsSgQYNgsViwb98+tGrVSnSpVmGdhlGvF7aoZXuZlm/lr88xbMrQ7aIHBZrMur/4KOWnUwUINJmxJqHy29tGcH7+IFnrybn9rBWdg32Z12EdeV0rWKdyUXumaUVG2FWS6BF2We/j6vXixvo6T86OgKebPmebkDOqJqDfv+3D7Dt7Q9N9Wh4F8IfaRTyEUT8bn2w6K6tDu57fj82H0jHuhySmdfT6etW+vmlqhF0jYR2IiLXJUSs2JrJfGvQaXADICi6xwzsKqER9anayjIlsgWFPNuW6zenf78Oaw39y3aY1Ak1mLH7mcQzuFqj4vkWaMqiFrPCi59F3B3RqADCGl9yCYl2fE/WAWl4YqJ1KlcL6On8Y2023T1StSUjB9E3JD1/wPnr921ZHyeAyPTwIY8PVuRU86T87sfGEcsOc02flrv2mvgjwcRNQjXisr7eVvyc2vdtLUDXi7D93k2nAvY3jussa1K8qLNdvzY3zQtQl5+kAvQYXABRc8NeYLSJFtfFAakxk+Y9awQUAlrzcu7yO7ZPEX2D0/shwZVYO68C8jpxpRrRi7/t9mJY/nZkrqBKxWGeaZu0byRO1awnStLaH2iXI0puxg+bTrbQ375S1Pt/C/tiv3h8Hv5/IC+vzT3hh0YvsY2UoKTjAszyMXrcUokfMdhQKeAw00GTGbzP7w9vDmf/GVdAnJABYy77eibQcXT7EIGfizozsAt22NOkBtbxYKf545sMXusf6Cfp8RJp1VpJPhrcXUocSFu1kfzRaz4+D309UcEmYerdlQ+vB5X5+Xq5Inn+3RWbBUy25b7/d3K0Ina/f1of7/TazP/M6eh6F9vPn2jItH/4J20jshA2FFyuNXHOYaXk9fsPKL2KdbhJwqaHPj5CcuWrkzDarVbyDy+B2nuW3YhrX0Wer472G9miC1JhI/DazP2pxPJSvWgoMcxvJ28MZPjI6pZ6+rMzkubw91YVt3rxcLYzkJsN+U1+m5eWMCcaDPq88RIhRX25mWl7Ps0cvjmcfe4JnxzQ18bx4juxSB6kxkVj8kv46J1rD28MZJz68G8pqcdyuUQJM0uwI5nUGfb5bQCXaVFSsvwDDeqtLrX4vFF4EeL93M7VLkCWRcWwlvc4eHbuTvZPu0Rn9BFSiPF4XzSEhtZAaE4nZz3Xhsj09OBETiXPz5A+Vfz+jBJj5g5ozr6PXaQNWDGW7Tf7ef2lkalEovFghJYOt5/gb/YIFVUJ4mL2FfRRUX08XAZUoi9fFMjUmEp8MYxuN0yhcajgiNSYSq1/pxGV7Rggww3uyn+/6L94hoBLxwtuzzbH34ym2vpLEehRerBC+ZBfT8k6ODoIqEYd1AD69PnUjZ+jukzKaxrWGx0Xy8+faGu4xcbl6tq7L7b0Y+tl2LttRE2vrS2au8Wfk1jPWfi9qoPBCAIBpYCJAv0/djFp7hHkdvY+UySO4pMZEMndYtAepMZHMY4Dc72B6oazO8loip/Vl7W72kXq1gPWLW/IV1hnx1Mfa70WN+ZwovHDmpL9GF1KNhKm91S7BJraeVHzda1Bry0M08HW3+T1qOZOts7wWvd/nUablTeazgioRi/WL2wCVJzBUgpwvhbai8MLZlneM+dSFESzffpp5Hb0/9mvLSeW3mf1xdJb+b5kpxdYAo/f+L2/2b612CZpkqPl3NITCy0OUlLJ99IIDPAVVIs6JtBym5TdN1NfgY2Xmb7/ItPzy558QVIkybLkYpsZE6nKsIrXZOuWA3gMMq0XmY2qXQKogZwoIJVF4eYjlO9m/resN66iXrRrynfBSq/p3rK92CbLZGlyIfPdOOSDHvrM3OFajLNYvNp/vviKoErHWjurKtLwex3vpExKgdgnVovDyEB9vS1W7BMLBsm2nmJb3dNbvoWHLiJcUXPiR+14Ojz3AuRLl2MsXG9YJDO1hvBfWuxS20u8ZmhAGMXGpTMvvfM+2J0jUJHfESwou/P0wtpus9SZ+vY1zJcqZ+CRbi6Wcvmh6Yw/jvXy1jX38LFtQeOEodnhHtUsQTu7JWG/8vFzVLkGWmJ/lddDlOXIs+UuHJo/IWu+XM0WcK1HO1Ei2vmKsfdGINi3YeU7R/VF44Sisrb/aJTDLyC5gWl7uyVhNrBPB+Xvqt6Pqsr3sj0Z3rqvfCTb1QG6LVtRH9tV5V2/0OlCnUdAZy851jYlTuwThWCeC2zpZn2O7TP4v20jQZdZNpttFoskJMEn6G9vMruh1oE4WWr6bQOGFkPvo9RHh9b+xzcEFUD8XJcVEtmBe5+m5+mx9OTQ9nGn5xZuSxBSiIelZ+WqXwEzLdxMovFRDjx82Yp9Yn6YCgLHdtHtiMqJhTzZlXudEnoBCFMDaZ2xpQrqgSrSj7yc71S7BUCi8VKP7An3OfCqKlpsQeVnwVEu1S5CF9WkqAJj+jPH/nlrz28z+zOu88y2dh4ygoITG2uWJwguxmpabEKvCOrfP0B5NBFWiLXMHPKZ2CXbJ28MZrDclfzypzxbgt3s2ULsEYmAUXoihqTFhmNLeWr2deZ0RYc0EVEKs8bud9DOaPChE7RKEs4fWaK2i8MLJz2/2ULsEYqd+Pl3ItPyIzmyjgxL12cOto/jj+hvITY+t0SLZMro3KwovnLRp7K12CYRYZe5gtnlZCH+Ln3mcaXm93jpiMXLNYbVLIDaSO7q3HBRe7Nj2o1fVLkFTtD6LKjGOwd0C1S6BEF2j8GLHRq8z/mRhLLQ+i2pl3v6GbZDBz59rK6gSQogR1VC7gCpQeCFEx346xTa9w1NdGgmqhLBa8mxrtUsQbsXQ9mqXQGyUyDjgoFIovBCrPOLmpHYJhBhKVOijTMtPWZsgqBJxwtvXU7sEYiOtTlKrSHj54osvEBgYCDc3N3Tp0gUHDx6sdvnvvvsOLVq0gJubG9q0aYNNmzYpUSaphl7n+yHEKH6gyY4IKSc8vPzvf//D5MmTMWvWLBw9ehTt2rVDREQErl27Vuny+/btw0svvYTXX38dx44dQ1RUFKKionDy5EnRpZJqaDV9E+tRfxdCiFEIDy+LFy/GmDFj8Nprr6FVq1ZYtmwZPDw8sHLlykqX/+c//4kBAwbgvffeQ8uWLfHhhx+iffv2+Pzzz0WXSoihUX8XQohRCA0vRUVFOHLkCMLD/+rw4+joiPDwcCQmJla6TmJiYoXlASAiIqLK5QsLC2GxWCr8EEKIHix6upXaJRCiS0LDy40bN1BSUgJ//4qjEPr7+yMjo/I5ZzIyMpiWj46Ohre3d/lPo0b07ZIQog/Pdw9SuwRCdEn3TxtNmzYNOTk55T+XLl1SuyRCCCGECCR0/Jk6derAyckJmZkV56zIzMxEQEDlA4IFBAQwLe/q6gpXV+pMSgjRn293/a52CYToktCWFxcXF3To0AFxcX+NAlpaWoq4uDiEhoZWuk5oaGiF5QFg27ZtVS5PCCF6NePXc2qXQIguCb9tNHnyZPzf//0fvv76a5w5cwbjx4/H7du38dprrwEARowYgWnTppUv/84772Dz5s345JNPcPbsWcyePRuHDx/GxIkTRZdKiKF9v/ei2iUQQggXwqctePHFF3H9+nXMnDkTGRkZCAkJwebNm8s75aalpcHR8a8M1a1bN6xZswYzZszA9OnT8dhjj2Hjxo1o3dr4Q2lrWUZ2AQJ83NQug9hg6s+nqYMoIcQQFJlzaeLEiVW2nMTHxz/wuxdeeAEvvPCC4KoIi64xcUiNiVS7DELs1vCOj6hdAiGaofunjQghRI/W7j7PtPz857sJqkSczYfS1S6B2Cgjm23yV6VQeCFExwa382Ranvq9aIfJfFbtEoQb90OS2iUQG/X/NF7tEipF4cWOxQ7vqHYJmrL18BW1S2C2+KVeTMtP/fm0oEoIIUZkKSxRu4RKUXixY2Ft/R++kB0Z+/0xtUsgdoLGdyHENhReOElKzVa7BEKsMvV/u9Uuwe6xju8yJKSWoEq049sRndUugdjo5zd7KLYvCi+cRC3bq3YJxE4917Ym0/LfH6PJS9WUk3eHeZ1PhvUUUIlYRcWlTMv3aOUnqBJxdiRVPueevWrT2FuxfVF4IYa2+pVOapcg3KfDw5jXWbbtFP9CiFXazd2qdgmKeH2l8b/QjVp7RO0S7BaFl2p4uTqpXYKmbD96Ve0SmPVsXZdp+TUJKYIq0ZaYuFS1S7BLaTfymNfR6y2j3ReohY+IQ+GlGlvfDVO7BE0Zve6o2iUIN31TstolyDJ3wGPM68z4IVFAJaQ6PRftZF5Hj7eMCBGNwks1aDh8ohcjwpoxr/PtoSwBlZCqrNzBPq5Le+W6EHCVnpXPtPzksEaCKtGO/aa+apdgKBRe7Nwjbsa/NebAuPx1S6GQOkSTM3x8oMksoBJSmblb2UbUBYD10/Q5JUfYoh1My789oK2gSrRDj1+GtdwhmcKLnds6ubfaJQi3bRLbQG6d5m8XVIlYcoePb/aPTZwrIfeTExI7+QooRCF32B400qWNiX+oXYJwWu6QTOGFIy2n1Kr4ebkyLX8wRX+3GoID2IbQ1zM5ze9FJZJuW5v0QG7r1nfv67PVhfURab2a9ONJtUuwaxReONJySuVl6Ar76OTJes9eK+Q2v+u1tUnrxsXGy1qPdeweLXnt32yPSM/s31RQJURJK4d1UHR/FF6IXZjRrwnT8t0XsN2z15Kdk8NkrUf9X/gqKi7F5rO3Za0rZ+werdh7ke0R6VF9WgiqRDvG96yndgnC9QkJUHR/FF4eQuk0ScQY3bel2iUoJqiu/G/tFGD4aTbjV1nrKTnEOm/2Mk3KntPXmZafOuAJQZWIU1IqqV1CtSi8PITSaVINmyY+ybT8ibQcQZVoy6aDl9UuQbbUGPn9JSjA2OZgSpZN76GSQ6zzxjpNyqReDQVVItYrqw8yLe/kyPrMo/q+2qbtATspvHCWfOWW2iUwa9XQi2n5p7/cI6gSsVjvrb+5/jdBlSjD1gBz8Zq8Wx72LNBktqlfmC1/Mz2aNLCd2iWQKizYyTZ5qNIovHAWsTRB7RJIFeTcW9djGL3X2lFdZa/be3E8tcIwsPW90ntwWfKrvsM+0RcKL4RUQ+9htGuz2jZvgwJM9U6k5dh9cAGAJbvYbrMuerqVoErEWr8vlWl5uR3o9SR2eEfF90nhxQrOdvAu/TCWbYAzvQ7QJGeW6azcIgGVKIfHhTHQZMba3ewjxBpdoMls823UhKn6HyhydTz7LYbnuwcJqES8yT+xzchuSwd6tbAOFRHW1l9QJVWzg8uy7bZOCmNaXo+DNHVowja0vF4HaGKdZRoA2s/bJqASZfEIMCbzWQSazLr8fPP2y4FLXFqkHB2AxnU8OFSkrpmbf2davrmLoEIIF/0+jVe7hIei8GIF1uT89w37BVVCeJg3kH0SQ70OWncvXrcmms34FV3t9FZS2o08BJrMmLjhOJftXYjW/+2i2J3sM7FvnDFAQCXibT18hWn5N57U59OqeTqY34HCiwAbjvypdgmyjAt9lGn5Xw5cElSJWK/0eox5HT0PWncvXgEmA3dvmUz+7y4u29O63IJiBJrM6LloJ7dtGqGfCwDM3sL+SK27iz4nhB37/TGm5d8f2F5QJYTCCyn33tOPMy3P69unGt7vwxbUAH3O61QZnhfN9b/lItBkxviVm7ltU0tOX7Yg0GRG69lbuG7XKMFFzq2z7YwTpeqZHsd3YW1lVmtQRQovVmLt6KnHTp5yDrT8ohIBlYj3Zv/WzOsYaV4n3hfPX8+VINBkRqDJjP3nbnLdthqWbTuFQJMZgz7fzXW7CVN7Gya4ZGQXyFpPrxOl/rQ/jWl5Z0F1iMba30WtQRUpvFiJtaNnmE47ebb2ZzuxvPaFfr9xTw9nf9ph1vfGGctC1EV02Mr9CDSZ8eYqvq0Voh258Gd5AIuJS+W+/dSYSEN0zi3TNSaOeR09t7q8vfEE0/IHZvQTVIlYeujvAlB4EYZtajLtWDu+O9Py+zMFFaKAseHs40x8ffiy5uf8YJEaE8k8PYS1NiUXl4eBMcu12cE3/nhmeY1Dlu8Tth+jtLaUYe24WkavrS5pN/KY1/H1pEeqRKqhdgFEWzzd2D8SB1Oy0DnYV0A14i16uhWm/nyaaZ2m0zcZ6mLUqqEXUmMihQ5Gt+1Cxf4RIzrXxtzB8kf/lSsn7w7azd2q6D6N9Fkpw9pxFQB+m9lfQCXK6LuYraP2gFZugioRa9/ZG0zLqzmJqIMkScb5GgnAYrHA29sbOTk58PJim7PnYQ6mZDH1e1g7qiuXEU6V9suBS8ydcfV8gpZz0Y5oXhdfvcY+4J3WqTma7qiufpgZ1Znb9nILirl3tGWx/Pkn0L9jfdX2L4qcz4iPWw0kzY4QUI0yWF/zuXkD4VJDfzc2WF8n7/M+y/WbWl4YsLYuDFu5X5cX9ae6NGIOLzl5d+Dtoc8uamtHdcWwlWxj82xJvoai4lJdnqCqkxoTidOXLdw7qlpj5f7rWLlfm7eXWOnxuLfGUx/L+/voObjImV3eaOcFLRL6DmdlZeHll1+Gl5cXfHx88PrrryM3N7fadcLCwuDg4FDhZ9y4cSLLJByEKdwUz5Pc1rFmM37lXIk2lN1GGt6RbdRlAiyNamPY4JJfVIKTMoawWv78E/yLURDr7PJqzPPDw4m0HKblG6jcpUdoeHn55Zdx6tQpbNu2Db/88gsSEhIwduzYh643ZswYXL16tfxnwYIFIstk0rQ229MCSanZYgoRjPVepj6H5fuL3AuOkSctnP98N8NeiHkb2t4bqTGReKZrY7VLEablTHlPFur51hlrHxBAnXl+eGCdo2uTSd0+TMLCy5kzZ7B582asWLECXbp0QY8ePfDZZ59h7dq1uHKl+p7qHh4eCAgIKP/h3XfFFusnsF3Uo5btFVSJWHKe3dfrZI1llg0JkbXeO2vYOy/qSWpMJPPEnfYiPOju+7NgqHodF5UgN6SfmzeQcyXKGh57gGl5fT62II/a3QSEhZfExET4+PigY8e/mtDCw8Ph6OiIAweq/0D85z//QZ06ddC6dWtMmzYNeXlVP6ZWWFgIi8VS4Ucktf9gSvpycDum5fU6WWOZAZ0ayFrvx+NXDD9ZYYcmjyA1JlLWrNxGFNbobmhZ8YbxW6bkBpe+wX667vshp9V8p06fqGK9ZaQFwj5ZGRkZqFu34sBuNWrUgK+vLzIyMqpcb/jw4fj222+xc+dOTJs2Dd988w1eeeWVKpePjo6Gt7d3+U+jRo24vYaqeLmyzcuh12HlB3VuyLyOnM5tWiL3NolR+7/cr2frukiNicTOyWFql6KKp1u5IjUmErETjB9aANsmJP33aH5PjqlBTqu5Xr/cst4ySpjaW1Al1mMOLyaT6YEOtff/nD17VnZBY8eORUREBNq0aYOXX34Zq1evxoYNG3D+/PlKl582bRpycnLKfy5dEj9Z4NZ3w5iW1/Ow8k+3Yrt/y9q5TYvWjOwiaz0j93+5X1DdmkiNiURqTCSef0I7t3VFmTewGVJjIvHZiHC1S1GU3AlJ9d5XSk5LxJa3ewqoRJu0MFI086PSU6ZMwciRI6tdpkmTJggICMC1a9cq/L64uBhZWVkICLB+mvAuXe5eSFJSUtC0adMH/t3V1RWurq5Wb4+HAB/2AYj0+ljtJ8Pb42fGVoVNBy/LarXRim4t6sheN9Bk1v2Jm9WiF5/EoheB65ZCdJq/Xe1yuHn+CS8selHM6MN6IDeMqzlwGS+sLREA0Lx+LQGViLfn9HW1S5CFObz4+fnBz8/vocuFhoYiOzsbR44cQYcOHQAAO3bsQGlpaXkgsUZSUhIAoF69eqylClXfyxVXLIVWL//3Dfvx6Qv66/ToUsMRbk5AAcP8i2+u/w2pOg4vAGwacdYeAwwA+Hm5VnjdY5abse2CigXJMDq0LmY8S/16bGlFVGuiPl6OXGB/dnLd6FABlSjjldUHmZbXwi0jQPAIuwMHDkRmZiaWLVuGO3fu4LXXXkPHjh2xZs0aAEB6ejr69u2L1atXo3Pnzjh//jzWrFmDQYMGoXbt2jh+/DjeffddNGzYELt27bJqnyJH2L1XVm4R2jNOvqjXC5qc11rbwxlHdNp57V62nMT1+vcWgXV0aqU81dIFn7+qzwn0RLH3z7yc16/X1y2ntVTka9XMCLv/+c9/MHHiRPTt2xeOjo4YMmQIli5dWv7vd+7cQXJycvnTRC4uLti+fTuWLFmC27dvo1GjRhgyZAhmzJghskxZ5Ey6lXzlli6bFuW81pt5d3Q96m6ZnZPD0HtxvKx17bUFpjKdg30feC+UDjRPt3K1uz4rrOw9uOxIqvphkqroudVlwBLrGgXKtPYRU4ccNLeRDeKPZ2LkmsNM6+j1AJfT+gLo9/Xey9aOuEZ4D4jx2fI5Pz9/EJwcHThWow57anUB2F/vydkRsibvtRbL9Vt/PUg1RM5IinodD0Tu9O6nL4sdd0cJtp6c7OkpJKJPtnxGP4pqbbfBRc+DN8oZPVhkcGFF4cVGTozH7PvrtXff31pyRstUY4I/ESjAEKOy5bPp6AC83PVRjtWoIyO7QNZ6HZrod/4v1tGDtXZ7jMKLjba804tp+Y1Hs8UUogCXGo4IC2QboA8AOs1lv92kRRRgiJFkZBfY/Jm8EK3fWyb36hoTx7yOngdqTMmofoLkynQO1tbkBxRebBQc4Mm8jpzmOq2IHTeAeZ3reUXIybsjoBrl8Qgw+UUMz50TIkCzf2ySdcG+l577etxrXGy8rPWC6tbkW4iCwhk76qo9g3RlKLxwwDoFOmtzndbIGYG23dytAipRh60n7ZYzN+P5fxrn/SD6Emgyo6jEtuc0jBJciopLsfnsbeb19Dzh5HWG8cnKqD2DdGUovHAgp+OunEm/tELuCLSDP9fnDNuVsfXkffjqHbqNRBTH4zNnlOACyJuTrF8zfU842VnGKNhaHPJCv38BjWlam22uBzmTfmnJ3vf7MK9z9HK2oW6Z8DiJU4AhSjiRlkPB5T6DP90sa73/G6XfCSdzC4rB2uam1TmbKLxwsn4C+3ween6MuIGvu6z1Ws6Ud8LQKl4BZv+5mxyqIeRBgSazrLl67mek4JJfVIKjmexfpLR6IbfWE7O3MK+j1YFVKbxwIqdZTe+PEcs9mQVPN1ZrA4+T+rCV+6kVhnCVlVvE7TNlpOACyP8SpdULuTVyC4rB+tiE1h6PvheFF47k3EpJvnJLQCXK2TSRfdbd4lIgPStfQDXq4XVyDzSZcSIth8u2iP1qM3OLrBGxK2O04CI30On9fWgto9VFa49H34vCC0dybqVELE0QUIlyWjWUNwVD9wU7OFeiPl4nt6e/3EOtMESWsrFbbhUVc9me3i/Y90u7kSdrvY3junOuRFlyhqr4doS2+/ZQeOFs+yS2QesA6P6bttwTnBEv0DxP9oEmMxJOXuO2PWJsj0032zx2S5kXOgUYLrgAQM9FO2WtFxLow7cQhckZqqJHKz8BlfBD4YUzOYPW8ehMpzY5oQ0AggwaYFxY542owohvDyHQZJY1NgOxD3tOX0egyYw7nKZNOzdvIBYO6cBnYxpir7eL5Ex9oPVWF4DCixByLuR6f9pETmgDAAnAN/v+4FuMBpz7aBD2m/py216n+dsN2VJF5EvPykegyYxXVh/kts3UmEhdj2FSFbnHjtwvZVoipzVO660uAIUXIeRcyIet3C+gEmXJ/YbywU8nUVJq24ifWhTg48b9W1ugyYyNicYLe8R6JaUSgkxm7v3G9N7CUJVmNoR+uV/KtELOAyF6aHUBKLwIkzC1N/M6nTg9HaCmk7MjZK3XdPomzpVoB++LwqQfTyLQZMbWw1e4bpdo3/trj6Dp9E3MA41Vp72/k2GDS1ZuEYpkrmuE90TOAyF6aHUBKLwI07gO24i7AHA9V/8TGHq61UBjmV9WjHxbJDUmEmtHdeW6zbHfH0OgyYztR69y3S7RntXx5xBoMmNdUgbX7Z6ZOwDr32WfbFUv5D4urue5i8ocTMliXkfL47rcj8KLQEdn9GNexwgTGCbMkP+NxcgBpmuz2kK+zY1edxSBJjM2H0rnvm2irpU7ziLQZMbMzb9z33ZqTCTcXZy4b1cr5J5LhnZoaIh+P0NXJDKvo+VxXe6n/7+Qhvl6yptHXO8D1wG2NbkaOcAAd98bTg8jVTDuhyQEmsxYu/s8/40TxZSNjBtoMmPuVv5/y5/f7GGIWyLVseUcsuCFdhwrUYec16+3qQ8cJEkyVE9Ji8UCb29v5OTkwMtL3gBqPJWUSrL6cxjl5GLLScQo70FV0rPyhQ7W18DLBZsmhWlyRljyoCMX/sSQ5fuE7sPoxxRA5xy55xUtvHaW6ze1vAjm5OiAt/s2Yl6vrYyhnLXIlseFO8zVfwfm6jTwdUdqTCQ8BTXdp1uK0G7uVhrsTuPW7bmAQJNZaHBZNzpUExcn0ew9uADyRi/nOayDUmqoXYA9mNyvLZbGXWJax1JQjKzcItm3nrQiwMdN9ro384ow56dTmPXM4xwr0p6TcwcgK7eI21w0lRnx7SEAQEMvV5gn9aLWGJUdTMmS1SdBDqNclB/GluBihA66ANBaxnvg4uhg03laLXTbSCHJV27JemzNKCceW08sRuhAZ43445kYueawIvuKiWyBYU82VWRf5O5Ip7yG77dGwtTesp561CNbzi9/6/IoPnyuNcdq1JGTd0fWAx9ausawXL8pvChIzgHWtI4H4mSMGaNFtpxgzs8fBCdHAb1cNar9nC3IyuczuZ41Fj3dCs93D1Jsf/ZCdItaZVYO64A+IQGK7lNNtpxXajg6IGX+II7VqEfO+/DD2G7o0OQRAdXIQ+FFo+Elt6BY1rTkJ2dHwNPNGHf4bDnRfD4sBE+FNOBYjbbJ/SZlq06NvLB6TDdDP0Yr0unLFgz6fLfi+7XHljRbn0zUUquDLZ780IxLt9nX09rrp/Ci0fACAN1mm3GFfZ4szX3IbGHLCadXsA++Hq3v6elZpWTkInzJLtX2Hzu8I8La+qu2fz1Yu/s8TOazquzbCcB5A50frEXB5S65X4oPTQ+Hn5ergIrko/Ci4fACyDvoPJwdcfpDY3QqA2w78TgAuGiQEw8LJR6ltca3IzrrZghxUdbvS8Xkn06pXYahWmVZUHD5i5z3wt3ZEWc0eD2h8KLx8CK3454Wk7It6AQkj5JPqlijZd2a+O7NHoa9iJ5Iy8HTX+5Ru4xyHk5A4j/62+0TY3Te+Ivc90Kr7wGFF42HF8B4Hzq56EQkn9YuqvdbO6orujarrXYZTH7an4a3N55Qu4xKhT/2CL58tavdPHlXGTpf/OW6pRCd5m9nXm/v+33QwNddQEW2o/Cig/ACUIApQyck26jxRIutVgxtj/D29VTZ95qEFEzflKzKvuWwx464laHzREVy3o8ajkDKfO2+DxRedBJe5N4+Cg2qjf++wXeGYrXRiYkPvV2YSdWoo/RdPJ66M9r5wahffDUxPcBHH32Ebt26wcPDAz4+PlatI0kSZs6ciXr16sHd3R3h4eH4/Xf+s6lqhdxRDRMv3kR+UQnnatRl60EVaDKjqLiUUzX6NbxnMFJjInU1tT35Sx0PZxyd0Q+pMZEUXACEzo+j4HKfJz6QF1z0OAVAdYSFl6KiIrzwwgsYP3681essWLAAS5cuxbJly3DgwAHUrFkTERERKCiQ8WyxTsg9sFrO3My5EvXZepJpNuNXTPnuEKdq9K1zsC9SYyKRGhOJBU+1VLsc8hCrX+mE1JhIHJ7ZX/dTgvASaDLjqsW2c7/RgktO3h38eYd9PRcnfU4BUB3ht41iY2MxadIkZGdnV7ucJEmoX78+pkyZgqlTpwIAcnJy4O/vj9jYWAwbNsyq/enptlEZuc/pA8Y7OAHbbyEBxnxfbJVbUIxBc7cgjRqoNGHJs60RFfqo2mVoEp0DKmfU20VlNHHbiNXFixeRkZGB8PDw8t95e3ujS5cuSEys+rHQwsJCWCyWCj964+lWA4Eyc1bIHOVHYBWNx4HG4+RnNJ5uNZAw/25rzNEZ/dQuxy4tfubx8hYxCi4PysotouBSBbnvy5m5AzhXog2aGZghIyMDAODvX/E+r7+/f/m/VSY6Ohpz5swRWpsS4qdHyvpwZuffMcTs0/dLjZH3ftwr0GTG0Rn9DPfe8ODr6VLhBK/mCLFGZ29zDckVMmcrsvNl3BO5DwWXv4QG1TbsNB9MLS8mkwkODg7V/pw9q+wJcNq0acjJySn/uXTpkqL750nuQae3x2StxeMk1H7eNjSlVpiHGvZk0/IWgU0Tn1S7HF3r3MgLZ+YOKH8/Kbg8XKDJTMGlCuEL5HUpAGC4p1LvxdTyMmXKFIwcObLaZZo0aSKrkICAuwd4ZmYm6tX7a/yHzMxMhISEVLmeq6srXF2NM+rs0Rn9ZIWRQJPZkAcujxaYEhj3/RGhVUOvCu9VUmo2opbtVbEibWvk7Ypf3ulltyPe2kLucBGVMeLxnVtQjJQsebPLG/H9uBdTePHz84Ofn5g5TYKCghAQEIC4uLjysGKxWHDgwAGmJ5b0ztfTBa4ACmWsa9QLNI8AA9x9f3ZODkNQ3ZocqrIfIYE+D3yuvt97EVN/Pq1SRepa/Uon9GxdV+0ydK/ZPzahqITP8yJGPO8BkP0gh1H7udxL2NNGaWlpyMrKwk8//YSFCxdi9+67U8QHBwfD09MTANCiRQtER0fjueeeAwB8/PHHiImJwddff42goCB88MEHOH78OE6fPg03N+se89Lj00aVseVibdQDmWcnXKO+R2rSysSRvNSt6YzN74ZRnynO8otKuA31sOC5xzG0SyCXbWmN3PNd96a18Z8x+rxdpIkRdkeOHImvv/76gd/v3LkTYWFhd3fu4IBVq1aV34qSJAmzZs3C8uXLkZ2djR49euDLL79Es2bNrN6vUcILIP/D6+gAXIg25sX5iblb8Wee7ffGAW3P8WE0F6/dRu/F8WqX8QAael9Zz3+5F4fTsrls6/z8QXBydOCyLa2x1y+vmggvajFSeAHkf4h7PeaHr1/vzLkabeA9l4+eD3ZC9MCWsawqY+Rj1l6DC6DTcV5I5eSOx7Hr9+uGm0KgzP2P+doq0GTGxWu3uW2PEPKXPgt3UnCxkj0HF1YUXjTO19MFHjJbRo04hcC9eB6svRfH08B2hHCUkV2AQJMZF27mcdlekLeDoS/Qtpx/zs8fxLESfaDwogOnbei/YvQLcmpMJOp58ZuzI9BkxsGULG7bI8QeNZu+idsj0ABwcnYEdk4z7gW6+TT55+mFz7c1bN+f6lB40QlbvnEYPcAkTu+L32b257a9oSsSEWgyIz0rn9s2CbEHB1Oy7s7wXsqvK2VqTCQ83TQzGDx3gz7diUKZb5enixNe6NiIb0E6QR12dYbuiVZPRFCzh/eNEFukZ+Wj+4IdXLfp5Qgcn2/sY2/Ozyewam+a7PWNdm6iDrsGRi0w1UuNicTgkAZctxloMiP+eCbXbRJiBPlFJQgymbkHl99m9jd8cNl0/AoFFxtQy4tOUQtM9YqKS9Fsxq/ct7tudCg6B/ty3y4hevPMP3fj+FUL9+3S+enhjPoe0TgvdhBebB2l0qgf/vu1/OBX5N8p5b7dTROfRKuGxv18EVKV+OOZGLnmMPftbp/UC8EBnty3qzU//3YFb/33mOz1jXzupvBiB+EFAEb+ex/if/9T9vpGPgjudd1SiE7ztwvZtr2ccAnZc/o6Xll9UMi27eVcNCr2IHacvS57faO/TxRe7CS8AEATkxm2tCsY/WC4V9vZW2ApkDdD68NQiCFGte/sDQyPPSBk21ve7onm9WsJ2bbW9IiOw+WcAtnr28O5msKLHYUXwPaOuPZwUJThPbXA/ezpZEyMLeHkNYz49pCw7dvTeafZdDOKbPiWaS/vFYUXOwsvgO0B5ty8gXCpYT8Pn0389ih+OXlV2PY3juuOkEAfYdsnRJTtR69i9LqjwrafMLU3GtfxELZ9raEvl9aj8GKH4QWw/SD5W2hDfPhsO07VaJ+oJ5LutfqVTujZuq7QfRBiq/yiEoxdvhW7L/Pv3F6mbW3gp/fs50IMUHBhReHFTsMLYPvB4gLgnJ0dMCkZuQhfskvoPp5vWx/zh7azq9Yton1pN/LQc9FO4fs5M3cA3F2chO9HSyi4sKPwYsfhBeAzGJ09Hjgin6a4F3XuJWoT2Qn3Xvb6WafgIg+FFzsPLwAFGFu0mbkZt4pKhO9n5bAO6BMSIHw/hABATt4dRHy0FRniP9r4dkRn9GjlJ35HGkTBRT4KLxReAFCAsYXop5Lu5Qhgq51+QyXiKdXKAgDhwa5YMTpckX1pTW5BMVrP3mLTNuz1fFuGwguFl3I8Asz5+YPscsp1ALh47TZ6L45XbH9hTWrjXyM72V3/AMKX0p9bwP6eWLzXwCW7cCYj16Zt2HtwASi8UHi5j60D2QHAgucex9AugTzK0aUTaTl4+ss9iu6TbisRFlm5Reg5bxtsu4SyOzqjH3w9XRTeq3ZQCzc/FF4ovDygR8wOXM7Ot3k79n6QJaVmI2rZXsX3a899CEjV8otKMG7FduxKEzNydHUOTQ+Hn5er4vvVEgoufFF4ofBSqQ1H0/HuuiSbt0MHG3Dkwp8YsnyfKvuOHd4RYW39Vdk3UV9O3h0MjN6KK3fU2f9+U18E+Lips3ONsHVi3DJ0Lq2IwguFlyqVlEpoOn2Tzduxx3EbKqPG7aR7tavvhf+MDYWnWw3VaiDiqdGH5V7ODsCBf9j37aEyL//ffuw9f9Pm7VBweRCFFwovD8WjubN7oCf+M64Xh2r0T+2LSxmalsA4diRlYNTaI6rW0LNJbXxFHcjL8ThvAhRcqkLhhcKLVehA5C8n7w76z9uKTHGjrFvNw9kR294NQwNfd7VLIVZQqz9VZaizeEW8WqwBOl9Wh8ILhRer8Qow9vyYZFVET3DHisKMtqjZb6oqP4zthg5NHlG7DE35Zt8f+OCnk1y2RcGlehReKLww4RVgBrf3w+Khnblsy0jU7hdTnbWjuqJrs9pql2F4RcWl+PuG/dhw5E+1S3mAr7sztrzby+6fHKpMkMkMHhfIl7s2wEdRIRy2ZGwUXii8MBvw6S6czeQzQgR9u6icmo+1sqBv37bTQn+Vh6HH76vGY7TcMtQqbT0KLxReZOF5wNLjlNU7fdmCQZ/vVrsMq3m5OmHru2H0N72PllvVKlOvlivM7/Skp4aq0e+TXfj9On2RUwOFFwovNuF1Gwmgg9ca8cczMXLNYbXLkO3vfZthbN9gw04hkZFdgK4xcWqXYZOf3+yBNo291S5D04qKS9Fsxq/ctkfnPnYUXii82IxngLH34cOtVVRcivd+2Icfj+WoXQp3Wu1bk19UggmrdmLHxUK1S+Fuzcgu6Naijtpl6MLUdb/h+6OXuWyrjhtweDYFFzkovFB44aJbdByu5BRw2ZYTgPP0TcRqOXl3MGD+VlzVdvcYojE0+jIb3q0tv83sD28PZ27bszeaCC8fffQRzGYzkpKS4OLiguzs7IeuM3LkSHz99dcVfhcREYHNm60fhpnCC185eXfQbu5Wbtuj+VDY5ReVYPy/4xD/h0rjwRNNo4638kz5XxJ+OJbObXt0m8h2mggvs2bNgo+PDy5fvox///vfVoeXzMxMrFq1qvx3rq6ueOQR6598oPAiBs/bSAAd6LZIOHkNI749pHYZRCUOAMwTn0SrhnR+k4Pngwll6HzGB8v1W9iEKHPmzAEAxMbGMq3n6uqKgAAa2VFrUmMi0WrmZuQVlXDZXqDJjJ2TwxBUtyaX7dmTnq3rlp8s07Py0X3BDpUrIqJN6RmMNwc0M2ynaKX0XbQT52/kcdsetSSrR3OzucXHx6Nu3bp45JFH0KdPH8ybNw+1a1fd0a+wsBCFhX91trNYLEqUaZdOzx2A65ZCdJq/ncv2yuYConEQ5Gvg617hW5+Whpgn8tFoyHyJCPnU2qIu4R12Y2NjMWnSJKtuG61duxYeHh4ICgrC+fPnMX36dHh6eiIxMRFOTpVPDDZ79uzyVp570W0jsXjfRnq6jQ8+e7k7120S/Y1DYq9oHB0xSkolPPaPTSjleJXzqQEkzaPgIoKwPi8mkwkff/xxtcucOXMGLVq0KP9/lvByvwsXLqBp06bYvn07+vbtW+kylbW8NGrUiMKLAl76aj8SL9o+Nfy9Eqb2RuM6Hly3Sf6SlVuEsHnbQO2T6qKngsT7Ii4FC7clc90mPU0klrDwcv36ddy8Wf3FqkmTJnBx+WtMD1vCCwD4+flh3rx5eOONN6xanjrsKiu/qAQtZ1r/NJi1zswdAHeXylvbCF/XLYXoOn87+PRmIvejGZqVlXYjDz0X7eS+XbpNJJ6wDrt+fn7w81PukbzLly/j5s2bqFevnmL7JGzcXZyQGhOJrvPjkGHhMyYMALScuRlP1HXEhskDuW2TVM7Py7XSMXi0OOuxllE/FXWJeIoIAPa+34f+phokrM9LWloasrKy8NNPP2HhwoXYvfvuPC7BwcHw9PQEALRo0QLR0dF47rnnkJubizlz5mDIkCEICAjA+fPn8f777+PWrVs4ceIEXF2t69FNLS/q4T0mTJmN47ojJNCH+3aJfHqf0sAW9HnUnt4Ld+LiTX5PEZWh1hZlaWKcl8oGnAOAnTt3Iiws7O7OHRywatUqjBw5Evn5+YiKisKxY8eQnZ2N+vXro3///vjwww/h72/9vWEKL+rr8tE2ZN4q4r7d7ZN6ITjAk/t2CX85eXcQ+fFWXNbZqPs04Ju+HEzJwtAVidy3S8M4qEMT4UUtFF60QVQrDEAzVhNi70Q+RUetLeqh8ELhRTOG/isRB//IErJtGiCKEPty+rIFgz7fLWTb1LdFfRReKLxoiqgnksrQrNWEGFvylVuIWJogZNueAE5Sa4smUHih8KJJKRm5CF+yS9j2KcQQYiyizxk0bou2UHih8KJpbWdvgaWgWNj2KcQQom8iW1oAYBNNbKlJFF4ovGheVm4R2s/bJnQf1LGXEH1RYjoL6pCrXRReKLzohuhvWAA9Yk2I1ol65Ple1CKrfRReKLzozlvfHsXPJ68K3QcNLkaIdhQVl2L6T4fw/cEbQvdDX170g8ILhRddKiouRevZm1FULPYjOSOiJV7rFQQnRweh+yGEPEipubTWjQ5F52BfwXshPFF4ofCia0r0hwGAWq6O2DGlD40VQ4gClLhFDADt6gA/TqV+LXpE4YXCiyGImh22Mj+/2QNtGnsrsi9C7EVJqYTYhGR8uPm8Ivuj2ej1jcILhRdDUeobGwC0b+SD1a93gacb04TrhJB7KPnFA6DRto2CwguFF0MSOTR4Zag1hhDrlZRKWL37HOb8mqLYPmk4BGOh8ELhxdCUGAviXnVqOuPXd3rRNztCKqH0lwqA5iEyKgovFF7sQlJqNqKW7VV0n/SkEiF3nxh6cv52FCi8XwotxkbhhcKLXVG6JabMD2O7oUOTRxTfLyFqyC0oxogvt+PoNdEPOT+Ibg/ZBwovFF7skuhJ3KpDA+ARIyoqLsU/fj6M7w5cV3zfro5A4nQaFdeeUHih8GLX0rPy0X3BDtX2T0GG6JmagQUAwpvXxWcvt6dHnu0QhRcKLwR3m7lf+Xwrkm6o9xGnUT6JHuTk3cGzn2xF6m31alg7qiu6NqutXgFEdRReKLyQ++w7ewPDYw+oWkMzv5r4bnx3eHs4q1oHIYDyY7FUZcvbPdG8fi21yyAaQOGFwgupgpID3lXHEcCvdNImCtt/7iaGrdyvdhk0/ACpFIUXCi/kIXLy7uCZRVvxR57aldxFrTJEhIvXbqP34ni1yyhHt4ZIdSi8UHghDNQYL+ZhIlr6Y8lLT1CnRcIkI7sAoTFx0NJJ3dXJAdveDUPjOh5ql0I0jsILhRcig9ZaY+7Vr0VdLB1OT2CQitR+sq461MpCWFF4ofBCbKTmmDHWoG+z9kmLrYT3auDthp/fepLGZiGyUHih8EI4OpiShaErEtUu46Ga1/XEunHdqN+MQWjlaaCHcXIAtrzTC8EBnmqXQnSOwguFFyKA2oN3yeFX0xmb6KkOzdPKU3AsaDBGwhuFFwovRDA153nh4YOBLTHySZpgUmnXLYXoNX87NNityio0nxcRicILhReiIL0HmXu5OTlg67vUl8ZWWu+bwoICC1EKhRcKL0QlJaUSvt3zO2Zt+l3tUoRo5OOGHydSh0ytjZ/CkwMA88Qn0aohnT+Jsii8UHghGnH6sgWDPt+tdhmK02sLzom0HDz95R61y1BcTWdHbH03DA183dUuhdgxCi8UXogGZeUWoV/0NtzU/90lYgA0DgvRGpbrt6OoIlJTU/H6668jKCgI7u7uaNq0KWbNmoWioqJq1ysoKMCECRNQu3ZteHp6YsiQIcjMzBRVJiGK8fV0wZGPIpEac/dny9s91S6J2BFvNyfsN/Ut//xRcCF6VkPUhs+ePYvS0lJ89dVXCA4OxsmTJzFmzBjcvn0bixYtqnK9d999F2azGd999x28vb0xceJEDB48GHv3GqPzGyFlmtevhdSYyPL/N1InT6I+GsiQGJmit40WLlyIf/3rX7hw4UKl/56TkwM/Pz+sWbMGzz//PIC7Iahly5ZITExE165dH7oPum1EjMJe+18QefTaz4iQMizXb2EtL5XJycmBr69vlf9+5MgR3LlzB+Hh4eW/a9GiBRo3blxleCksLERhYWH5/1ssFr5FE6KSNo29K7TMZGQXoFtMHEpVrIlox986N8YHzzwOlxrC7v4TolmKhZeUlBR89tln1d4yysjIgIuLC3x8fCr83t/fHxkZGZWuEx0djTlz5vAslRBNCvBxw4V7wgxgv08z2RtvNydsmRSGAB83tUshRBOYw4vJZMLHH39c7TJnzpxBixYtyv8/PT0dAwYMwAsvvIAxY8awV1mNadOmYfLkyeX/b7FY0KhRI677IESrWjX0qtA6Axh7DBJ70MyvJr4b353mqCKkGszhZcqUKRg5cmS1yzRp0qT8v69cuYLevXujW7duWL58ebXrBQQEoKioCNnZ2RVaXzIzMxEQEFDpOq6urnB1pXlbCCkTVLfmA4EGAI5c+BNDlu9ToSJSGWdHYOukMATVral2KYToDnN48fPzg5+fn1XLpqeno3fv3ujQoQNWrVoFR8fq78126NABzs7OiIuLw5AhQwAAycnJSEtLQ2hoKGuphJB7dGjySKWhhvrSiPVql0fxj6dbUd8UQjgS9rRReno6wsLC8Oijj+Lrr7+Gk5NT+b+VtaKkp6ejb9++WL16NTp37gwAGD9+PDZt2oTY2Fh4eXnhrbfeAgDs22fdN0Z62ogQvlIychG+ZJfaZWgWTZlACB+aeNpo27ZtSElJQUpKCho2bFjh38ry0p07d5CcnIy8vL/mWP3000/h6OiIIUOGoLCwEBEREfjyyy9FlUkIeYjgAM9KW2zuZ7QWHGoxIUS7aHoAQgghhKhOE9MDEEIIIYSIQOGFEEIIIbpC4YUQQgghukLhhRBCCCG6QuGFEEIIIbpC4YUQQgghukLhhRBCCCG6QuGFEEIIIbpC4YUQQgghuiJsegC1lA0YbLFYVK6EEEIIIdYqu25bM/C/4cLLrVu3AACNGjVSuRJCCCGEsLp16xa8vb2rXcZwcxuVlpbiypUrqFWrFhwcHLhu22KxoFGjRrh06ZIh500y+usDjP8a6fXpn9FfI70+/RP1GiVJwq1bt1C/fn04Olbfq8VwLS+Ojo4PzGLNm5eXl2E/lIDxXx9g/NdIr0//jP4a6fXpn4jX+LAWlzLUYZcQQgghukLhhRBCCCG6QuGFgaurK2bNmgVXV1e1SxHC6K8PMP5rpNenf0Z/jfT69E8Lr9FwHXYJIYQQYmzU8kIIIYQQXaHwQgghhBBdofBCCCGEEF2h8EIIIYQQXaHwUo3U1FS8/vrrCAoKgru7O5o2bYpZs2ahqKio2vUKCgowYcIE1K5dG56enhgyZAgyMzMVqprNRx99hG7dusHDwwM+Pj5WrTNy5Eg4ODhU+BkwYIDYQmWS8/okScLMmTNRr149uLu7Izw8HL///rvYQm2QlZWFl19+GV5eXvDx8cHrr7+O3NzcatcJCwt74G84btw4hSqu3hdffIHAwEC4ubmhS5cuOHjwYLXLf/fdd2jRogXc3NzQpk0bbNq0SaFK5WN5jbGxsQ/8rdzc3BSslk1CQgKefvpp1K9fHw4ODti4ceND14mPj0f79u3h6uqK4OBgxMbGCq9TLtbXFx8f/8Dfz8HBARkZGcoUzCg6OhqdOnVCrVq1ULduXURFRSE5Ofmh6yl9HFJ4qcbZs2dRWlqKr776CqdOncKnn36KZcuWYfr06dWu9+677+Lnn3/Gd999h127duHKlSsYPHiwQlWzKSoqwgsvvIDx48czrTdgwABcvXq1/Oe///2voAptI+f1LViwAEuXLsWyZctw4MAB1KxZExERESgoKBBYqXwvv/wyTp06hW3btuGXX35BQkICxo4d+9D1xowZU+FvuGDBAgWqrd7//vc/TJ48GbNmzcLRo0fRrl07RERE4Nq1a5Uuv2/fPrz00kt4/fXXcezYMURFRSEqKgonT55UuHLrsb5G4O5Ipvf+rf744w8FK2Zz+/ZttGvXDl988YVVy1+8eBGRkZHo3bs3kpKSMGnSJIwePRpbtmwRXKk8rK+vTHJycoW/Yd26dQVVaJtdu3ZhwoQJ2L9/P7Zt24Y7d+6gf//+uH37dpXrqHIcSoTJggULpKCgoCr/PTs7W3J2dpa+++678t+dOXNGAiAlJiYqUaIsq1atkry9va1a9tVXX5WeffZZofXwZu3rKy0tlQICAqSFCxeW/y47O1tydXWV/vvf/wqsUJ7Tp09LAKRDhw6V/+7XX3+VHBwcpPT09CrX69Wrl/TOO+8oUCGbzp07SxMmTCj//5KSEql+/fpSdHR0pcsPHTpUioyMrPC7Ll26SG+88YbQOm3B+hpZjk2tASBt2LCh2mXef/996fHHH6/wuxdffFGKiIgQWBkf1ry+nTt3SgCkP//8U5GaeLt27ZoEQNq1a1eVy6hxHFLLC6OcnBz4+vpW+e9HjhzBnTt3EB4eXv67Fi1aoHHjxkhMTFSiREXEx8ejbt26aN68OcaPH4+bN2+qXRIXFy9eREZGRoW/n7e3N7p06aLJv19iYiJ8fHzQsWPH8t+Fh4fD0dERBw4cqHbd//znP6hTpw5at26NadOmIS8vT3S51SoqKsKRI0cqvPeOjo4IDw+v8r1PTEyssDwAREREaPJvBch7jQCQm5uLRx99FI0aNcKzzz6LU6dOKVGuIvT2N5QrJCQE9erVQ79+/bB37161y7FaTk4OAFR73VPjb2i4iRlFSklJwWeffYZFixZVuUxGRgZcXFwe6F/h7++v2XucrAYMGIDBgwcjKCgI58+fx/Tp0zFw4EAkJibCyclJ7fJsUvY38vf3r/B7rf79MjIyHmh+rlGjBnx9fautd/jw4Xj00UdRv359HD9+HH//+9+RnJyM9evXiy65Sjdu3EBJSUml7/3Zs2crXScjI0M3fytA3mts3rw5Vq5cibZt2yInJweLFi1Ct27dcOrUKeGT0Cqhqr+hxWJBfn4+3N3dVaqMj3r16mHZsmXo2LEjCgsLsWLFCoSFheHAgQNo37692uVVq7S0FJMmTUL37t3RunXrKpdT4zi0y5YXk8lUaQeqe3/uP5Gkp6djwIABeOGFFzBmzBiVKreOnNfHYtiwYXjmmWfQpk0bREVF4ZdffsGhQ4cQHx/P70VUQ/Tr0wLRr3Hs2LGIiIhAmzZt8PLLL2P16tXYsGEDzp8/z/FVEB5CQ0MxYsQIhISEoFevXli/fj38/Pzw1VdfqV0asULz5s3xxhtvoEOHDujWrRtWrlyJbt264dNPP1W7tIeaMGECTp48ibVr16pdygPssuVlypQpGDlyZLXLNGnSpPy/r1y5gt69e6Nbt25Yvnx5tesFBASgqKgI2dnZFVpfMjMzERAQYEvZVmN9fbZq0qQJ6tSpg5SUFPTt25fbdqsi8vWV/Y0yMzNRr1698t9nZmYiJCRE1jblsPY1BgQEPNDRs7i4GFlZWUyfty5dugC427rYtGlT5np5qFOnDpycnB54Mq+6YycgIIBpebXJeY33c3Z2xhNPPIGUlBQRJSquqr+hl5eX7ltdqtK5c2fs2bNH7TKqNXHixPIHAB7WwqfGcWiX4cXPzw9+fn5WLZueno7evXujQ4cOWLVqFRwdq2+s6tChA5ydnREXF4chQ4YAuNvLPC0tDaGhoTbXbg2W18fD5cuXcfPmzQoXe5FEvr6goCAEBAQgLi6uPKxYLBYcOHCA+YksW1j7GkNDQ5GdnY0jR46gQ4cOAIAdO3agtLS0PJBYIykpCQAU+xtWxsXFBR06dEBcXByioqIA3G22jouLw8SJEytdJzQ0FHFxcZg0aVL577Zt26bYscZKzmu8X0lJCU6cOIFBgwYJrFQ5oaGhDzxWq+W/IQ9JSUmqHmvVkSQJb731FjZs2ID4+HgEBQU9dB1VjkNhXYEN4PLly1JwcLDUt29f6fLly9LVq1fLf+5dpnnz5tKBAwfKfzdu3DipcePG0o4dO6TDhw9LoaGhUmhoqBov4aH++OMP6dixY9KcOXMkT09P6dixY9KxY8ekW7dulS/TvHlzaf369ZIkSdKtW7ekqVOnSomJidLFixel7du3S+3bt5cee+wxqaCgQK2XUSXW1ydJkhQTEyP5+PhIP/74o3T8+HHp2WeflYKCgqT8/Hw1XsJDDRgwQHriiSekAwcOSHv27JEee+wx6aWXXir/9/s/oykpKdLcuXOlw4cPSxcvXpR+/PFHqUmTJlLPnj3Vegnl1q5dK7m6ukqxsbHS6dOnpbFjx0o+Pj5SRkaGJEmS9Le//U0ymUzly+/du1eqUaOGtGjRIunMmTPSrFmzJGdnZ+nEiRNqvYSHYn2Nc+bMkbZs2SKdP39eOnLkiDRs2DDJzc1NOnXqlFovoVq3bt0qP84ASIsXL5aOHTsm/fHHH5IkSZLJZJL+9re/lS9/4cIFycPDQ3rvvfekM2fOSF988YXk5OQkbd68Wa2XUC3W1/fpp59KGzdulH7//XfpxIkT0jvvvCM5OjpK27dvV+slVGv8+PGSt7e3FB8fX+Gal5eXV76MFo5DCi/VWLVqlQSg0p8yFy9elABIO3fuLP9dfn6+9Oabb0qPPPKI5OHhIT333HMVAo+WvPrqq5W+vntfDwBp1apVkiRJUl5entS/f3/Jz89PcnZ2lh599FFpzJgx5SderWF9fZJ093HpDz74QPL395dcXV2lvn37SsnJycoXb6WbN29KL730kuTp6Sl5eXlJr732WoVwdv9nNC0tTerZs6fk6+srubq6SsHBwdJ7770n5eTkqPQKKvrss8+kxo0bSy4uLlLnzp2l/fv3l/9br169pFdffbXC8uvWrZOaNWsmubi4SI8//rhkNpsVrpgdy2ucNGlS+bL+/v7SoEGDpKNHj6pQtXXKHg2+/6fsNb366qtSr169HlgnJCREcnFxkZo0aVLheNQa1tf38ccfS02bNpXc3NwkX19fKSwsTNqxY4c6xVuhqmvevX8TLRyHDv+/WEIIIYQQXbDLp40IIYQQol8UXgghhBCiKxReCCGEEKIrFF4IIYQQoisUXgghhBCiKxReCCGEEKIrFF4IIYQQoisUXgghhBCiKxReCCGEEKIrFF4IIYQQoisUXgghhBCiKxReCCGEEKIr/w9sBbL6I3bHFQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# from sklearn.utils import shuffle as util_shuffle\n", + "np.random.seed(192)\n", + "def make_circles(batch_size: int=1000, rng=None) -> np.ndarray:\n", + " n_samples4 = n_samples3 = n_samples2 = batch_size // 4\n", + " n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2\n", + "\n", + " # so as not to have the first point = last point, we set endpoint=False\n", + " linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False)\n", + " linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False)\n", + " linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False)\n", + " linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False)\n", + "\n", + " circ4_x = np.cos(linspace4)\n", + " circ4_y = np.sin(linspace4)\n", + " circ3_x = np.cos(linspace4) * 0.75\n", + " circ3_y = np.sin(linspace3) * 0.75\n", + " circ2_x = np.cos(linspace2) * 0.5\n", + " circ2_y = np.sin(linspace2) * 0.5\n", + " circ1_x = np.cos(linspace1) * 0.25\n", + " circ1_y = np.sin(linspace1) * 0.25\n", + "\n", + " X = np.vstack([\n", + " np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]),\n", + " np.hstack([circ4_y, circ3_y, circ2_y, circ1_y])\n", + " ]).T * 3.0\n", + " # X = util_shuffle(X, random_state=rng)\n", + "\n", + " # Add noise\n", + " # X = X + rng.normal(scale=0.08, size=X.shape)\n", + "\n", + " return X.astype(\"float32\")\n", + "\n", + "\n", + "def transform(data: np.ndarray) -> np.ndarray:\n", + " assert data.shape[1] == 2\n", + " data[:, 0] = data[:, 0] / np.max(np.abs(data[:, 0]))\n", + " data[:, 1] = data[:, 1] / np.max(np.abs(data[:, 1]))\n", + " # data[:, 2] = data[:, 2] / np.max(np.abs(data[:, 2]))\n", + " data = (data - data.min()) / (data.max()\n", + " - data.min()) # Towards [0, 1]\n", + " data = data * 4 - 2 # [-1, 1]\n", + " return data\n", + "# get data from sklearn\n", + "data = make_circles(100000)\n", + "data = transform(data)\n", + "data = data.astype(np.float32)\n", + "plot2d(data)\n", + "def get_train_data(dataloader):\n", + " while True:\n", + " yield from dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "def task_run_flow_model(config, data, flow_model_type=IndependentConditionalFlowModel):\n", + " seed_value = set_seed()\n", + "\n", + " flow_model = flow_model_type(config=config.flow_model).to(config.flow_model.device)\n", + " flow_model = torch.compile(flow_model)\n", + "\n", + " data_loader = torch.utils.data.DataLoader(data, batch_size=config.parameter.batch_size, shuffle=True)\n", + " data_generator = get_train_data(data_loader)\n", + "\n", + " optimizer = torch.optim.Adam(\n", + " flow_model.parameters(),\n", + " lr=config.parameter.lr,\n", + " # weight_decay=config.parameter.weight_decay,\n", + " )\n", + " for iteration in track(range(config.parameter.iterations), description=config.project):\n", + "\n", + " batch_data = next(data_generator).to(config.device)\n", + " flow_model.train()\n", + " if config.parameter.training_loss_type == \"flow_matching\":\n", + " x0 = flow_model.gaussian_generator(batch_data.shape[0]).to(config.device)\n", + " loss = flow_model.flow_matching_loss(x0=x0, x1=batch_data)\n", + " else:\n", + " raise NotImplementedError(\n", + " f\"Unknown loss type {config.parameter.training_loss_type}, we need flow matching.\"\n", + " )\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " return flow_model\n", + "\n", + "def task_run_diffusion_model(config, data):\n", + " seed_value = set_seed()\n", + "\n", + " diffusion_model = DiffusionModel(config=config.diffusion_model).to(config.diffusion_model.device)\n", + " diffusion_model = torch.compile(diffusion_model)\n", + "\n", + " data_loader = torch.utils.data.DataLoader(data, batch_size=config.parameter.batch_size, shuffle=True)\n", + " data_generator = get_train_data(data_loader)\n", + "\n", + " optimizer = torch.optim.Adam(\n", + " diffusion_model.parameters(),\n", + " lr=config.parameter.lr,\n", + " # weight_decay=config.parameter.weight_decay,\n", + " )\n", + " for iteration in track(range(config.parameter.iterations), description=config.project):\n", + "\n", + " batch_data = next(data_generator).to(config.device)\n", + " diffusion_model.train()\n", + " if config.parameter.training_loss_type == \"flow_matching\":\n", + " loss = diffusion_model.flow_matching_loss(batch_data)\n", + " elif config.parameter.training_loss_type == \"score_matching_maximum_likelihhood\":\n", + " loss = diffusion_model.score_matching_loss(batch_data)\n", + " elif config.parameter.training_loss_type == \"score_matching\":\n", + " loss = diffusion_model.score_matching_loss(batch_data, weighting_scheme=\"vanilla\")\n", + " else:\n", + " raise NotImplementedError(\n", + " f\"Unknown loss type {config.parameter.training_loss_type}, we need flow matching or score matching.\"\n", + " )\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " return diffusion_model" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "45717a201101413eb7717ce540e9fa21", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b434ca73e5ec4cb797efa1054c61aa7f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "diffusion_model = task_run_diffusion_model(diffusion_model_config, data)\n", + "flow_model = task_run_flow_model(flow_model_config, data, IndependentConditionalFlowModel)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "diffusion_model: shape of x: torch.Size([1, 2])\n", + "diffusion_model: likelihood: tensor([0.1089], device='cuda:0')\n", + "flow_model: shape of x: torch.Size([1, 2])\n", + "flow_model: likelihood: tensor([0.0774], device='cuda:0')\n" + ] + } + ], + "source": [ + "t_span = torch.linspace(0.0, 1.0, 1000)\n", + "# diffusion model\n", + "\n", + "x = diffusion_model.sample(t_span=t_span, batch_size=1)\n", + "print(f\"diffusion_model: shape of x: {x.shape}\")\n", + "with torch.no_grad():\n", + " logp = compute_likelihood(diffusion_model, x)\n", + " print(f\"diffusion_model: likelihood: {torch.exp(logp)}\")\n", + "\n", + "# flow model\n", + "x = flow_model.sample(t_span=t_span, batch_size=1)\n", + "print(f\"flow_model: shape of x: {x.shape}\")\n", + "with torch.no_grad():\n", + " logp = compute_likelihood(flow_model, x)\n", + " print(f\"flow_model: likelihood: {torch.exp(logp)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "diffusion_model: shape of x: torch.Size([1, 2])\n", + "diffusion_model: likelihood: tensor([0.1333], device='cuda:0', grad_fn=)\n", + "flow_model: shape of x: torch.Size([1, 2])\n", + "flow_model: likelihood: tensor([0.1217], device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "# diffusion model\n", + "\n", + "x = diffusion_model.sample(t_span=t_span, batch_size=1)\n", + "print(f\"diffusion_model: shape of x: {x.shape}\")\n", + "logp = compute_likelihood(diffusion_model, x)\n", + "print(f\"diffusion_model: likelihood: {torch.exp(logp)}\")\n", + "\n", + "# test if the tensor has grad_fn\n", + "assert logp.grad_fn is not None\n", + "\n", + "# flow model\n", + "\n", + "x = flow_model.sample(t_span=t_span, batch_size=1)\n", + "print(f\"flow_model: shape of x: {x.shape}\")\n", + "logp = compute_likelihood(flow_model, x)\n", + "print(f\"flow_model: likelihood: {torch.exp(logp)}\")\n", + "\n", + "# test if the tensor has grad_fn\n", + "assert logp.grad_fn is not None" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "diffusion_model: shape of x: torch.Size([1, 2])\n", + "diffusion_model: likelihood: tensor([0.1349], device='cuda:0', grad_fn=)\n", + "flow_model: shape of x: torch.Size([1, 2])\n", + "flow_model: likelihood: tensor([0.0012], device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "# diffusion model\n", + "\n", + "x = diffusion_model.sample(t_span=t_span, batch_size=1)\n", + "print(f\"diffusion_model: shape of x: {x.shape}\")\n", + "logp = compute_likelihood(diffusion_model, x, using_Hutchinson_trace_estimator=True)\n", + "print(f\"diffusion_model: likelihood: {torch.exp(logp)}\")\n", + "\n", + "# test if the tensor has grad_fn\n", + "assert logp.grad_fn is not None\n", + "\n", + "# flow model\n", + "\n", + "x = flow_model.sample(t_span=t_span, batch_size=1)\n", + "print(f\"flow_model: shape of x: {x.shape}\")\n", + "logp = compute_likelihood(diffusion_model, x, using_Hutchinson_trace_estimator=True)\n", + "print(f\"flow_model: likelihood: {torch.exp(logp)}\")\n", + "\n", + "# test if the tensor has grad_fn\n", + "assert logp.grad_fn is not None\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d02c87b5b8ee4226b52b091ba4114b4c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def density_flow_of_generative_model(model):\n", + "\n", + " model.eval()\n", + " x_range = torch.linspace(-4, 4, 100, device=model.device)\n", + " y_range = torch.linspace(-4, 4, 100, device=model.device)\n", + " xx, yy = torch.meshgrid(x_range, y_range)\n", + " z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)\n", + "\n", + " indexs = torch.arange(0, z_grid.shape[0], device=model.device)\n", + " memory = 0.01\n", + "\n", + " p_list = []\n", + " for t in track(range(100), description=\"Density Flow Training\"):\n", + " t_span = torch.linspace(0.01 * t, 1, 101 - t, device=model.device)\n", + " logp_list = []\n", + " for ii in torch.split(indexs, int(z_grid.shape[0] * memory)):\n", + " logp_ii = compute_likelihood(\n", + " model=model,\n", + " x=z_grid[ii],\n", + " t=t_span,\n", + " using_Hutchinson_trace_estimator=True,\n", + " )\n", + " logp_list.append(logp_ii.unsqueeze(0))\n", + " logp = torch.cat(logp_list, 1)\n", + " p = torch.exp(logp).reshape(100, 100)\n", + " p_list.append(p)\n", + "\n", + " return p_list\n", + "p_list = density_flow_of_generative_model(diffusion_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[11:34:18] INFO     Animation.save using <class 'matplotlib.animation.FFMpegWriter'>              animation.py:1060\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[11:34:18]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Animation.save using \u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'matplotlib.animation.FFMpegWriter'\u001b[0m\u001b[1m>\u001b[0m \u001b]8;id=198352;file:///opt/conda/lib/python3.10/site-packages/matplotlib/animation.py\u001b\\\u001b[2manimation.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=80104;file:///opt/conda/lib/python3.10/site-packages/matplotlib/animation.py#1060\u001b\\\u001b[2m1060\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
           INFO     MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s      animation.py:338\n",
+       "                    700x600 -pix_fmt rgba -framerate 20 -loglevel error -i pipe: -vcodec h264                      \n",
+       "                    -pix_fmt yuv420p -y ./video-diffusion/density_flow_diffuse.mp4                                 \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s \u001b]8;id=561233;file:///opt/conda/lib/python3.10/site-packages/matplotlib/animation.py\u001b\\\u001b[2manimation.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=842454;file:///opt/conda/lib/python3.10/site-packages/matplotlib/animation.py#338\u001b\\\u001b[2m338\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 70\u001b[1;36m0x600\u001b[0m -pix_fmt rgba -framerate \u001b[1;36m20\u001b[0m -loglevel error -i pipe: -vcodec h264 \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m -pix_fmt yuv420p -y .\u001b[35m/video-diffusion/\u001b[0m\u001b[95mdensity_flow_diffuse.mp4\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def render_density_flow_video(p_list, video_save_path, fps=20, dpi=100, generate_path=True):\n", + " if not os.path.exists(video_save_path):\n", + " os.makedirs(video_save_path)\n", + "\n", + " fig, ax = plt.subplots(figsize=(7, 6))\n", + " plt.xlim([-4, 4])\n", + " plt.ylim([-4, 4])\n", + "\n", + " ims = []\n", + " colors = np.linspace(0, 1, len(p_list))\n", + "\n", + " # Assuming p_list contains 2D arrays of the same shape\n", + " x = np.linspace(-4, 4, p_list[0].shape[1])\n", + " y = np.linspace(-4, 4, p_list[0].shape[0])\n", + " X, Y = np.meshgrid(x, y)\n", + "\n", + " cbar = None # Initialize color bar\n", + "\n", + " if generate_path:\n", + " enumerate_items = list(enumerate(p_list))[::-1]\n", + " enumerate_items = enumerate_items[:-1]\n", + " # enumerate_items = enumerate_items\n", + " else:\n", + " enumerate_items = list(enumerate(p_list))[1:]\n", + "\n", + " for i, p in enumerate_items:\n", + " p_max = 0.2\n", + " p_min = 0.0\n", + "\n", + " im = ax.pcolormesh(\n", + " Y, X, p.cpu().detach().numpy(),\n", + " cmap=\"viridis\",\n", + " vmin=p_min, vmax=p_max,\n", + " shading='auto'\n", + " )\n", + " title = ax.text(0.5, 1.05, f't={colors[i]:.2f}', size=plt.rcParams[\"axes.titlesize\"], ha=\"center\", transform=ax.transAxes)\n", + "\n", + " # Remove the previous color bar if it exists\n", + " if cbar:\n", + " cbar.remove()\n", + "\n", + " # Adding the colorbar inside the loop to update it each frame\n", + " cbar = fig.colorbar(im, ax=ax)\n", + " cbar.set_label('Density')\n", + "\n", + " ims.append([im, title])\n", + "\n", + " ani = animation.ArtistAnimation(fig, ims, interval=20/fps, blit=True)\n", + " ani.save(\n", + " os.path.join(video_save_path, f\"density_flow_diffuse.mp4\"),\n", + " fps=fps,\n", + " dpi=dpi,\n", + " )\n", + "\n", + " # clean up\n", + " plt.close(fig)\n", + " plt.clf()\n", + "render_density_flow_video(p_list=p_list, video_save_path=diffusion_model_config.parameter.video_save_path, generate_path=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py new file mode 100644 index 0000000..836d204 --- /dev/null +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -0,0 +1,266 @@ +from typing import Optional, Tuple, Literal +from dataclasses import dataclass + +import numpy as np +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from easydict import EasyDict + +from .edm_preconditioner import PreConditioner +from grl.generative_models.intrinsic_model import IntrinsicModel + +class EDMModel(nn.Module): + + def __init__(self, config: Optional[EasyDict]=None) -> None: + + super().__init__() + self.config: EasyDict = config + self.device: torch.device = config.device + + # EDM Type ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] + self.edm_type: str = config.edm_model.path.edm_type + assert self.edm_type in ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"], \ + f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}" + + #* 1. Construct basic Unet architecture through params in config + # TODO: construct basic denoise network here + + self.base_denoise_network: Optional[nn.Module] = IntrinsicModel(config.edm_model.model.args) + + #* 2. Precond setup + self.params: EasyDict = config.edm_model.path.params + self.preconditioner = PreConditioner( + self.edm_type, + base_denoise_model=self.base_denoise_network, + use_mixes_precision=False, + **self.params + ) + + #* 3. Solver setup + self.solver_type: str = config.edm_model.solver.solver_type + self.solver_schedule: str = config.edm_model.solver.schedule + self.solver_scaling: str = config.edm_model.solver.scaling + self.solver_params: EasyDict = config.edm_model.solver.params + assert self.solver_type in ['euler', 'heun'] + assert self.solver_schedule in ['VP', 'VE', 'Linear'] + assert self.solver_scaling in ["VP", "none"] + + # Initialize sigma_min and sigma_max if not provided + + if "sigma_min" not in self.params: + self._initialize_sigma_min() + else: + self.sigma_min = self.params.sigma_min + if "sigma_max" not in self.params: + self._initialize_sigma_max() + else: + self.sigma_max = self.params.sigma_max + + def get_type(self): + return "DiffusionModel" + + def _initialize_sigma_min(self): + vp_sigma = lambda beta_d, beta_min: lambda t: (np.exp(0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 + self.sigma_min = { + "VP_edm": vp_sigma(19.9, 0.1)(1e-3), + "VE_edm": 0.02, + "iDDPM_edm": 0.002, + "EDM": 0.002 + }[self.edm_type] + + def _initialize_sigma_max(self): + vp_sigma = lambda beta_d, beta_min: lambda t: (np.exp(0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 + self.sigma_max = { + "VP_edm": vp_sigma(19.9, 0.1)(1), + "VE_edm": 100, + "iDDPM_edm": 81, + "EDM": 80 + }[self.edm_type] + + # For VP_edm + + + def _sample_sigma_weight_train(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # assert the first dim of x is batch size + rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) + if self.edm_type == "VP_edm": + def sigma_for_vp_edm(self, t): + t = torch.as_tensor(t) + return ((0.5 * self.params.beta_d * (t ** 2) + self.params.beta_min * t).exp() - 1).sqrt() + rand_uniform = torch.rand(*rand_shape, device=x.device) + sigma = sigma_for_vp_edm(1 + rand_uniform * (self.params.epsilon_t - 1)) + weight = 1 / sigma ** 2 + elif self.edm_type == "VE_edm": + rand_uniform = torch.rand(*rand_shape, device=x.device) + sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform) + weight = 1 / sigma ** 2 + elif self.edm_type == "EDM": + rand_normal = torch.randn(*rand_shape, device=x.device) + sigma = (rand_normal * self.params.P_std + self.params.P_mean).exp() + weight = (sigma ** 2 + self.params.sigma_data ** 2) / (sigma * self.params.sigma_data) ** 2 + return sigma, weight + + def forward(self, x: torch.Tensor, class_labels=None): + x = x.to(self.device) + sigma, weight = self._sample_sigma_weight_train(x) + n = torch.randn_like(x) * sigma + D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels) + loss = weight * ((D_xn - x) ** 2) + return loss + + + def _get_sigma_steps(self): + """ + Overview: + Get the schedule of sigma according to differernt t schedules. + + """ + self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min) + self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max) + + # Define time steps in terms of noise level + step_indices = torch.arange(self.solver_params.num_steps, dtype=torch.float64, device=self.device) + sigma_steps = None + if self.edm_type == "VP_edm": + vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / self.params.epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (self.params.epsilon_s - 1) + vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d + vp_sigma = lambda beta_d, beta_min: lambda t: (np.exp((0.5 * beta_d * (t ** 2) + beta_min * t)) - 1) ** 0.5 + orig_t_steps = 1 + step_indices / (self.solver_params.num_steps - 1) * (self.params.epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + + elif self.edm_type == "VE_edm": + ve_sigma = lambda t: t.sqrt() + orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (self.solver_params.num_steps - 1))) + sigma_steps = ve_sigma(orig_t_steps) + + elif self.edm_type == "iDDPM_edm": + u = torch.zeros(self.params.M + 1, dtype=torch.float64, device=self.device) + alpha_bar = lambda j: (0.5 * np.pi * j / self.params.M / (self.params.C_2 + 1)).sin() ** 2 + for j in torch.arange(self.params.M, 0, -1, device=self.device): # M, ..., 1 + u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=self.params.C_1) - 1).sqrt() + u_filtered = u[torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)] + sigma_steps = u_filtered[((len(u_filtered) - 1) / (self.solver_params.num_steps - 1) * step_indices).round().to(torch.int64)] + + elif self.edm_type == "EDM": + sigma_steps = (self.sigma_max ** (1 / self.solver_params.rho) + step_indices / (self.solver_params.num_steps - 1) * \ + (self.sigma_min ** (1 / self.solver_params.rho) - self.sigma_max ** (1 / self.solver_params.rho))) ** self.solver_params.rho + # Define noise level schedule. + return sigma_steps + + + def _get_sigma_deriv_inv(self): + """ + Overview: + Get sigma(t) for different solver schedules. + + Returns: + sigma(t), sigma'(t), sigma^{-1}(sigma) + """ + if self.solver_schedule == 'VP': # [VP_edm] + vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / self.params.epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (self.params.epsilon_s - 1) + vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d + vp_sigma = lambda beta_d, beta_min: lambda t: (np.exp((0.5 * beta_d * (t ** 2) + beta_min * t)) - 1) ** 0.5 + vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) + vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d + vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / self.params.epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (self.params.epsilon_s - 1) + vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d + + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif self.solver_schedule == 'VE': # [VE_edm] + sigma = lambda t: t.sqrt() + sigma_deriv = lambda t: 0.5 / t.sqrt() + sigma_inv = lambda sigma: sigma ** 2 + elif self.solver_schedule == 'Linear': # [iDDPM_edm, EDM] + sigma = lambda t: t + sigma_deriv = lambda t: 1 + sigma_inv = lambda sigma: sigma + + return sigma, sigma_deriv, sigma_inv + + + def _get_scaling(self, sigma, sigma_deriv, sigma_inv, sigma_steps): + """ + Overview: + Get s(t) for different solver schedules. and t_steps + + Returns: + sigma(t), sigma'(t), sigma^{-1}(sigma) + """ + # Define scaling schedule. + if self.solver_scaling == 'VP': + s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() + s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + elif self.solver_scaling == 'none': # [VE_edm, iDDPM_edm, EDM] + s = lambda t: 1 + s_deriv = lambda t: 0 + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(self.preconditioner.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + return s, s_deriv, t_steps + + + def sample(self, latents, class_labels=None, use_stochastic=False): + # Get sigmas, scales, and timesteps + latents = latents.to(self.device) + sigma_steps = self._get_sigma_steps() + sigma, sigma_deriv, sigma_inv = self._get_sigma_deriv_inv() + s, s_deriv, t_steps = self._get_scaling(sigma, sigma_deriv, sigma_inv, sigma_steps) + + if not use_stochastic: + # Main sampling loop + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(self.solver_params.S_churn / self.solver_params.num_steps, np.sqrt(2) - 1) if self.solver_params.S_min <= sigma(t_cur) <= self.solver_params.S_max else 0 + t_hat = sigma_inv(self.preconditioner.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * self.solver_params.S_noise * torch.randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + denoised = self.preconditioner(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64) + d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + self.solver_params.alpha * h * d_cur + t_prime = t_hat + self.solver_params.alpha * h + + # Apply 2nd order correction. + if self.solver_type == 'euler' or i == self.solver_params.num_steps - 1: + x_next = x_hat + h * d_cur + else: + assert self.solver_type == 'heun' + denoised = self.preconditioner(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64) + d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ((1 - 1 / (2 * self.solver_params.alpha)) * d_cur + 1 / (2 * self.solver_params.alpha) * d_prime) + + else: + assert self.edm_type == "EDM", f"Stochastic can only use in EDM, but your precond type is {self.edm_type}" + x_next = latents.to(torch.float64) * t_steps[0] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(self.solver_params.S_churn / self.solver_params.num_steps, np.sqrt(2) - 1) if self.solver_params.S_min <= t_cur <= self.solver_params.S_max else 0 + t_hat = self.preconditioner.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * self.solver_params.S_noise * torch.randn_like(x_cur) + + # Euler step. + denoised = self.preconditioner(x_hat, t_hat, class_labels).to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < self.solver_params.num_steps - 1: + denoised = self.preconditioner(x_next, t_next, class_labels).to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + + return x_next diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py new file mode 100644 index 0000000..e9f115a --- /dev/null +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -0,0 +1,116 @@ +from typing import Optional, Tuple, Literal +from dataclasses import dataclass + +import numpy as np +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F + +class PreConditioner(nn.Module): + + def __init__(self, + precondition_type: Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] = "EDM", + base_denoise_model: nn.Module = None, + use_mixes_precision: bool = False, + **precond_config_kwargs) -> None: + + super().__init__() + self.precondition_type = precondition_type + self.base_denoise_model = base_denoise_model + self.use_mixes_precision = use_mixes_precision + + if self.precondition_type == "VP_edm": + self.beta_d = precond_config_kwargs.get("beta_d", 19.9) + self.beta_min = precond_config_kwargs.get("beta_min", 0.1) + self.M = precond_config_kwargs.get("M", 1000) + self.epsilon_t = precond_config_kwargs.get("epsilon_t", 1e-5) + self.sigma_min = float(self.sigma_for_vp_edm(self.epsilon_t)) + self.sigma_max = float(self.sigma_for_vp_edm(1)) + + elif self.precondition_type == "VE_edm": + self.sigma_min = precond_config_kwargs.get("sigma_min", 0.02) + self.sigma_max = precond_config_kwargs.get("sigma_max", 100) + + elif self.precondition_type == "iDDPM_edm": + self.C_1 = precond_config_kwargs.get("C_1", 0.001) + self.C_2 = precond_config_kwargs.get("C_2", 0.008) + self.M = precond_config_kwargs.get("M", 1000) + u = torch.zeros(self.M + 1) + for j in range(self.M, 0, -1): # M, ..., 1 + u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=self.C_1) - 1).sqrt() + self.register_buffer('u', u) + self.sigma_min = float(u[self.M - 1]) + self.sigma_max = float(u[0]) + + elif self.precondition_type == "EDM": + self.sigma_min = precond_config_kwargs.get("sigma_min", 0.) + self.sigma_max = precond_config_kwargs.get("sigma_max", float("inf")) + self.sigma_data = precond_config_kwargs.get("sigma_data", 0.5) + + else: + raise ValueError(f"Please check your precond type {self.precondition_type} is in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") + + # For VP_edm + def sigma_for_vp_edm(self, t): + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() + # For VP_edm + def sigma_inv_for_vp_edm(self, sigma): + sigma = torch.as_tensor(sigma) + return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d + + # For iDDPM_edm + def alpha_bar(self, j): + assert self.precondition_type == "iDDPM_edm", f"Only iDDPM_edm supports the alpha bar function, but your precond type is {self.precondition_type}" + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + + def round_sigma(self, sigma, return_index=False): + + if self.precondition_type == "iDDPM_edm": + sigma = torch.as_tensor(sigma) + index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape).to(sigma.device) + else: + return torch.as_tensor(sigma) + + def get_precondition_c(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + + if self.precondition_type == "VP_edm": + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1).sqrt() + c_noise = (self.M - 1) * self.sigma_inv_for_vp_edm(sigma) + elif self.precondition_type == "VE_edm": + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + elif self.precondition_type == "iDDPM_edm": + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1).sqrt() + c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + elif self.precondition_type == "EDM": + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() + c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() + c_noise = sigma.log() / 4 + return c_skip, c_out, c_in, c_noise + + def forward(self, x: Tensor, sigma: Tensor, class_labels=None, **model_kwargs): + # Suppose the first dim of x is batch size + x = x.to(torch.float32) + sigma_shape = [x.shape[0]] + [1] * (x.ndim - 1) + if sigma.numel() == 1: + sigma = sigma.view(-1).expand(*sigma_shape) + + dtype = torch.float16 if (self.use_mixes_precision and x.device.type == 'cuda') else torch.float32 + c_skip, c_out, c_in, c_noise = self.get_precondition_c(sigma) + F_x = self.base_denoise_model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) + assert F_x.dtype == dtype + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x \ No newline at end of file diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py new file mode 100644 index 0000000..733623e --- /dev/null +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -0,0 +1,220 @@ +################################################################################################ +# This script demonstrates how to use edm diffusion to train Swiss Roll dataset. +################################################################################################ + +import os +import signal +import sys + +import matplotlib +import numpy as np +from easydict import EasyDict +from rich.progress import track +from sklearn.datasets import make_swiss_roll + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import torch +from easydict import EasyDict +from matplotlib import animation + +from grl.generative_models.edm_diffusion_model.edm_diffusion_model import EDMModel +from grl.utils import set_seed +from grl.utils.log import log + +x_size = 2 +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +config = EasyDict( + dict( + device=device, + edm_model=dict( + path=dict( + edm_type="EDM", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] + params=dict( + #^ 1: VP_edm + # beta_d=19.9, + # beta_min=0.1, + # M=1000, + # epsilon_t=1e-5, + # epsilon_s=1e-3, + #^ 2: VE_edm + # sigma_min=0.02, + # sigma_max=100, + #^ 3: iDDPM_edm + # C_1=0.001, + # C_2=0.008, + # M=1000, + #^ 4: EDM + sigma_min=0.002, + sigma_max=80, + sigma_data=0.5, + P_mean=-1.2, + P_std=1.2, + ) + ), + + solver=dict( + solver_type="heun", + # *['euler', 'heun'] + schedule="Linear", + #* ['VP', 'VE', 'Linear'] Give "Linear" when edm type in ["iDDPM_edm", "EDM"] + scaling="none", + #* ["VP", "none"] Give "none" when edm type in ["VE_edm", "iDDPM_edm", "EDM"] + params=dict( + num_steps=18, + alpha=1, + S_churn=0, + S_min=0, + S_max=float("inf"), + S_noise=1, + rho=7, #* EDM needs rho + ) + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=x_size, + t_dim=t_embedding_dim, + ), + ), + ), + ), + ), + parameter=dict( + training_loss_type="score_matching", + lr=5e-3, + data_num=10000, + iterations=1000, + batch_size=2048, + clip_grad_norm=1.0, + eval_freq=500, + checkpoint_freq=100, + checkpoint_path="./checkpoint", + video_save_path="./video", + device=device, + ), + ) +) +if __name__ == "__main__": + seed_value = set_seed() + log.info(f"start exp with seed value {seed_value}.") + edm_diffusion_model = EDMModel(config=config).to(config.device) + edm_diffusion_model = torch.compile(edm_diffusion_model) + # get data + data = make_swiss_roll(n_samples=config.parameter.data_num, noise=0.01)[0].astype( + np.float32 + )[:, [0, 2]] + # transform data + data[:, 0] = data[:, 0] / np.max(np.abs(data[:, 0])) + data[:, 1] = data[:, 1] / np.max(np.abs(data[:, 1])) + data = (data - data.min()) / (data.max() - data.min()) + data = data * 10 - 5 + + # + optimizer = torch.optim.Adam( + edm_diffusion_model.parameters(), + lr=config.parameter.lr, + ) + if config.parameter.checkpoint_path is not None: + + if ( + not os.path.exists(config.parameter.checkpoint_path) + or len(os.listdir(config.parameter.checkpoint_path)) == 0 + ): + log.warning( + f"Checkpoint path {config.parameter.checkpoint_path} does not exist" + ) + last_iteration = -1 + else: + checkpoint_files = [ + f + for f in os.listdir(config.parameter.checkpoint_path) + if f.endswith(".pt") + ] + checkpoint_files = sorted( + checkpoint_files, key=lambda x: int(x.split("_")[-1].split(".")[0]) + ) + checkpoint = torch.load( + os.path.join(config.parameter.checkpoint_path, checkpoint_files[-1]), + map_location="cpu", + ) + edm_diffusion_model.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + last_iteration = checkpoint["iteration"] + else: + last_iteration = -1 + + data_loader = torch.utils.data.DataLoader( + data, batch_size=config.parameter.batch_size, shuffle=True + ) + + def get_train_data(dataloader): + while True: + yield from dataloader + + data_generator = get_train_data(data_loader) + + gradient_sum = 0.0 + loss_sum = 0.0 + counter = 0 + iteration = 0 + + def plot2d(data): + + plt.scatter(data[:, 0], data[:, 1]) + plt.show() + + def render_video(data_list, video_save_path, iteration, fps=100, dpi=100): + if not os.path.exists(video_save_path): + os.makedirs(video_save_path) + fig = plt.figure(figsize=(6, 6)) + plt.xlim([-10, 10]) + plt.ylim([-10, 10]) + ims = [] + colors = np.linspace(0, 1, len(data_list)) + + for i, data in enumerate(data_list): + # image alpha frm 0 to 1 + im = plt.scatter(data[:, 0], data[:, 1], s=1) + ims.append([im]) + ani = animation.ArtistAnimation(fig, ims, interval=0.1, blit=True) + ani.save( + os.path.join(video_save_path, f"iteration_{iteration}.mp4"), + fps=fps, + dpi=dpi, + ) + # clean up + plt.close(fig) + plt.clf() + + def save_checkpoint(model, optimizer, iteration): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + iteration=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, f"checkpoint_{iteration}.pt" + ), + ) + + history_iteration = [-1] + batch_data = next(data_generator) + batch_data = batch_data.to(config.device) + \ No newline at end of file From 0ee86a9342ca5be7a606126b3d9377db6f0997a6 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Tue, 20 Aug 2024 13:41:50 +0000 Subject: [PATCH 02/14] feature(wrh): add edm --- .../edm_diffusion_model.py | 251 ++++++------ .../edm_diffusion_model/edm_preconditioner.py | 25 +- .../edm_diffusion_model/edm_utils.py | 101 +++++ .../edm_diffusion_model/test.ipynb | 370 ++++++++++++++++++ .../swiss_roll/swiss_roll_edm_diffusion.py | 27 +- 5 files changed, 625 insertions(+), 149 deletions(-) create mode 100644 grl/generative_models/edm_diffusion_model/edm_utils.py create mode 100644 grl/generative_models/edm_diffusion_model/test.ipynb diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 836d204..5a2831d 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -8,17 +8,36 @@ import torch.nn.functional as F import torch.optim as optim from easydict import EasyDict +from functools import partial from .edm_preconditioner import PreConditioner +from .edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV, SCALE_T, SCALE_T_DERIV, INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN +from .edm_utils import DEFAULT_SOLVER_PARAM from grl.generative_models.intrinsic_model import IntrinsicModel +from grl.utils import set_seed +from grl.utils.log import log + +class Simple(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + nn.Linear(2, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 2) + ) + def forward(self, x, noise, class_labels=None): + return self.model(x) class EDMModel(nn.Module): def __init__(self, config: Optional[EasyDict]=None) -> None: super().__init__() - self.config: EasyDict = config - self.device: torch.device = config.device + self.config= config + # self.x_size = config.x_size + self.device = config.device # EDM Type ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] self.edm_type: str = config.edm_model.path.edm_type @@ -26,12 +45,10 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}" #* 1. Construct basic Unet architecture through params in config - # TODO: construct basic denoise network here - - self.base_denoise_network: Optional[nn.Module] = IntrinsicModel(config.edm_model.model.args) + self.base_denoise_network = Simple() #* 2. Precond setup - self.params: EasyDict = config.edm_model.path.params + self.params = config.edm_model.path.params self.preconditioner = PreConditioner( self.edm_type, base_denoise_model=self.base_denoise_network, @@ -40,79 +57,61 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: ) #* 3. Solver setup - self.solver_type: str = config.edm_model.solver.solver_type - self.solver_schedule: str = config.edm_model.solver.schedule - self.solver_scaling: str = config.edm_model.solver.scaling - self.solver_params: EasyDict = config.edm_model.solver.params + self.solver_type = config.edm_model.solver.solver_type assert self.solver_type in ['euler', 'heun'] - assert self.solver_schedule in ['VP', 'VE', 'Linear'] - assert self.solver_scaling in ["VP", "none"] + + self.solver_params = DEFAULT_SOLVER_PARAM + self.solver_params.update(config.edm_model.solver.params) # Initialize sigma_min and sigma_max if not provided - if "sigma_min" not in self.params: - self._initialize_sigma_min() - else: - self.sigma_min = self.params.sigma_min - if "sigma_max" not in self.params: - self._initialize_sigma_max() - else: - self.sigma_max = self.params.sigma_max + + self.sigma_min = INITIAL_SIGMA_MIN[self.edm_type] if "sigma_min" not in self.params else self.params.sigma_min + self.sigma_max = INITIAL_SIGMA_MAX[self.edm_type] if "sigma_max" not in self.params else self.params.sigma_max + def get_type(self): - return "DiffusionModel" - - def _initialize_sigma_min(self): - vp_sigma = lambda beta_d, beta_min: lambda t: (np.exp(0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 - self.sigma_min = { - "VP_edm": vp_sigma(19.9, 0.1)(1e-3), - "VE_edm": 0.02, - "iDDPM_edm": 0.002, - "EDM": 0.002 - }[self.edm_type] + return "EDMModel" - def _initialize_sigma_max(self): - vp_sigma = lambda beta_d, beta_min: lambda t: (np.exp(0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 - self.sigma_max = { - "VP_edm": vp_sigma(19.9, 0.1)(1), - "VE_edm": 100, - "iDDPM_edm": 81, - "EDM": 80 - }[self.edm_type] - # For VP_edm - - - def _sample_sigma_weight_train(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]: # assert the first dim of x is batch size + log.info(f"Params of trainig is: {params}") rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) if self.edm_type == "VP_edm": - def sigma_for_vp_edm(self, t): - t = torch.as_tensor(t) - return ((0.5 * self.params.beta_d * (t ** 2) + self.params.beta_min * t).exp() - 1).sqrt() + epsilon_t = params.get("epsilon_t", 1e-5) + beta_d = params.get("beta_d", 19.9) + beta_min = params.get("beta_min", 0.1) + rand_uniform = torch.rand(*rand_shape, device=x.device) - sigma = sigma_for_vp_edm(1 + rand_uniform * (self.params.epsilon_t - 1)) + sigma = SIGMA_T["VP_edm"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min) weight = 1 / sigma ** 2 elif self.edm_type == "VE_edm": rand_uniform = torch.rand(*rand_shape, device=x.device) sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform) weight = 1 / sigma ** 2 elif self.edm_type == "EDM": + P_mean = params.get("P_mean", -1.2) + P_std = params.get("P_mean", 1.2) + sigma_data = params.get("sigma_data", 0.5) + rand_normal = torch.randn(*rand_shape, device=x.device) - sigma = (rand_normal * self.params.P_std + self.params.P_mean).exp() - weight = (sigma ** 2 + self.params.sigma_data ** 2) / (sigma * self.params.sigma_data) ** 2 + sigma = (rand_normal * P_std + P_mean).exp() + weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 return sigma, weight - def forward(self, x: torch.Tensor, class_labels=None): + def forward(self, + x: Tensor, + class_labels=None) -> Tensor: x = x.to(self.device) - sigma, weight = self._sample_sigma_weight_train(x) + sigma, weight = self._sample_sigma_weight_train(x, **self.params) n = torch.randn_like(x) * sigma D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels) loss = weight * ((D_xn - x) ** 2) return loss - def _get_sigma_steps(self): + def _get_sigma_steps_t_steps(self, num_steps=18, epsilon_s=1e-3, rho=7): """ Overview: Get the schedule of sigma according to differernt t schedules. @@ -120,38 +119,44 @@ def _get_sigma_steps(self): """ self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min) self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max) - + # Define time steps in terms of noise level - step_indices = torch.arange(self.solver_params.num_steps, dtype=torch.float64, device=self.device) + step_indices = torch.arange(num_steps, dtype=torch.float64, device=self.device) sigma_steps = None if self.edm_type == "VP_edm": - vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / self.params.epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (self.params.epsilon_s - 1) + vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1) vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d - vp_sigma = lambda beta_d, beta_min: lambda t: (np.exp((0.5 * beta_d * (t ** 2) + beta_min * t)) - 1) ** 0.5 - orig_t_steps = 1 + step_indices / (self.solver_params.num_steps - 1) * (self.params.epsilon_s - 1) - sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = SIGMA_T["VP_edm"](orig_t_steps, vp_beta_d, vp_beta_min) elif self.edm_type == "VE_edm": - ve_sigma = lambda t: t.sqrt() - orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (self.solver_params.num_steps - 1))) - sigma_steps = ve_sigma(orig_t_steps) + orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (num_steps - 1))) + sigma_steps = SIGMA_T["VE_edm"](orig_t_steps) elif self.edm_type == "iDDPM_edm": - u = torch.zeros(self.params.M + 1, dtype=torch.float64, device=self.device) - alpha_bar = lambda j: (0.5 * np.pi * j / self.params.M / (self.params.C_2 + 1)).sin() ** 2 + M, C_1, C_2 = self.params.M, self.params.C_1, self.params.C_2 + + u = torch.zeros(M + 1, dtype=torch.float64, device=self.device) + alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 for j in torch.arange(self.params.M, 0, -1, device=self.device): # M, ..., 1 - u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=self.params.C_1) - 1).sqrt() + u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() u_filtered = u[torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)] - sigma_steps = u_filtered[((len(u_filtered) - 1) / (self.solver_params.num_steps - 1) * step_indices).round().to(torch.int64)] + + sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] + orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) elif self.edm_type == "EDM": - sigma_steps = (self.sigma_max ** (1 / self.solver_params.rho) + step_indices / (self.solver_params.num_steps - 1) * \ - (self.sigma_min ** (1 / self.solver_params.rho) - self.sigma_max ** (1 / self.solver_params.rho))) ** self.solver_params.rho - # Define noise level schedule. - return sigma_steps + sigma_steps = (self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * \ + (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))) ** rho + orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) + + t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0 + + return sigma_steps, t_steps - def _get_sigma_deriv_inv(self): + def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s=1e-3): """ Overview: Get sigma(t) for different solver schedules. @@ -159,86 +164,70 @@ def _get_sigma_deriv_inv(self): Returns: sigma(t), sigma'(t), sigma^{-1}(sigma) """ - if self.solver_schedule == 'VP': # [VP_edm] - vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / self.params.epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (self.params.epsilon_s - 1) - vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d - vp_sigma = lambda beta_d, beta_min: lambda t: (np.exp((0.5 * beta_d * (t ** 2) + beta_min * t)) - 1) ** 0.5 - vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) - vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d - vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / self.params.epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (self.params.epsilon_s - 1) - vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d - - sigma = vp_sigma(vp_beta_d, vp_beta_min) - sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) - sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) - elif self.solver_schedule == 'VE': # [VE_edm] - sigma = lambda t: t.sqrt() - sigma_deriv = lambda t: 0.5 / t.sqrt() - sigma_inv = lambda sigma: sigma ** 2 - elif self.solver_schedule == 'Linear': # [iDDPM_edm, EDM] - sigma = lambda t: t - sigma_deriv = lambda t: 1 - sigma_inv = lambda sigma: sigma - - return sigma, sigma_deriv, sigma_inv - + vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1) + vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d + sigma = partial(SIGMA_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + sigma_deriv = partial(SIGMA_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + sigma_inv = partial(SIGMA_T_INV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + scale = partial(SCALE_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + scale_deriv = partial(SCALE_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + + return sigma, sigma_deriv, sigma_inv, scale, scale_deriv + - def _get_scaling(self, sigma, sigma_deriv, sigma_inv, sigma_steps): - """ - Overview: - Get s(t) for different solver schedules. and t_steps - - Returns: - sigma(t), sigma'(t), sigma^{-1}(sigma) - """ - # Define scaling schedule. - if self.solver_scaling == 'VP': - s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() - s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) - elif self.solver_scaling == 'none': # [VE_edm, iDDPM_edm, EDM] - s = lambda t: 1 - s_deriv = lambda t: 0 - # Compute final time steps based on the corresponding noise levels. - t_steps = sigma_inv(self.preconditioner.round_sigma(sigma_steps)) - t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + def sample(self, + t_span, + batch_size, + latents: Tensor, + class_labels: Tensor=None, + use_stochastic: bool=False, + **solver_kwargs + ) -> Tensor: - return s, s_deriv, t_steps - - - def sample(self, latents, class_labels=None, use_stochastic=False): # Get sigmas, scales, and timesteps + log.info(f"Solver param is {self.solver_params}") + num_steps = self.solver_params.num_steps + epsilon_s = self.solver_params.epsilon_s + rho = self.solver_params.rho + latents = latents.to(self.device) - sigma_steps = self._get_sigma_steps() - sigma, sigma_deriv, sigma_inv = self._get_sigma_deriv_inv() - s, s_deriv, t_steps = self._get_scaling(sigma, sigma_deriv, sigma_inv, sigma_steps) + sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=num_steps, epsilon_s=epsilon_s, rho=rho) + sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv() + + S_churn = self.solver_params.S_churn + S_min = self.solver_params.S_min + S_max = self.solver_params.S_max + S_noise = self.solver_params.S_noise + alpha = self.solver_params.alpha + if not use_stochastic: # Main sampling loop t_next = t_steps[0] - x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + x_next = latents.to(torch.float64) * (sigma(t_next) * scale(t_next)) for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. - gamma = min(self.solver_params.S_churn / self.solver_params.num_steps, np.sqrt(2) - 1) if self.solver_params.S_min <= sigma(t_cur) <= self.solver_params.S_max else 0 + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 t_hat = sigma_inv(self.preconditioner.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) - x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * self.solver_params.S_noise * torch.randn_like(x_cur) + x_hat = scale(t_hat) / scale(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * scale(t_hat) * S_noise * torch.randn_like(x_cur) # Euler step. h = t_next - t_hat - denoised = self.preconditioner(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64) - d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised - x_prime = x_hat + self.solver_params.alpha * h * d_cur - t_prime = t_hat + self.solver_params.alpha * h + denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels).to(torch.float64) + d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h # Apply 2nd order correction. - if self.solver_type == 'euler' or i == self.solver_params.num_steps - 1: + if self.solver_type == 'euler' or i == num_steps - 1: x_next = x_hat + h * d_cur else: assert self.solver_type == 'heun' - denoised = self.preconditioner(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64) - d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised - x_next = x_hat + h * ((1 - 1 / (2 * self.solver_params.alpha)) * d_cur + 1 / (2 * self.solver_params.alpha) * d_prime) + denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels).to(torch.float64) + d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) else: assert self.edm_type == "EDM", f"Stochastic can only use in EDM, but your precond type is {self.edm_type}" @@ -247,9 +236,9 @@ def sample(self, latents, class_labels=None, use_stochastic=False): x_cur = x_next # Increase noise temporarily. - gamma = min(self.solver_params.S_churn / self.solver_params.num_steps, np.sqrt(2) - 1) if self.solver_params.S_min <= t_cur <= self.solver_params.S_max else 0 + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 t_hat = self.preconditioner.round_sigma(t_cur + gamma * t_cur) - x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * self.solver_params.S_noise * torch.randn_like(x_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) # Euler step. denoised = self.preconditioner(x_hat, t_hat, class_labels).to(torch.float64) @@ -257,7 +246,7 @@ def sample(self, latents, class_labels=None, use_stochastic=False): x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. - if i < self.solver_params.num_steps - 1: + if i < num_steps - 1: denoised = self.preconditioner(x_next, t_next, class_labels).to(torch.float64) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index e9f115a..7618332 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -3,10 +3,12 @@ import numpy as np import torch -from torch import Tensor +from torch import Tensor, as_tensor import torch.nn as nn import torch.nn.functional as F +from .edm_utils import SIGMA_T, SIGMA_T_INV + class PreConditioner(nn.Module): def __init__(self, @@ -25,8 +27,9 @@ def __init__(self, self.beta_min = precond_config_kwargs.get("beta_min", 0.1) self.M = precond_config_kwargs.get("M", 1000) self.epsilon_t = precond_config_kwargs.get("epsilon_t", 1e-5) - self.sigma_min = float(self.sigma_for_vp_edm(self.epsilon_t)) - self.sigma_max = float(self.sigma_for_vp_edm(1)) + + self.sigma_min = SIGMA_T["VP_edm"](self.epsilon_t, self.beta_d, self.beta_min) + self.sigma_max = SIGMA_T["VP_edm"](1, self.beta_d, self.beta_min) elif self.precondition_type == "VE_edm": self.sigma_min = precond_config_kwargs.get("sigma_min", 0.02) @@ -44,22 +47,14 @@ def __init__(self, self.sigma_max = float(u[0]) elif self.precondition_type == "EDM": - self.sigma_min = precond_config_kwargs.get("sigma_min", 0.) - self.sigma_max = precond_config_kwargs.get("sigma_max", float("inf")) + self.sigma_min = precond_config_kwargs.get("sigma_min", 0.002) + self.sigma_max = precond_config_kwargs.get("sigma_max", 80) self.sigma_data = precond_config_kwargs.get("sigma_data", 0.5) else: raise ValueError(f"Please check your precond type {self.precondition_type} is in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") - # For VP_edm - def sigma_for_vp_edm(self, t): - t = torch.as_tensor(t) - return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() - # For VP_edm - def sigma_inv_for_vp_edm(self, sigma): - sigma = torch.as_tensor(sigma) - return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d - + # For iDDPM_edm def alpha_bar(self, j): assert self.precondition_type == "iDDPM_edm", f"Only iDDPM_edm supports the alpha bar function, but your precond type is {self.precondition_type}" @@ -83,7 +78,7 @@ def get_precondition_c(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Ten c_skip = 1 c_out = -sigma c_in = 1 / (sigma ** 2 + 1).sqrt() - c_noise = (self.M - 1) * self.sigma_inv_for_vp_edm(sigma) + c_noise = (self.M - 1) * SIGMA_T_INV["VP_edm"](sigma, self.beta_d, self.beta_min) elif self.precondition_type == "VE_edm": c_skip = 1 c_out = sigma diff --git a/grl/generative_models/edm_diffusion_model/edm_utils.py b/grl/generative_models/edm_diffusion_model/edm_utils.py new file mode 100644 index 0000000..8e54d89 --- /dev/null +++ b/grl/generative_models/edm_diffusion_model/edm_utils.py @@ -0,0 +1,101 @@ +import numpy as np +import torch +from easydict import EasyDict + +############# Sampling Section ############# + +# Scheduling in Table 1 + +SIGMA_T = { + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: ((0.5 * beta_d * (t ** 2) + beta_min * t).exp() - 1) ** 0.5, + "VE_edm": lambda t, **kwargs: t.sqrt(), + "iDDPM_edm": lambda t, **kwargs: t, + "EDM": lambda t, **kwargs: t +} + +SIGMA_T_DERIV = { + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: 0.5 * (beta_min + beta_d * t) * (SIGMA_T["VP_edm"](t, beta_d, beta_min) + 1 / SIGMA_T["VP_edm"](t, beta_d, beta_min)), + "VE_edm": lambda t, **kwargs: t.sqrt(), + "iDDPM_edm": lambda t, **kwargs: t, + "EDM": lambda t, **kwargs: t +} + +SIGMA_T_INV = { + "VP_edm": lambda sigma, beta_d=19.9, beta_min=0.1: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1)).log() - beta_min).sqrt() / beta_d, + "VE_edm": lambda sigma, **kwargs: sigma ** 2, + "iDDPM_edm": lambda sigma, **kwargs: sigma, + "EDM": lambda sigma, **kwargs: sigma +} + +# Scaling in Table 1 +SCALE_T = { + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: 1 / (1 + SIGMA_T["VP_edm"](t, beta_d, beta_min) ** 2).sqrt(), + "VE_edm": lambda t, **kwargs: 1, + "iDDPM_edm": lambda t, **kwargs: 1, + "EDM": lambda t, **kwargs: 1 +} + +SCALE_T_DERIV = { + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: -SIGMA_T["VP_edm"](t, beta_d, beta_min) * SIGMA_T_DERIV["VP_edm"](t, beta_d, beta_min) * (SCALE_T["VP_edm"](t, beta_d, beta_min) ** 3), + "VE_edm": lambda t, **kwargs: 0, + "iDDPM_edm": lambda t, **kwargs: 0, + "EDM": lambda t, **kwargs: 0 +} + + +INITIAL_SIGMA_MIN = { + "VP_edm": SIGMA_T["VP_edm"](torch.tensor(1e-3), 19.9, 0.1), + "VE_edm": 0.02, + "iDDPM_edm": 0.002, + "EDM": 0.002 +} + +INITIAL_SIGMA_MAX = { + "VP_edm": SIGMA_T["VP_edm"](torch.tensor(1.), 19.9, 0.1), + "VE_edm": 100, + "iDDPM_edm": 81, + "EDM": 80 +} + +###### Default Params ###### + +DEFAULT_PARAM = EasyDict({ + "VP_edm": + { + "beta_d": 19.9, + "beta_min": 0.1, + "M": 1000, + "epsilon_t": 1e-5, + }, + "VE_edm": + { + "sigma_min": 0.02, + "sigma_max": 100 + }, + "iDDPM_edm": + { + "C_1": 0.001, + "C_2": 0.008, + "M": 1000 + }, + "EDM": + { + "sigma_min": 0.002, + "sigma_max": 80, + "sigma_data": 0.5, + "P_mean": -1.2, + "P_std": 1.2 + } +}) + +DEFAULT_SOLVER_PARAM = EasyDict( + { + "num_steps": 18, + "epsilon_s": 1e-3, + "rho": 7, + "S_churn": 0., + "S_min": 0., + "S_max": float("inf"), + "S_noise": 1., + "alpha": 1 +}) \ No newline at end of file diff --git a/grl/generative_models/edm_diffusion_model/test.ipynb b/grl/generative_models/edm_diffusion_model/test.ipynb new file mode 100644 index 0000000..81b4bf0 --- /dev/null +++ b/grl/generative_models/edm_diffusion_model/test.ipynb @@ -0,0 +1,370 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Optional, Tuple, Literal\n", + "from dataclasses import dataclass\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from torch import Tensor\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from easydict import EasyDict\n", + "from functools import partial\n", + "\n", + "from edm_preconditioner import PreConditioner\n", + "from edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV, SCALE_T, SCALE_T_DERIV, DEFAULT_SOLVER_PARAM\n", + "from grl.generative_models.intrinsic_model import IntrinsicModel\n", + "\n", + "class Simple(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.model = nn.Sequential(\n", + " nn.Linear(2, 32),\n", + " nn.ReLU(),\n", + " nn.Linear(32, 32), \n", + " nn.ReLU(),\n", + " nn.Linear(32, 2)\n", + " )\n", + " def forward(self, x, noise, class_labels=None):\n", + " return self.model(x)\n", + "\n", + "class EDMModel(nn.Module):\n", + " \n", + " def __init__(self, config: Optional[EasyDict]=None) -> None:\n", + " \n", + " super().__init__()\n", + " self.config= config\n", + " # self.x_size = config.x_size\n", + " self.device = config.device\n", + " \n", + " # EDM Type [\"VP_edm\", \"VE_edm\", \"iDDPM_edm\", \"EDM\"]\n", + " self.edm_type: str = config.edm_model.path.edm_type\n", + " assert self.edm_type in [\"VP_edm\", \"VE_edm\", \"iDDPM_edm\", \"EDM\"], \\\n", + " f\"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}\"\n", + " \n", + " #* 1. Construct basic Unet architecture through params in config\n", + " self.base_denoise_network = Simple()\n", + "\n", + " #* 2. Precond setup\n", + " self.params = config.edm_model.path.params\n", + " self.preconditioner = PreConditioner(\n", + " self.edm_type, \n", + " base_denoise_model=self.base_denoise_network, \n", + " use_mixes_precision=False,\n", + " **self.params\n", + " )\n", + " \n", + " #* 3. Solver setup\n", + " self.solver_type = config.edm_model.solver.solver_type\n", + " assert self.solver_type in ['euler', 'heun']\n", + " \n", + " self.solver_params = DEFAULT_SOLVER_PARAM\n", + " self.solver_params.update(config.edm_model.solver.params)\n", + " \n", + " # Initialize sigma_min and sigma_max if not provided\n", + " \n", + " if \"sigma_min\" not in self.params:\n", + " min = torch.tensor(1e-3)\n", + " self.sigma_min = {\n", + " \"VP_edm\": SIGMA_T[\"VP_edm\"](min, 19.9, 0.1), \n", + " \"VE_edm\": 0.02, \n", + " \"iDDPM_edm\": 0.002, \n", + " \"EDM\": 0.002\n", + " }[self.edm_type]\n", + " else:\n", + " self.sigma_min = self.params.sigma_min\n", + " if \"sigma_max\" not in self.params:\n", + " max = torch.tensor(1)\n", + " self.sigma_max = {\n", + " \"VP_edm\": SIGMA_T[\"VP_edm\"](max, 19.9, 0.1), \n", + " \"VE_edm\": 100, \n", + " \"iDDPM_edm\": 81, \n", + " \"EDM\": 80\n", + " }[self.edm_type] \n", + " else:\n", + " self.sigma_max = self.params.sigma_max\n", + " \n", + " def get_type(self):\n", + " return \"EDMModel\"\n", + "\n", + " # For VP_edm\n", + " def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]:\n", + " # assert the first dim of x is batch size\n", + " print(f\"params is {params}\")\n", + " rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) \n", + " if self.edm_type == \"VP_edm\":\n", + " epsilon_t = params.get(\"epsilon_t\", 1e-5)\n", + " beta_d = params.get(\"beta_d\", 19.9)\n", + " beta_min = params.get(\"beta_min\", 0.1)\n", + " \n", + " rand_uniform = torch.rand(*rand_shape, device=x.device)\n", + " sigma = SIGMA_T[\"VP_edm\"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min)\n", + " weight = 1 / sigma ** 2\n", + " elif self.edm_type == \"VE_edm\":\n", + " rand_uniform = torch.rand(*rand_shape, device=x.device)\n", + " sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform)\n", + " weight = 1 / sigma ** 2\n", + " elif self.edm_type == \"EDM\":\n", + " P_mean = params.get(\"P_mean\", -1.2)\n", + " P_std = params.get(\"P_mean\", 1.2)\n", + " sigma_data = params.get(\"sigma_data\", 0.5)\n", + " \n", + " rand_normal = torch.randn(*rand_shape, device=x.device)\n", + " sigma = (rand_normal * P_std + P_mean).exp()\n", + " weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2\n", + " return sigma, weight\n", + " \n", + " def forward(self, \n", + " x: Tensor, \n", + " class_labels=None) -> Tensor:\n", + " x = x.to(self.device)\n", + " sigma, weight = self._sample_sigma_weight_train(x, **self.params)\n", + " n = torch.randn_like(x) * sigma\n", + " D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels)\n", + " loss = weight * ((D_xn - x) ** 2)\n", + " return loss\n", + " \n", + " \n", + " def _get_sigma_steps_t_steps(self, num_steps=18, epsilon_s=1e-3, rho=7):\n", + " \"\"\"\n", + " Overview:\n", + " Get the schedule of sigma according to differernt t schedules.\n", + " \n", + " \"\"\"\n", + " self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min)\n", + " self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max)\n", + " \n", + " # Define time steps in terms of noise level\n", + " step_indices = torch.arange(num_steps, dtype=torch.float64, device=self.device)\n", + " sigma_steps = None\n", + " if self.edm_type == \"VP_edm\":\n", + " vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1)\n", + " vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d\n", + " \n", + " orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)\n", + " sigma_steps = SIGMA_T[\"VP_edm\"](orig_t_steps, vp_beta_d, vp_beta_min)\n", + " \n", + " elif self.edm_type == \"VE_edm\":\n", + " orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (num_steps - 1)))\n", + " sigma_steps = SIGMA_T[\"VE_edm\"](orig_t_steps)\n", + " \n", + " elif self.edm_type == \"iDDPM_edm\":\n", + " M, C_1, C_2 = self.params.M, self.params.C_1, self.params.C_2\n", + " \n", + " u = torch.zeros(M + 1, dtype=torch.float64, device=self.device)\n", + " alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2\n", + " for j in torch.arange(self.params.M, 0, -1, device=self.device): # M, ..., 1\n", + " u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()\n", + " u_filtered = u[torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)]\n", + " \n", + " sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] \n", + " orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) \n", + " \n", + " elif self.edm_type == \"EDM\": \n", + " sigma_steps = (self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * \\\n", + " (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))) ** rho\n", + " orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) \n", + " \n", + " t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0\n", + " \n", + " return sigma_steps, t_steps \n", + " \n", + " \n", + " def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s=1e-3):\n", + " \"\"\"\n", + " Overview:\n", + " Get sigma(t) for different solver schedules.\n", + " \n", + " Returns:\n", + " sigma(t), sigma'(t), sigma^{-1}(sigma) \n", + " \"\"\"\n", + " vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1)\n", + " vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d\n", + " sigma = partial(SIGMA_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", + " sigma_deriv = partial(SIGMA_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", + " sigma_inv = partial(SIGMA_T_INV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", + " scale = partial(SCALE_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", + " scale_deriv = partial(SCALE_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", + "\n", + " return sigma, sigma_deriv, sigma_inv, scale, scale_deriv\n", + " \n", + " \n", + " def sample(self, \n", + " latents: Tensor, \n", + " class_labels: Tensor=None, \n", + " use_stochastic: bool=False, \n", + " **solver_params) -> Tensor:\n", + " \n", + " # Get sigmas, scales, and timesteps\n", + " print(f\"solver_params is {solver_params}\")\n", + " num_steps = self.solver_params.num_steps\n", + " epsilon_s = self.solver_params.epsilon_s\n", + " rho = self.solver_params.rho\n", + " \n", + " latents = latents.to(self.device)\n", + " sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=num_steps, epsilon_s=epsilon_s, rho=rho)\n", + " sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv()\n", + " \n", + " S_churn = self.solver_params.S_churn\n", + " S_min = self.solver_params.S_min\n", + " S_max = self.solver_params.S_max\n", + " S_noise = self.solver_params.S_noise\n", + " alpha = self.solver_params.alpha\n", + " \n", + " if not use_stochastic:\n", + " # Main sampling loop\n", + " t_next = t_steps[0]\n", + " x_next = latents.to(torch.float64) * (sigma(t_next) * scale(t_next))\n", + " for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1\n", + " x_cur = x_next\n", + "\n", + " # Increase noise temporarily.\n", + " gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0\n", + " t_hat = sigma_inv(self.preconditioner.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))\n", + " x_hat = scale(t_hat) / scale(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * scale(t_hat) * S_noise * torch.randn_like(x_cur)\n", + "\n", + " # Euler step.\n", + " h = t_next - t_hat\n", + " denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels).to(torch.float64)\n", + " d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised\n", + " x_prime = x_hat + alpha * h * d_cur\n", + " t_prime = t_hat + alpha * h\n", + "\n", + " # Apply 2nd order correction.\n", + " if self.solver_type == 'euler' or i == num_steps - 1:\n", + " x_next = x_hat + h * d_cur\n", + " else:\n", + " assert self.solver_type == 'heun'\n", + " denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels).to(torch.float64)\n", + " d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised\n", + " x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)\n", + " \n", + " else:\n", + " assert self.edm_type == \"EDM\", f\"Stochastic can only use in EDM, but your precond type is {self.edm_type}\"\n", + " x_next = latents.to(torch.float64) * t_steps[0]\n", + " for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1\n", + " x_cur = x_next\n", + "\n", + " # Increase noise temporarily.\n", + " gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0\n", + " t_hat = self.preconditioner.round_sigma(t_cur + gamma * t_cur)\n", + " x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)\n", + "\n", + " # Euler step.\n", + " denoised = self.preconditioner(x_hat, t_hat, class_labels).to(torch.float64)\n", + " d_cur = (x_hat - denoised) / t_hat\n", + " x_next = x_hat + (t_next - t_hat) * d_cur\n", + "\n", + " # Apply 2nd order correction.\n", + " if i < num_steps - 1:\n", + " denoised = self.preconditioner(x_next, t_next, class_labels).to(torch.float64)\n", + " d_prime = (x_next - denoised) / t_next\n", + " x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)\n", + "\n", + "\n", + " return x_next\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "params is {}\n", + "solver_params is {}\n" + ] + } + ], + "source": [ + "import torch\n", + "from easydict import EasyDict\n", + "\n", + "config = EasyDict(\n", + " dict(\n", + " device=torch.device(\"cuda\"), # Test if all tensors are converted to the same device\n", + " edm_model=dict( \n", + " path=dict(\n", + " edm_type=\"EDM\", # *[\"VP_edm\", \"VE_edm\", \"iDDPM_edm\", \"EDM\"]\n", + " params=dict(\n", + " #^ 1: VP_edm\n", + " # beta_d=19.9, \n", + " # beta_min=0.1, \n", + " # M=1000, \n", + " # epsilon_t=1e-5,\n", + " # epsilon_s=1e-3,\n", + " #^ 2: VE_edm\n", + " # sigma_min=0.02,\n", + " # sigma_max=100,\n", + " #^ 3: iDDPM_edm\n", + " # C_1=0.001,\n", + " # C_2=0.008,\n", + " # M=1000,\n", + " #^ 4: EDM\n", + " # sigma_min=0.002,\n", + " # sigma_max=80,\n", + " # sigma_data=0.5,\n", + " # P_mean=-1.2,\n", + " # P_std=1.2,\n", + " )\n", + " ),\n", + " solver=dict(\n", + " solver_type=\"heun\", \n", + " # *['euler', 'heun']\n", + " params=dict(\n", + " num_steps=18,\n", + " alpha=1, \n", + " S_churn=0, \n", + " S_min=0, \n", + " S_max=float(\"inf\"),\n", + " S_noise=1,\n", + " rho=7, #* EDM needs rho \n", + " epsilon_s=1e-3 #* VP_edm needs epsilon_s\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")\n", + "\n", + "edm = EDMModel(config).to(config.device)\n", + "x = torch.randn((1024, 2)).to(config.device)\n", + "noise = torch.randn_like(x)\n", + "loss = edm(x).mean()\n", + "sample = edm.sample(x)\n", + "sample.shape\n", + "loss.backward()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py index 733623e..d56b1be 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -71,11 +71,12 @@ params=dict( num_steps=18, alpha=1, - S_churn=0, - S_min=0, + S_churn=0., + S_min=0., S_max=float("inf"), - S_noise=1, + S_noise=1., rho=7, #* EDM needs rho + epsilon_s=1e-3 #* VP needs epsilon_s ) ), model=dict( @@ -217,4 +218,24 @@ def save_checkpoint(model, optimizer, iteration): history_iteration = [-1] batch_data = next(data_generator) batch_data = batch_data.to(config.device) + + for i in range(10): + edm_diffusion_model.train() + loss = edm_diffusion_model(batch_data).mean() + optimizer.zero_grad() + loss.backward() + gradien_norm = torch.nn.utils.clip_grad_norm_( + edm_diffusion_model.parameters(), config.parameter.clip_grad_norm + ) + optimizer.step() + gradient_sum += gradien_norm.item() + loss_sum += loss.item() + counter += 1 + iteration += 1 + log.info(f"iteration {iteration}, gradient {gradient_sum/counter}, loss {loss_sum/counter}") + + edm_diffusion_model.eval() + latents = torch.randn((2048, 2)) + sampled = edm_diffusion_model.sample(None, None, latents=latents) + log.info(f"Sampled size: {sampled.shape}") \ No newline at end of file From 45ea69b2036d856a2f0c706ae6a13265f6c97682 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Wed, 21 Aug 2024 06:41:39 +0000 Subject: [PATCH 03/14] feature(wrh): add edm initial implementation --- density_func.ipynb | 771 ------------------ .../edm_diffusion_model.py | 50 +- .../edm_diffusion_model/edm_preconditioner.py | 19 +- .../edm_diffusion_model/edm_utils.py | 3 +- .../edm_diffusion_model/test.ipynb | 370 --------- 5 files changed, 44 insertions(+), 1169 deletions(-) delete mode 100644 density_func.ipynb delete mode 100644 grl/generative_models/edm_diffusion_model/test.ipynb diff --git a/density_func.ipynb b/density_func.ipynb deleted file mode 100644 index d788bec..0000000 --- a/density_func.ipynb +++ /dev/null @@ -1,771 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from base64 import b64encode\n", - "import pickle\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "from easydict import EasyDict\n", - "from IPython.display import HTML\n", - "from rich.progress import track\n", - "from sklearn.datasets import make_swiss_roll\n", - "import matplotlib\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib import animation\n", - "matplotlib.use(\"Agg\")\n", - "\n", - "from grl.generative_models.diffusion_model import DiffusionModel\n", - "from grl.generative_models.conditional_flow_model import IndependentConditionalFlowModel, OptimalTransportConditionalFlowModel\n", - "from grl.generative_models.metric import compute_likelihood\n", - "from grl.utils import set_seed\n", - "from grl.utils.log import log" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.utils import shuffle as util_shuffle\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "\n", - "\n", - "def plot2d(data):\n", - " plt.scatter(data[:, 0], data[:, 1])\n", - " plt.show()\n", - "\n", - "\n", - "def show_video(video_path, video_width=600):\n", - "\n", - " video_file = open(video_path, \"r+b\").read()\n", - "\n", - " video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n", - " return HTML(\n", - " f\"\"\"\"\"\"\n", - " )\n", - "\n", - "\n", - "def render_video(\n", - " data_list, video_save_path, iteration, fps=100, dpi=100\n", - "):\n", - " if not os.path.exists(video_save_path):\n", - " os.makedirs(video_save_path)\n", - " fig = plt.figure(figsize=(6, 6))\n", - " plt.xlim([-5, 5])\n", - " plt.ylim([-5, 5])\n", - " ims = []\n", - " colors = np.linspace(0, 1, len(data_list))\n", - "\n", - " for i, data in enumerate(data_list):\n", - " im = plt.scatter(data[:, 0], data[:, 1], s=1)\n", - " title = plt.text(0.5, 1.05, f't={i/len(data_list):.2f}', ha='center', va='bottom', transform=plt.gca().transAxes)\n", - " ims.append([im, title])\n", - "\n", - " ani = animation.ArtistAnimation(fig, ims, interval=0.1, blit=True)\n", - " ani.save(os.path.join(video_save_path, f'iteration_{iteration}.mp4'), fps=fps, dpi=dpi)\n", - " # clean up\n", - " plt.close(fig)\n", - " plt.clf()\n", - "\n", - "def load_and_plot_results(file_path):\n", - " try:\n", - " with open(file_path, \"rb\") as f:\n", - " results = pickle.load(f)\n", - " except Exception as e:\n", - " print(f\"Failed to load the file: {e}\")\n", - " return\n", - "\n", - " plt.figure(figsize=(10, 6))\n", - " x = results[\"iterations\"]\n", - " if \"gradients\" in results and results[\"gradients\"]:\n", - " plt.plot(x, results[\"gradients\"], label=\"Gradients\")\n", - " if \"losses\" in results and results[\"losses\"]:\n", - " plt.plot(x, results[\"losses\"], label=\"Losses\")\n", - " plt.xlabel(\"Iteration\")\n", - " plt.ylabel(\"Log(Value)\")\n", - " plt.yscale(\"log\")\n", - " # Specify y-ticks\n", - " y_ticks = [1e-1, 5e-1, 1, 5, 10]\n", - " plt.yticks(y_ticks, [f\"{y:.0e}\" for y in y_ticks])\n", - " plt.title(\"Training Metrics Over Iterations\")\n", - " plt.legend()\n", - " plt.grid(True)\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "x_size = 2\n", - "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", - "t_embedding_dim = 32\n", - "\n", - "diffusion_model_config = EasyDict(\n", - " dict(\n", - " device=device,\n", - " project=\"linear_vp_sde_noise_function_score_matching\",\n", - " diffusion_model=dict(\n", - " device=device,\n", - " x_size=x_size,\n", - " alpha=1.0,\n", - " solver=dict(\n", - " type=\"ODESolver\",\n", - " args=dict(\n", - " library=\"torchdiffeq_adjoint\",\n", - " ),\n", - " ),\n", - " path=dict(\n", - " type=\"linear_vp_sde\",\n", - " beta_0=0.1,\n", - " beta_1=20.0,\n", - " ),\n", - " model=dict(\n", - " type=\"noise_function\",\n", - " args=dict(\n", - " t_encoder=dict(\n", - " type=\"GaussianFourierProjectionTimeEncoder\",\n", - " args=dict(\n", - " embed_dim=t_embedding_dim,\n", - " scale=30.0,\n", - " ),\n", - " ),\n", - " backbone=dict(\n", - " type=\"TemporalSpatialResidualNet\",\n", - " args=dict(\n", - " hidden_sizes=[128, 64, 32],\n", - " output_dim=x_size,\n", - " t_dim=t_embedding_dim,\n", - " ),\n", - " ),\n", - " ),\n", - " ),\n", - " ),\n", - " parameter=dict(\n", - " training_loss_type=\"score_matching\",\n", - " lr=5e-4,\n", - " data_num=100000,\n", - " # weight_decay=1e-4,\n", - " iterations=100000,\n", - " batch_size=4096,\n", - " # clip_grad_norm=1.0,\n", - " eval_freq=1000,\n", - " video_save_path=\"./video-diffusion\",\n", - " device=device,\n", - " ),\n", - " )\n", - ")\n", - "\n", - "flow_model_config = EasyDict(\n", - " dict(\n", - " device=device,\n", - " project=\"icfm_velocity_function_flow_matching\",\n", - " flow_model=dict(\n", - " device=device,\n", - " x_size=x_size,\n", - " alpha=1.0,\n", - " solver=dict(\n", - " type=\"ODESolver\",\n", - " args=dict(\n", - " library=\"torchdiffeq_adjoint\",\n", - " ),\n", - " ),\n", - " path=dict(\n", - " sigma=0.1,\n", - " ),\n", - " model=dict(\n", - " type=\"velocity_function\",\n", - " args=dict(\n", - " t_encoder=dict(\n", - " type=\"GaussianFourierProjectionTimeEncoder\",\n", - " args=dict(\n", - " embed_dim=t_embedding_dim,\n", - " scale=30.0,\n", - " ),\n", - " ),\n", - " backbone=dict(\n", - " type=\"TemporalSpatialResidualNet\",\n", - " args=dict(\n", - " hidden_sizes=[128, 64, 32],\n", - " output_dim=x_size,\n", - " t_dim=t_embedding_dim,\n", - " ),\n", - " ),\n", - " ),\n", - " ),\n", - " ),\n", - " parameter=dict(\n", - " training_loss_type=\"flow_matching\",\n", - " lr=5e-4,\n", - " data_num=100000,\n", - " # weight_decay=1e-4,\n", - " iterations=100000,\n", - " batch_size=4096,\n", - " # clip_grad_norm=1.0,\n", - " eval_freq=1000,\n", - " video_save_path=\"./video-flow\",\n", - " device=device,\n", - " ),\n", - " )\n", - ")\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAi8AAAGdCAYAAADaPpOnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABrrUlEQVR4nO3deVhV1f4/8DcgowiEIjgGSk45kDNqioqi0kBaZtY1MzVNK1PrHr3mlAmpmdeGa369Stb1eq3UhmNOKOKAs+SMoRKKggPBEZkE9u8Pf5Ao4Fn7rLWn83k9D89TuIfPOZy99/usvfZaDpIkSSCEEEII0QlHtQsghBBCCGFB4YUQQgghukLhhRBCCCG6QuGFEEIIIbpC4YUQQgghukLhhRBCCCG6QuGFEEIIIbpC4YUQQgghulJD7QJ4Ky0txZUrV1CrVi04ODioXQ4hhBBCrCBJEm7duoX69evD0bH6thXDhZcrV66gUaNGapdBCCGEEBkuXbqEhg0bVruM4cJLrVq1ANx98V5eXipXQwghhBBrWCwWNGrUqPw6Xh3DhZeyW0VeXl4UXgghhBCdsabLB3XYJYQQQoiuUHghhBBCiK5QeCGEEEKIrlB4IYQQQoiuUHghhBBCiK5QeCGEEEKIrlB4IYQQQoiuUHghhBBCiK4YbpA6QohxnL5sQeTnuyFZseze9/ugga+78JoIIeoTGl6io6Oxfv16nD17Fu7u7ujWrRs+/vhjNG/evNr1vvvuO3zwwQdITU3FY489ho8//hiDBg0SWSohxEYHU7IwdEWiavvvvmAH921un9QLwQGe3LdLCLGN0PCya9cuTJgwAZ06dUJxcTGmT5+O/v374/Tp06hZs2al6+zbtw8vvfQSoqOj8dRTT2HNmjWIiorC0aNH0bp1a5HlEkIqoXYoUVP4kl1WLUchhxBlOUiSZE2LLBfXr19H3bp1sWvXLvTs2bPSZV588UXcvn0bv/zyS/nvunbtipCQECxbtuyh+7BYLPD29kZOTg7NbUSIFZKv3MLApQkoVbsQg/IAsGt6OPy8XNUuhRBNY7l+K9rnJScnBwDg6+tb5TKJiYmYPHlyhd9FRERg48aNlS5fWFiIwsLC8v+3WCy2F0qIweQWFOO1lQdxKO1PtUuxO3kAOs3fXum/bRzXHSGBPorWQ4gRKBZeSktLMWnSJHTv3r3a2z8ZGRnw9/ev8Dt/f39kZGRUunx0dDTmzJnDtVZC9CwpNRtRy/aqXQaxQlV/p4SpvdG4jofC1RCiH4qFlwkTJuDkyZPYs2cP1+1OmzatQkuNxWJBo0aNuO6DEK2y5/4oRtZz0c4Hfkf9agj5iyLhZeLEifjll1+QkJCAhg0bVrtsQEAAMjMzK/wuMzMTAQEBlS7v6uoKV1e6l0yMLyUjFxH/3IUSxXqpES25v/NwbSdg27R+8PV0UakiQtQjNLxIkoS33noLGzZsQHx8PIKCgh66TmhoKOLi4jBp0qTy323btg2hoaECKyVEe6hVhVTnZgnQft62Cr/b8nZPNK9fS6WKCFGO0PAyYcIErFmzBj/++CNq1apV3m/F29sb7u53B5MaMWIEGjRogOjoaADAO++8g169euGTTz5BZGQk1q5di8OHD2P58uUiSyVEdfvP3cSwlfvVLoPoWMTShAr/T2GGGJXQR6UdHBwq/f2qVaswcuRIAEBYWBgCAwMRGxtb/u/fffcdZsyYUT5I3YIFC6wepI4elSZ6QR1ridJ2Tg5DUN3Kx9giRG0s129Fx3lRAoUXolXpWfno92k88u7QiCpyRDTzxJJXesDdxQnA3fez/6fxuE3vpyz9H6uJf/7tyfL3kxC1UXih8EI04siFPzFk+T61y1DcutGh6Bxc9XhOWpWTdwdDl+1D8rVctUtRHD3NRNRG4YXCC1FJflEJ3l5zFNvOXlO7FCH0Gkp4M3rIocHziBoovFB4IQq6binEwH/uwo3bd9QuxWZBnsDGyf3h7eGsdim6dvHabfRfEg8j3NGaM+gxvNLjMTg5Vt6HkRBeKLxQeCGC6T2w/PxmD7Rp7K12GXanqLgUH/18Gl8f+EPtUmShIENEovBC4YUIkFtQjBH/PoCjl7LVLsVqL4X6Y05ke7jUcFS7FFKNjOwCRCyJR05BidqlWI1uIRLeKLxQeCGclJRKiN19ER/+ekbtUh7qw8hmGN49mL4VG0RWbhGe/Xw3LmUXqF3KQ22a+CRaNaTzLbENhRcKL8RGpy9bMOjz3WqXUSUaGt4+6WGKiP2mvgjwcVO7DKJDFF4ovBAZtNyPZe2orujarLbaZRCN0XLrzKyBwRjxZDNqCSRWo/BC4YUw0OIcQtSfgMiRk3cHL/xrL85dv612KRXQGDLEGhReKLyQh9BaK8uLXeviw6c6UMdawlVGdgH6L4mHRSMdgak1hlSHwguFF1KFE2k5ePrLPWqXAQBImNobjet4qF0GsSNamk9r7/t90MDXXe0yiIZQeKHwQu5RUirh3/EXMH/rWVXroNYVoiVaaX2kMYdIGQovFF4I7nZm7Lt4B/7MU6/JnIZZJ3qghWktvh3RGT1a+am2f6I+Ci8UXuzaxWu30XtxvGr73/J2TzSvX0u1/RNiq/3nbmLYyv2q7PtRD+CnqTRFhT2i8ELhxS6pOTYLBRZiVGoGGRozxr5QeKHwYleOXPgTQ5bvU3y/dK+e2JOi4lJ8+NMpfHMwTfF9U+d2+0DhhcKLXVBjfBYaLI6Qu+PJPP+vvfhd4fFkdk4OQ1DdmorukyiHwguFF0NTuqWF7sETUjU1piygEGNMFF4ovBiS0mNUUD8WQtgo3T+GbicZC4UXCi+GomRH3Be6+OGjpzvSWCyE2EDpkX2pY68xUHih8GIIKRm5CF+yS5F9USsLIWLsO3sDw2MPKLKv32bS7V09o/BC4UXXMrIL0DUmTvh+qC8LIcpRavwlDxdHJM2MoNZTHaLwQuFFl3ILitF69hbh+6EnhghRj1JPKj3Vuh4+f6W90H0Qvii8UHjRndBZZlwtFLuPTROfRKuG9JkgRCv2nL6OV1YfFLoPOu71g8ILhRfdUGKWZ3oigRBtU+JJQurUq30UXii8aJ7ofi2OAA7P6AdfTxdh+yCE8JV85RYiliYI234NRyB53iA4OToI2weRj8ILhRdNCzSZhW3b3Qk49EEEPN1qCNsHIUQs0Z17/96/Bcb3aSps+0QeCi8UXjRJ5HgtFFoIMZ60G3nouWinsO0fmh4OPy9XYdsnbCi8UHjRlJy8O2g3d6uQbbs4AkdnUmghxMhEtsS4OzvizIcDhWybsKHwQuFFM9pMM+OWgE+YA4AkGpCKELsicuBKmi9JfRReKLyo7rqlEJ3mbxeybWrqtR/XLYUYsGQXbubdqXa5lcM6oE9IgEJVEbWJugXtAOBiTCT37RLraCa8JCQkYOHChThy5AiuXr2KDRs2ICoqqsrl4+Pj0bt37wd+f/XqVQQEWHdiovCiPlEdcumRZ/1QehJNnuiRWv0Q9Tmj6ULUwXL9FtpR4Pbt22jXrh1GjRqFwYMHW71ecnJyhcLr1q0rojzCmajWFhpkSlusbQ3Rq4c9wk+tPNoREuiD1JhIvLk6AZtO3+K23bLHtVOpFUazFLtt5ODgYHXLy59//gkfHx9Z+6GWF3WIaG2Z0rcZ3ur3GPftEuvsSMrAqLVH1C5DVzaO646QQB+1y7BLRcWlaDbjV+7bpb4wytFMy4tcISEhKCwsROvWrTF79mx07969ymULCwtRWPjXuPIWi0WJEsn/J2I+IhpISlnJV25h4NIElKpdiAFUdgvDC0A8DZgonEsNR6TGRHIfALP34njqC6NBmgov9erVw7Jly9CxY0cUFhZixYoVCAsLw4EDB9C+feUTbEVHR2POnDkKV0oAoNdHZvzBr6UWAPU3EO3itdvovyQedyipKMYCoP28bRV+F/aoM/71el+4uzipU5SBBfi4ITUmkuvUIxLuti7TwwLaoanbRpXp1asXGjdujG+++abSf6+s5aVRo0Z020gw3reJfn6zB9o09ua6TQLsO3sDw2MPqF0GsQJ1EhWD97nKw8URp+fSuDAi6P620b06d+6MPXuqTs+urq5wdaUkrBTeTbI1HIGU+dQcywuFFf26f04fCjN8pMZEcn2YIK+oFIEmM3XmVZnmw0tSUhLq1aundhkE/L/B7H2/Dxr4unPdpr2h/irGdW+YqVcD2DydBmWUy8/LFakxkXh7zT78dPxPLtsMNJnpHKYioeElNzcXKSkp5f9/8eJFJCUlwdfXF40bN8a0adOQnp6O1atXAwCWLFmCoKAgPP744ygoKMCKFSuwY8cObN0qZmh5Yj2ewcW/lgsO/KMft+3Zm+1Hr2L0uqNql0EUdLUYFabY+GFsN3Ro8oiKFenT0uHdsGgov6eSui/YQa3HKhEaXg4fPlxh0LnJkycDAF599VXExsbi6tWrSEtLK//3oqIiTJkyBenp6fDw8EDbtm2xffv2SgeuI8rgPXbLbzSkP7Os3CI8tTQBVyyFD1+Y2IUhy/eV/zeNO8Om7KkkXlMNFJeCbiOpgKYHIFUKMpnB68MR7OeJ7VN6cdqa8WVkF6D/4nhYikrULkVT7h1HJf54JkauOaxuQRqz/Pkn0L9jfbXL0BWercpH6ZF4m2hmegA1UHjhg+cBfXI2zfpsjazcIjz1zwRcuWW8FhYtPwLP85FaLaEWGevxnLXax90ZSbP6c9mWvaHwQuFFtpJSCU2nb+KyLW/3GvhtVgSXbRnZ1sNXMPb7Y2qXIZs9jCqr99GG140ORedgX7XL0DyeX9roNhI7Ci8UXmT5/uAfmLr+JJdtUfNp9Q6mZGHoikS1y7Da3/sFYmzvVjTq8X0ysgvQ/9N4WAr1c3uPJjit3rAvdmD/pXwu26IAw4bCC4UXZk1NZvA4/To6ABei6YCtjF5uC9FcLrbTQzil20pVyy8qQcuZm7ls68zcATSSspUovFB4YcKrqZS+0VVOyxcymrFbOXtOX8crqw+qXUal6NitHK9z45PBdfDN6C5ctmVkFF4ovFiN18FJzaMVlZRK+HxzMj5NOK92KeVa+wBrJ1Hnaa1IychFxD93oURDZ+A1I7ugW4s6apehKbxGFafJHR+OwguFF6vwCC40QFNFGdkFCFsYhwINdIGo7wz8Oo3G1dELLYWZZUNCMKBTA7XL0BT6oicehRcKLw/F40CMfq4NXurSmEM1+pd85dYDc9OoYfukXggO8FS7DMKBFuapCnACtvyDAnAZCjBiUXih8FKlomI+Q2Ofnz+InjwBcOTCnxVGO1Xae+GPYlyfx+lvYXBaGLRQy2P1KGnt/oswbTxt83YowDyIwguFl0p98ONv+Cbxss3boYMO2H/uJoat3K/Kvqlfgn0rKi7FjO+OY91v6arsnzr38hsPi86lFVF4ofDygObTzCi08S9d39sN+6b15VOQTqkVWmiQMVIVtQY5pBDD5zYSBZi/UHih8FIBjwPM3idUVOP2ELWwEBZFxaWYse43rDt+RdH92vvtJAow/FB4ofBSjg4s2yjdEZcm1iM8KN1HxgHACTuew2zahmP47wHbQqM9n2fLUHih8AKAgost0rPy0X3BDsX2Z+/fXok4SanZiFq2V5F9NfZxR4KpjyL70hoeD0PY+4MQFF4ovFBwkYnX01jW+HZEZ/Ro5afIvggpKi7F9HW/4XsFbivFDu+IsLb+wvejRbaee78c3h6D2tbjVI2+UHix8/Bi68HjCOCCHQaXN1btxJbkPOH72ft+HzTwdRe+H0KqolQfLnudfsLWc/CYJ4Pwj8hWnKrRDwovdhxebD1oWtXzwqZ3nuRUjT6cvmzBoM93C90H9WUhWnTdUoiIxbuQVXBH6H7sscO/7QEmEP+IfJxTNfpA4cVOw4utB8tJO+twl1tQjNaztwjdx9pRXdG1WW2h+yCEh00HL+PN9b8J275fTRcc+qCfsO1rke23kJ7AoLb286WHwosdhhdbDxJ769/SdaYZGUXitm+vzeVE/+KPZ2LkmsPCtm9vU1jYem62p068FF7sLLxQcLEerxliq7JzchiC6tYUtn1ClCJ6biV7Ou+0/cAMiw135uzlvaLwYkfhpYnJjFIb1reXgwLgN6laZWi0UWJUB1OyMHRFopBt//xmD7Rp7C1k21rz2spE7DyXJXt9ezhXU3ixk/DSI2YHLmfny17fHg4GQOyYLfTkELEXIkOMvZyLfklKx8S1SbLXN/r7ROHFDsLLhqPpeHddkuz1jX4QlBHV2mJv9+0JKbPn9HW8svog9+3aSz8xWyd1NPK5m8KLwcMLffgfLifvDtrN3cp9uxvHdUdIoA/37RKiN9uPXsXodUe5b9cezk+AbV+sjPoeUXgxeHihD3312k43w2JLR6BK0CSJhFROROvmoenh8PNy5b5draFzeUUUXgwcXujDXj3eJ1L3Go44M28g120SYjTXLYXoNH871216ODvi9IfGPvZsnY7EaOd0luu3o0I1EQ4ouFTtuqWQe3A5ND2cggshVvDzckVqTCQ2juvObZt5d0qFPiGoBS41HDGqe6Ds9Y3+/lSHWl50goJL1XgfwNSvhRDb8D4mj87oB19PF67b1BJbnhx9sokXvhlrjCld6LaRwcJLsxm/oqhYXicOCi7Wq+EApEQb+/0iRClZuUVoP28bt+35uDsjaVZ/btvTmqYmM0pkrntm7gC4uzhxrUcNdNvIQP538JLs4HLO4Lc8eAaXve/3oeBCCEe+ni5IjYnEU635DEKXnX/H0LdJztvwRbPlzM0cK9EHannRMFseiX65awN8FBXCtyCN4DnEv7uzI84YvFOg3m09fAVjvz9W7TJ0q0/bbO2Yej8jtyjbcxcBzdw2SkhIwMKFC3HkyBFcvXoVGzZsQFRUVLXrxMfHY/LkyTh16hQaNWqEGTNmYOTIkVbv00jhxZ4/xFXh+c3LXh7H1AJRY4LwYC+Do2nBsvgziNl8gcu2jHKrpDL2eu7XTHj59ddfsXfvXnTo0AGDBw9+aHi5ePEiWrdujXHjxmH06NGIi4vDpEmTYDabERERYdU+jRJe7PXDWx1ewaWGI5Ay35jvkZqsaSHRG3uae0cptg6yea8ng2vjm9FduWxLa+Se75o+4oS4vw/gXI0yNBNeKuzIweGh4eXvf/87zGYzTp48Wf67YcOGITs7G5s3W3dPzwjhpdXMzcgrktd1i4JL9WguItvlFhTj+S924+z1PLVLUcXSqDZ4pmtjtcvQvTbTzLjF6epD572KTs6OgKdbDc7ViMdy/dbUq0tMTER4eHiF30VERGDSpElVrlNYWIjCwsLy/7dYLKLKU8R1SyEFl/vwCi5GfX9Eiz+eiZFrDqtdhma8vfEE3t54osLvaK4rdieiI7lN4xFoMhvy+D4zd4CszritZ28x5PtxL02Fl4yMDPj7+1f4nb+/PywWC/Lz8+Hu/uA35ujoaMyZM0epEoWTO0qlUT+oPIJLsJ8ntk/pxaEa+7Dv7A0Mjz2gdhm6Er5kV4X/pxY+63h7OCM1JpLLcW7EAOPu4oQng2tjd8pN5nWN+H7cS/ePSk+bNg05OTnlP5cuXVK7JNnkHsC/zTTm2Ac8TmgnZ0dQcHmIouJSjFt1AIEmMwJNZgouHHRfsKP8/fzlgH7PSUpJjYlEIIe7/EZ8lNqWPj1h8433fpTRVMtLQEAAMjMzK/wuMzMTXl5elba6AICrqytcXfX/xEjofHmP/vq63P32YjQ8TkJG/tZhq4vXbiP803iUGGqgBG2auOE4Jm44DgCYN7AZXun1mMoVaVP89EjkFhSj9ewtNm3HiC0OclunUi13+6jpsf/Lw2iq5SU0NBRxcRUv4tu2bUNoaKhKFSkjJ+8OrloKZK17dK6xDtKiYtvnM6nhSMGlMulZ+Wjxj00INJnRezEFFzXM+PVceYvMmoQUtcvRHE+3GlyOXSO2wMh9X2wNg1ol9Gmj3NxcpKTcPUCfeOIJLF68GL1794avry8aN26MadOmIT09HatXrwbw16PSEyZMwKhRo7Bjxw68/fbbhn9UWu6BZrQL9Ie/nMS/9/xh0zaor8GDYncmY/YWulBq2bIhIRjQqYHaZWgKtb4+yJaWKT28F5p5VDo+Ph69e/d+4PevvvoqYmNjMXLkSKSmpiI+Pr7COu+++y5Onz6Nhg0b4oMPPjD0IHXB082QM/r/uXkD4VJDUw1nNnlq6S6cvJJr0zb0cHAqJeHkNYz49pDaZRAZEqb2RuM6HmqXoQnD/xWPfX/ctmkbRjsv9F+8C+eusZ8r9fC50kx4UYOewkt6Vj66L9jBvN7g9n5YPLSzgIrU8fRnu3Ei3bZH3I12gpJr5Y6zmLv1vNplCLX3/T7449ptw3csptaYu/KLSmyeu8do5wejttZTeNFJeDHqB5DFh7+cxr/3XLRpG0Z6P+TQY+fbnZPDEFS3pir71lur1Mz+TTGqTwu1y1CdrbeRjHaeMOL1g8KLDsJLk2lmlMp457X8wWPFY7I2I70frPafu4lhK/erXUaV9NBMfb89p6/jldUH1S6jSkdn9IOvp4vaZaiGAsxfsnKL0H7eNub1tHxcUnjReHhJu5GHnot2Mq9ntBOXLSciBwAXDXQiYqG1loPFzzyOwd0C1S5DmOuWQoQv3omcAnkjX4tgzx3TKcD8pd2cLcjJL2ZeT6vvAYUXjYcXOQdfDQApGv3AyWHLCaihjzv2mPpwrEYftNIqQLMwa2eW7P2mvgjwcVO7DMVRgPmLkW4fUXjRcHjpPG8bruUWMa+nxQ+aXLaceD4dGoLn2ttXJ8ak1GxELdur2v5/GNsNHZo8otr+9eCXA5fKB6JTmgOA03MHwN3FSZX9q4UCzF1yOzRrMfhSeNFoeJE7CdkZA52YbDnhnJ8/CE6ODhyr0baM7AJ0jZE38rIt5g9qjuE9gxXfr1GkZ+Wj7yc7UaBwD+oWfh7YPOXBoSmMjALMXUP/lYiDf2Qxr6e110/hRaPhRc6B1rGeM75/xxhzF9lyotHaQSaa0iOEfjm4HQZ1bqjoPu1BSamEuRuS8PWhK4rt097+lhRg7pLzPkztH4iJfR4XUI08FF40GF5OX7Zg0Oe7mdez5wOrjFHeA2tsPpSOcT8kKbKvpVFt8EzXxorsi9xt3h/xf/tw6JJtYxpZy576JtlyfnmstjO2vaf/L4hybx9pqUWbwosGw4ucg+u3mf0NMelikMkMuR8yI90yq47cAQtZ0Zgh2pB85RYGLE2QfVywsJfwb0uAOTk7whCTFz7/5V4cTstmXk8rnxGW67dxxpfXsPGrDzOv4wFjzBb9zb4/ZJ+gwx57xC6CS6DJLDy4rBsditSYSAouGtG8fi1cjIlEakwklka1EbqvQJMZe05fF7oPLbDlAmyUyQu/f7O7rPXSs/I5VyIetbwIJncgNq0kYVuUlEpoOn2T7PWN8B5UJyUjF+FLdgnb/gsdvREzuLtmmoRJ9ZKv3ELE0gSh+zD6MQXQLWq5Hf218Nqp5UVDWnzAHly2T+oloBLlUXCpWqDJLCy4LI1qg9SYSCx8vgcFFx1pXr8WUmMicX7+ILzSoZ6QfQSazDiYwv5Uip7Ycu5QuqO8CAE+bnCRcdyfvqxMXyxeKLwIlJFdIGsKgOAAT/7FKMzev/1UJb+oRNgJMnZ4R6TGRFInXJ1zcnTAvBfaIzUmEvMHNee+/aErEg1xka6OLeeQFz9XfngC3s7NH8S8jpwHStRE4UUgOU13J2dHCKhEWe3myL9/bOTgEvXJJptnx63MtyM6IzUmEmFt/blvm6hreM9gpMZEYsFTLblvO9BkRk7eHe7b1Qq555IDlwuQX6SdqSDk+mFsN+Z19p+7KaASMSi8CJKSkcu8jr8zdN/jPSu3SNZcG4Cxg0ugyYyk63y7l61+pRNSYyLRo5Uf1+0S7RnaowlSYyIRE8m3w3W7uVvRWcbkfnoh98ugiC8ZSpMzKraWJ3q9H4UXQeT0Zzjwof4v3nJmOQWMH1x4WjYkBKkxkejZui7X7RLtG/ZkU+7HyrXcIsPeRvJ0q4FmdeXdhjfCeyKn/2TCyWsCKuGPwosAyVduMa+zbnSogEqUJfdg3/u+MSdZTLuRx/UE2KKuB1JjIjGgk33N7UQelBoTiaMz+nHdphEu1pXZOln+AxByzuVaIqf/pJZmrK8OhRcB5Dzu2DnYV0AlyrGlp3oDX3eOlWhDoMmMnot2ctvembkDsHmyfc1bQ6rn6+mC1JhIWX0bqmLUfjByW6tEP7quBDkhN/54poBK+KLwwpmcpL5p4pMCKlGW3J7qRrxdxPMbbNngcvYwWB+Rp0OTR7geR+3mbkW3aP0/cXM/ue+R3lukfD1d4Mx4pR+5hn1gVaVReOFMTlLX+/wjcg9uCi7VS42J1H2LHFFOakwkfpvJZ46eKzkFur9oV2a/qa+s9fQ+Ns7x2QOY19F63xcKLxyl3chjXkfvA9IlpWbLWu/Q9HC+hWgAr5P99km9DBnsiHjeHs5IjYnE4z58tme0ABPg4wYXJ/YB3IauSBRQjXLcXZxQ253tdWu97wuFF47k9HHQ+4B0Ucv2ylrPz8uVcyXq4nWST42J1P1ngqjPbIrEmbns37YrY7QAc+4j9gHcAP2/D3umsT82ruVxXyi8cJKVW8S8jt5bXYLodhEAPie1jeO6G+59Iepyd3Hi9pnS+4X7fudljEALACfScjhXohw5rS9aHveFwgsnvWSMb6Lnb9gXr92WNVv0uXkDudeiJh4n9dSYSIQE+theDCGVSI2J5PJFyUgBxsnRAVP6NmNe7+kv9wioRjlyWl+0OucRhRdOWJ8x+vnNHkLqUErvxfHM60Q+7gWXGsb5yPEKLoSIFhzgyeWzZqQA81a/x2St117H74G7ixNYn1vU6pxHxrmSqGhHUgbzOm0aewuoRBlye6F/8Tf9PxJextaT+LSIFhRciOIowFQk5/3IApBbIG8KFC04JGPcFzndIkSj8MLBqLVHmJZfO6qroEqUIacXOq/Og1pg68n7/PxBeKN3U07VEMImNSYSts6gZqQAs2ZkF+Z1Ws+WP/ms2nw9XZjXeUqDg/VReLGRnEHpujarLaASZTz1KfucTW18YZhB1mw9aafGRMLJkf1RTUJ4SuEwvYBRAky3FnVkrSdn8l2t2Dk5jGn5K5ZCMYXYgMKLjVgHpZsaJu8+qxbkF5XgZCb7Afvz+8a4PcIjuBCiFWXTC9jCKAFGTpCTM/muVgTVrcm8jtYem6bworDx/fUbXh6XMU38lrd7CqhEeRRciFHZ+tl86Ut+c3ipxdfTBZ4yWof1PPIua/cFrT02beutT7vG2lHXAdDtLYPrlkKUylivef1a3GtRGgUX9U2I3Qrz2conDHy7ZwNMHhSibEEGkxoTKftznpiWh/yiEt3fGj45dwDzezB0RaJuj2853Rdy8u7A28NZQDXsHCRJkjNcB5MvvvgCCxcuREZGBtq1a4fPPvsMnTt3rnTZ2NhYvPbaaxV+5+rqioKCAqv2ZbFY4O3tjZycHHh5iZ0ziPWDvn1SL92O7SLnxHZm7gDdn9CSr9yyaWZZvZ7YlPDtrt8x49dziu7z5zd76PpJP9FsCepG+KwnnLzG/EDCmpFdZPebUVsbk5lpmI+mtT0Q95642e1Zrt/CW17+97//YfLkyVi2bBm6dOmCJUuWICIiAsnJyahbt26l63h5eSE5Obn8/x0ctNdacV1GBya9BpeMbOuC472a1TRGJ10KLnys35eKyT+dUruMSgcZ2/t+HzTwdVehGu2xpQUm0GTW/We+Z+vKr0nVGR57QLeve9eMfmjPMMDq+Zvs8/eJIrzPy+LFizFmzBi89tpraNWqFZYtWwYPDw+sXLmyynUcHBwQEBBQ/uPv7y+6TGZd529nWv793uyjOWpF15g45nW2fqDPg/le9v4t1BYXr91GkMmMwP//o4XgUpXuC3aU1/numni1y1GdLZ9dI3Tg3TiuO/M6WuvMai05j01fvHZbQCXshIaXoqIiHDlyBOHhf80g7OjoiPDwcCQmVj1LZ25uLh599FE0atQIzz77LE6dqvrEV1hYCIvFUuFHCSWMy7/RL1hIHaLJGYxp2ZAQ/oUojIILuyMX/iwPAb0Xx8uaPkJtG47fLn8Nb67S71getrLlM6znTqwAZE3VobXOrCxYx7npvyReTCGMhIaXGzduoKSk5IGWE39/f2RkVN7ZtXnz5li5ciV+/PFHfPvttygtLUW3bt1w+fLlSpePjo6Gt7d3+U+jRo24v477pd1gbzrTa0fdqBj2E/iATg0EVKKcTzadlb2uPQaXCbFbEWgyY8jyfWqXwtWm5OLyILN083G1y1Gc3M/y0BVVfzHVi00T2UcD1+ukjaz9de7IeXJDAM09Kh0aGooRI0YgJCQEvXr1wvr16+Hn54evvvqq0uWnTZuGnJyc8p9Lly4Jr7HnIrZHA38Y201QJeKlMHZ3Wf78E2IKUUhJqYTPEs7LWteegsu+szfKL+xVPQVkJIvjLyHQZMbor/R/W4TFoenhD1+oEnq/fdSqIfvDHnqftJGFFm4dCQ0vderUgZOTEzIzMyv8PjMzEwEBAVZtw9nZGU888QRSUlIq/XdXV1d4eXlV+NGaDk0eUbsEWX45wB4E+3esL6AS5TSdvknWevYSXH7an4ZAkxnDYw+oXYoqtl9EeWjT8/w21vLzcoW7s7zLhN4DjJzWFy3OAWQN1teqhVtHQsOLi4sLOnTogLi4vzp8lpaWIi4uDqGhoVZto6SkBCdOnEC9evVElclErx9OOSZuYGsqXxrVRlAlyvhw/QlZ6+l9hnBrbD96FYEmM97eKO89MqLWs7eg9Qx9X6CtcebDgbLXzcnTb6ucnNaXgRq4qMvB+lq1cOtI+G2jyZMn4//+7//w9ddf48yZMxg/fjxu375dPpbLiBEjMG3atPLl586di61bt+LChQs4evQoXnnlFfzxxx8YPXq06FKt0ovhsTJAvxe2pNRs5nWe6dqYfyEKKSmV8O+DabLWNfK4ITl5d+7eLll3VO1SNCm3+G4Lw9vfsD+RpydyWxbbzd3KuRJlrRtt3ZfsMpm5+g1rrNT+Ii88vLz44otYtGgRZs6ciZCQECQlJWHz5s3lnXjT0tJw9erV8uX//PNPjBkzBi1btsSgQYNgsViwb98+tGrVSnSpVmGdhlGvF7aoZXuZlm/lr88xbMrQ7aIHBZrMur/4KOWnUwUINJmxJqHy29tGcH7+IFnrybn9rBWdg32Z12EdeV0rWKdyUXumaUVG2FWS6BF2We/j6vXixvo6T86OgKebPmebkDOqJqDfv+3D7Dt7Q9N9Wh4F8IfaRTyEUT8bn2w6K6tDu57fj82H0jHuhySmdfT6etW+vmlqhF0jYR2IiLXJUSs2JrJfGvQaXADICi6xwzsKqER9anayjIlsgWFPNuW6zenf78Oaw39y3aY1Ak1mLH7mcQzuFqj4vkWaMqiFrPCi59F3B3RqADCGl9yCYl2fE/WAWl4YqJ1KlcL6On8Y2023T1StSUjB9E3JD1/wPnr921ZHyeAyPTwIY8PVuRU86T87sfGEcsOc02flrv2mvgjwcRNQjXisr7eVvyc2vdtLUDXi7D93k2nAvY3jussa1K8qLNdvzY3zQtQl5+kAvQYXABRc8NeYLSJFtfFAakxk+Y9awQUAlrzcu7yO7ZPEX2D0/shwZVYO68C8jpxpRrRi7/t9mJY/nZkrqBKxWGeaZu0byRO1awnStLaH2iXI0puxg+bTrbQ375S1Pt/C/tiv3h8Hv5/IC+vzT3hh0YvsY2UoKTjAszyMXrcUokfMdhQKeAw00GTGbzP7w9vDmf/GVdAnJABYy77eibQcXT7EIGfizozsAt22NOkBtbxYKf545sMXusf6Cfp8RJp1VpJPhrcXUocSFu1kfzRaz4+D309UcEmYerdlQ+vB5X5+Xq5Inn+3RWbBUy25b7/d3K0Ina/f1of7/TazP/M6eh6F9vPn2jItH/4J20jshA2FFyuNXHOYaXk9fsPKL2KdbhJwqaHPj5CcuWrkzDarVbyDy+B2nuW3YhrX0Wer472G9miC1JhI/DazP2pxPJSvWgoMcxvJ28MZPjI6pZ6+rMzkubw91YVt3rxcLYzkJsN+U1+m5eWMCcaDPq88RIhRX25mWl7Ps0cvjmcfe4JnxzQ18bx4juxSB6kxkVj8kv46J1rD28MZJz68G8pqcdyuUQJM0uwI5nUGfb5bQCXaVFSsvwDDeqtLrX4vFF4EeL93M7VLkCWRcWwlvc4eHbuTvZPu0Rn9BFSiPF4XzSEhtZAaE4nZz3Xhsj09OBETiXPz5A+Vfz+jBJj5g5ozr6PXaQNWDGW7Tf7ef2lkalEovFghJYOt5/gb/YIFVUJ4mL2FfRRUX08XAZUoi9fFMjUmEp8MYxuN0yhcajgiNSYSq1/pxGV7Rggww3uyn+/6L94hoBLxwtuzzbH34ym2vpLEehRerBC+ZBfT8k6ODoIqEYd1AD69PnUjZ+jukzKaxrWGx0Xy8+faGu4xcbl6tq7L7b0Y+tl2LttRE2vrS2au8Wfk1jPWfi9qoPBCAIBpYCJAv0/djFp7hHkdvY+UySO4pMZEMndYtAepMZHMY4Dc72B6oazO8loip/Vl7W72kXq1gPWLW/IV1hnx1Mfa70WN+ZwovHDmpL9GF1KNhKm91S7BJraeVHzda1Bry0M08HW3+T1qOZOts7wWvd/nUablTeazgioRi/WL2wCVJzBUgpwvhbai8MLZlneM+dSFESzffpp5Hb0/9mvLSeW3mf1xdJb+b5kpxdYAo/f+L2/2b612CZpkqPl3NITCy0OUlLJ99IIDPAVVIs6JtBym5TdN1NfgY2Xmb7/ItPzy558QVIkybLkYpsZE6nKsIrXZOuWA3gMMq0XmY2qXQKogZwoIJVF4eYjlO9m/resN66iXrRrynfBSq/p3rK92CbLZGlyIfPdOOSDHvrM3OFajLNYvNp/vviKoErHWjurKtLwex3vpExKgdgnVovDyEB9vS1W7BMLBsm2nmJb3dNbvoWHLiJcUXPiR+14Ojz3AuRLl2MsXG9YJDO1hvBfWuxS20u8ZmhAGMXGpTMvvfM+2J0jUJHfESwou/P0wtpus9SZ+vY1zJcqZ+CRbi6Wcvmh6Yw/jvXy1jX38LFtQeOEodnhHtUsQTu7JWG/8vFzVLkGWmJ/lddDlOXIs+UuHJo/IWu+XM0WcK1HO1Ei2vmKsfdGINi3YeU7R/VF44Sisrb/aJTDLyC5gWl7uyVhNrBPB+Xvqt6Pqsr3sj0Z3rqvfCTb1QG6LVtRH9tV5V2/0OlCnUdAZy851jYlTuwThWCeC2zpZn2O7TP4v20jQZdZNpttFoskJMEn6G9vMruh1oE4WWr6bQOGFkPvo9RHh9b+xzcEFUD8XJcVEtmBe5+m5+mx9OTQ9nGn5xZuSxBSiIelZ+WqXwEzLdxMovFRDjx82Yp9Yn6YCgLHdtHtiMqJhTzZlXudEnoBCFMDaZ2xpQrqgSrSj7yc71S7BUCi8VKP7An3OfCqKlpsQeVnwVEu1S5CF9WkqAJj+jPH/nlrz28z+zOu88y2dh4ygoITG2uWJwguxmpabEKvCOrfP0B5NBFWiLXMHPKZ2CXbJ28MZrDclfzypzxbgt3s2ULsEYmAUXoihqTFhmNLeWr2deZ0RYc0EVEKs8bud9DOaPChE7RKEs4fWaK2i8MLJz2/2ULsEYqd+Pl3ItPyIzmyjgxL12cOto/jj+hvITY+t0SLZMro3KwovnLRp7K12CYRYZe5gtnlZCH+Ln3mcaXm93jpiMXLNYbVLIDaSO7q3HBRe7Nj2o1fVLkFTtD6LKjGOwd0C1S6BEF2j8GLHRq8z/mRhLLQ+i2pl3v6GbZDBz59rK6gSQogR1VC7gCpQeCFEx346xTa9w1NdGgmqhLBa8mxrtUsQbsXQ9mqXQGyUyDjgoFIovBCrPOLmpHYJhBhKVOijTMtPWZsgqBJxwtvXU7sEYiOtTlKrSHj54osvEBgYCDc3N3Tp0gUHDx6sdvnvvvsOLVq0gJubG9q0aYNNmzYpUSaphl7n+yHEKH6gyY4IKSc8vPzvf//D5MmTMWvWLBw9ehTt2rVDREQErl27Vuny+/btw0svvYTXX38dx44dQ1RUFKKionDy5EnRpZJqaDV9E+tRfxdCiFEIDy+LFy/GmDFj8Nprr6FVq1ZYtmwZPDw8sHLlykqX/+c//4kBAwbgvffeQ8uWLfHhhx+iffv2+Pzzz0WXSoihUX8XQohRCA0vRUVFOHLkCMLD/+rw4+joiPDwcCQmJla6TmJiYoXlASAiIqLK5QsLC2GxWCr8EEKIHix6upXaJRCiS0LDy40bN1BSUgJ//4qjEPr7+yMjo/I5ZzIyMpiWj46Ohre3d/lPo0b07ZIQog/Pdw9SuwRCdEn3TxtNmzYNOTk55T+XLl1SuyRCCCGECCR0/Jk6derAyckJmZkV56zIzMxEQEDlA4IFBAQwLe/q6gpXV+pMSgjRn293/a52CYToktCWFxcXF3To0AFxcX+NAlpaWoq4uDiEhoZWuk5oaGiF5QFg27ZtVS5PCCF6NePXc2qXQIguCb9tNHnyZPzf//0fvv76a5w5cwbjx4/H7du38dprrwEARowYgWnTppUv/84772Dz5s345JNPcPbsWcyePRuHDx/GxIkTRZdKiKF9v/ei2iUQQggXwqctePHFF3H9+nXMnDkTGRkZCAkJwebNm8s75aalpcHR8a8M1a1bN6xZswYzZszA9OnT8dhjj2Hjxo1o3dr4Q2lrWUZ2AQJ83NQug9hg6s+nqYMoIcQQFJlzaeLEiVW2nMTHxz/wuxdeeAEvvPCC4KoIi64xcUiNiVS7DELs1vCOj6hdAiGaofunjQghRI/W7j7PtPz857sJqkSczYfS1S6B2Cgjm23yV6VQeCFExwa382Ranvq9aIfJfFbtEoQb90OS2iUQG/X/NF7tEipF4cWOxQ7vqHYJmrL18BW1S2C2+KVeTMtP/fm0oEoIIUZkKSxRu4RKUXixY2Ft/R++kB0Z+/0xtUsgdoLGdyHENhReOElKzVa7BEKsMvV/u9Uuwe6xju8yJKSWoEq049sRndUugdjo5zd7KLYvCi+cRC3bq3YJxE4917Ym0/LfH6PJS9WUk3eHeZ1PhvUUUIlYRcWlTMv3aOUnqBJxdiRVPueevWrT2FuxfVF4IYa2+pVOapcg3KfDw5jXWbbtFP9CiFXazd2qdgmKeH2l8b/QjVp7RO0S7BaFl2p4uTqpXYKmbD96Ve0SmPVsXZdp+TUJKYIq0ZaYuFS1S7BLaTfymNfR6y2j3ReohY+IQ+GlGlvfDVO7BE0Zve6o2iUIN31TstolyDJ3wGPM68z4IVFAJaQ6PRftZF5Hj7eMCBGNwks1aDh8ohcjwpoxr/PtoSwBlZCqrNzBPq5Le+W6EHCVnpXPtPzksEaCKtGO/aa+apdgKBRe7Nwjbsa/NebAuPx1S6GQOkSTM3x8oMksoBJSmblb2UbUBYD10/Q5JUfYoh1My789oK2gSrRDj1+GtdwhmcKLnds6ubfaJQi3bRLbQG6d5m8XVIlYcoePb/aPTZwrIfeTExI7+QooRCF32B400qWNiX+oXYJwWu6QTOGFIy2n1Kr4ebkyLX8wRX+3GoID2IbQ1zM5ze9FJZJuW5v0QG7r1nfv67PVhfURab2a9ONJtUuwaxReONJySuVl6Ar76OTJes9eK+Q2v+u1tUnrxsXGy1qPdeweLXnt32yPSM/s31RQJURJK4d1UHR/FF6IXZjRrwnT8t0XsN2z15Kdk8NkrUf9X/gqKi7F5rO3Za0rZ+werdh7ke0R6VF9WgiqRDvG96yndgnC9QkJUHR/FF4eQuk0ScQY3bel2iUoJqiu/G/tFGD4aTbjV1nrKTnEOm/2Mk3KntPXmZafOuAJQZWIU1IqqV1CtSi8PITSaVINmyY+ybT8ibQcQZVoy6aDl9UuQbbUGPn9JSjA2OZgSpZN76GSQ6zzxjpNyqReDQVVItYrqw8yLe/kyPrMo/q+2qbtATspvHCWfOWW2iUwa9XQi2n5p7/cI6gSsVjvrb+5/jdBlSjD1gBz8Zq8Wx72LNBktqlfmC1/Mz2aNLCd2iWQKizYyTZ5qNIovHAWsTRB7RJIFeTcW9djGL3X2lFdZa/be3E8tcIwsPW90ntwWfKrvsM+0RcKL4RUQ+9htGuz2jZvgwJM9U6k5dh9cAGAJbvYbrMuerqVoErEWr8vlWl5uR3o9SR2eEfF90nhxQrOdvAu/TCWbYAzvQ7QJGeW6azcIgGVKIfHhTHQZMba3ewjxBpdoMls823UhKn6HyhydTz7LYbnuwcJqES8yT+xzchuSwd6tbAOFRHW1l9QJVWzg8uy7bZOCmNaXo+DNHVowja0vF4HaGKdZRoA2s/bJqASZfEIMCbzWQSazLr8fPP2y4FLXFqkHB2AxnU8OFSkrpmbf2davrmLoEIIF/0+jVe7hIei8GIF1uT89w37BVVCeJg3kH0SQ70OWncvXrcmms34FV3t9FZS2o08BJrMmLjhOJftXYjW/+2i2J3sM7FvnDFAQCXibT18hWn5N57U59OqeTqY34HCiwAbjvypdgmyjAt9lGn5Xw5cElSJWK/0eox5HT0PWncvXgEmA3dvmUz+7y4u29O63IJiBJrM6LloJ7dtGqGfCwDM3sL+SK27iz4nhB37/TGm5d8f2F5QJYTCCyn33tOPMy3P69unGt7vwxbUAH3O61QZnhfN9b/lItBkxviVm7ltU0tOX7Yg0GRG69lbuG7XKMFFzq2z7YwTpeqZHsd3YW1lVmtQRQovVmLt6KnHTp5yDrT8ohIBlYj3Zv/WzOsYaV4n3hfPX8+VINBkRqDJjP3nbnLdthqWbTuFQJMZgz7fzXW7CVN7Gya4ZGQXyFpPrxOl/rQ/jWl5Z0F1iMba30WtQRUpvFiJtaNnmE47ebb2ZzuxvPaFfr9xTw9nf9ph1vfGGctC1EV02Mr9CDSZ8eYqvq0Voh258Gd5AIuJS+W+/dSYSEN0zi3TNSaOeR09t7q8vfEE0/IHZvQTVIlYeujvAlB4EYZtajLtWDu+O9Py+zMFFaKAseHs40x8ffiy5uf8YJEaE8k8PYS1NiUXl4eBMcu12cE3/nhmeY1Dlu8Tth+jtLaUYe24WkavrS5pN/KY1/H1pEeqRKqhdgFEWzzd2D8SB1Oy0DnYV0A14i16uhWm/nyaaZ2m0zcZ6mLUqqEXUmMihQ5Gt+1Cxf4RIzrXxtzB8kf/lSsn7w7azd2q6D6N9Fkpw9pxFQB+m9lfQCXK6LuYraP2gFZugioRa9/ZG0zLqzmJqIMkScb5GgnAYrHA29sbOTk58PJim7PnYQ6mZDH1e1g7qiuXEU6V9suBS8ydcfV8gpZz0Y5oXhdfvcY+4J3WqTma7qiufpgZ1Znb9nILirl3tGWx/Pkn0L9jfdX2L4qcz4iPWw0kzY4QUI0yWF/zuXkD4VJDfzc2WF8n7/M+y/WbWl4YsLYuDFu5X5cX9ae6NGIOLzl5d+Dtoc8uamtHdcWwlWxj82xJvoai4lJdnqCqkxoTidOXLdw7qlpj5f7rWLlfm7eXWOnxuLfGUx/L+/voObjImV3eaOcFLRL6DmdlZeHll1+Gl5cXfHx88PrrryM3N7fadcLCwuDg4FDhZ9y4cSLLJByEKdwUz5Pc1rFmM37lXIk2lN1GGt6RbdRlAiyNamPY4JJfVIKTMoawWv78E/yLURDr7PJqzPPDw4m0HKblG6jcpUdoeHn55Zdx6tQpbNu2Db/88gsSEhIwduzYh643ZswYXL16tfxnwYIFIstk0rQ229MCSanZYgoRjPVepj6H5fuL3AuOkSctnP98N8NeiHkb2t4bqTGReKZrY7VLEablTHlPFur51hlrHxBAnXl+eGCdo2uTSd0+TMLCy5kzZ7B582asWLECXbp0QY8ePfDZZ59h7dq1uHKl+p7qHh4eCAgIKP/h3XfFFusnsF3Uo5btFVSJWHKe3dfrZI1llg0JkbXeO2vYOy/qSWpMJPPEnfYiPOju+7NgqHodF5UgN6SfmzeQcyXKGh57gGl5fT62II/a3QSEhZfExET4+PigY8e/mtDCw8Ph6OiIAweq/0D85z//QZ06ddC6dWtMmzYNeXlVP6ZWWFgIi8VS4Ucktf9gSvpycDum5fU6WWOZAZ0ayFrvx+NXDD9ZYYcmjyA1JlLWrNxGFNbobmhZ8YbxW6bkBpe+wX667vshp9V8p06fqGK9ZaQFwj5ZGRkZqFu34sBuNWrUgK+vLzIyMqpcb/jw4fj222+xc+dOTJs2Dd988w1eeeWVKpePjo6Gt7d3+U+jRo24vYaqeLmyzcuh12HlB3VuyLyOnM5tWiL3NolR+7/cr2frukiNicTOyWFql6KKp1u5IjUmErETjB9aANsmJP33aH5PjqlBTqu5Xr/cst4ySpjaW1Al1mMOLyaT6YEOtff/nD17VnZBY8eORUREBNq0aYOXX34Zq1evxoYNG3D+/PlKl582bRpycnLKfy5dEj9Z4NZ3w5iW1/Ow8k+3Yrt/y9q5TYvWjOwiaz0j93+5X1DdmkiNiURqTCSef0I7t3VFmTewGVJjIvHZiHC1S1GU3AlJ9d5XSk5LxJa3ewqoRJu0MFI086PSU6ZMwciRI6tdpkmTJggICMC1a9cq/L64uBhZWVkICLB+mvAuXe5eSFJSUtC0adMH/t3V1RWurq5Wb4+HAB/2AYj0+ljtJ8Pb42fGVoVNBy/LarXRim4t6sheN9Bk1v2Jm9WiF5/EoheB65ZCdJq/Xe1yuHn+CS8selHM6MN6IDeMqzlwGS+sLREA0Lx+LQGViLfn9HW1S5CFObz4+fnBz8/vocuFhoYiOzsbR44cQYcOHQAAO3bsQGlpaXkgsUZSUhIAoF69eqylClXfyxVXLIVWL//3Dfvx6Qv66/ToUsMRbk5AAcP8i2+u/w2pOg4vAGwacdYeAwwA+Hm5VnjdY5abse2CigXJMDq0LmY8S/16bGlFVGuiPl6OXGB/dnLd6FABlSjjldUHmZbXwi0jQPAIuwMHDkRmZiaWLVuGO3fu4LXXXkPHjh2xZs0aAEB6ejr69u2L1atXo3Pnzjh//jzWrFmDQYMGoXbt2jh+/DjeffddNGzYELt27bJqnyJH2L1XVm4R2jNOvqjXC5qc11rbwxlHdNp57V62nMT1+vcWgXV0aqU81dIFn7+qzwn0RLH3z7yc16/X1y2ntVTka9XMCLv/+c9/MHHiRPTt2xeOjo4YMmQIli5dWv7vd+7cQXJycvnTRC4uLti+fTuWLFmC27dvo1GjRhgyZAhmzJghskxZ5Ey6lXzlli6bFuW81pt5d3Q96m6ZnZPD0HtxvKx17bUFpjKdg30feC+UDjRPt3K1uz4rrOw9uOxIqvphkqroudVlwBLrGgXKtPYRU4ccNLeRDeKPZ2LkmsNM6+j1AJfT+gLo9/Xey9aOuEZ4D4jx2fI5Pz9/EJwcHThWow57anUB2F/vydkRsibvtRbL9Vt/PUg1RM5IinodD0Tu9O6nL4sdd0cJtp6c7OkpJKJPtnxGP4pqbbfBRc+DN8oZPVhkcGFF4cVGTozH7PvrtXff31pyRstUY4I/ESjAEKOy5bPp6AC83PVRjtWoIyO7QNZ6HZrod/4v1tGDtXZ7jMKLjba804tp+Y1Hs8UUogCXGo4IC2QboA8AOs1lv92kRRRgiJFkZBfY/Jm8EK3fWyb36hoTx7yOngdqTMmofoLkynQO1tbkBxRebBQc4Mm8jpzmOq2IHTeAeZ3reUXIybsjoBrl8Qgw+UUMz50TIkCzf2ySdcG+l577etxrXGy8rPWC6tbkW4iCwhk76qo9g3RlKLxwwDoFOmtzndbIGYG23dytAipRh60n7ZYzN+P5fxrn/SD6Emgyo6jEtuc0jBJciopLsfnsbeb19Dzh5HWG8cnKqD2DdGUovHAgp+OunEm/tELuCLSDP9fnDNuVsfXkffjqHbqNRBTH4zNnlOACyJuTrF8zfU842VnGKNhaHPJCv38BjWlam22uBzmTfmnJ3vf7MK9z9HK2oW6Z8DiJU4AhSjiRlkPB5T6DP90sa73/G6XfCSdzC4rB2uam1TmbKLxwsn4C+3ween6MuIGvu6z1Ws6Ud8LQKl4BZv+5mxyqIeRBgSazrLl67mek4JJfVIKjmexfpLR6IbfWE7O3MK+j1YFVKbxwIqdZTe+PEcs9mQVPN1ZrA4+T+rCV+6kVhnCVlVvE7TNlpOACyP8SpdULuTVyC4rB+tiE1h6PvheFF47k3EpJvnJLQCXK2TSRfdbd4lIgPStfQDXq4XVyDzSZcSIth8u2iP1qM3OLrBGxK2O04CI30On9fWgto9VFa49H34vCC0dybqVELE0QUIlyWjWUNwVD9wU7OFeiPl4nt6e/3EOtMESWsrFbbhUVc9me3i/Y90u7kSdrvY3junOuRFlyhqr4doS2+/ZQeOFs+yS2QesA6P6bttwTnBEv0DxP9oEmMxJOXuO2PWJsj0032zx2S5kXOgUYLrgAQM9FO2WtFxLow7cQhckZqqJHKz8BlfBD4YUzOYPW8ehMpzY5oQ0AggwaYFxY542owohvDyHQZJY1NgOxD3tOX0egyYw7nKZNOzdvIBYO6cBnYxpir7eL5Ex9oPVWF4DCixByLuR6f9pETmgDAAnAN/v+4FuMBpz7aBD2m/py216n+dsN2VJF5EvPykegyYxXVh/kts3UmEhdj2FSFbnHjtwvZVoipzVO660uAIUXIeRcyIet3C+gEmXJ/YbywU8nUVJq24ifWhTg48b9W1ugyYyNicYLe8R6JaUSgkxm7v3G9N7CUJVmNoR+uV/KtELOAyF6aHUBKLwIkzC1N/M6nTg9HaCmk7MjZK3XdPomzpVoB++LwqQfTyLQZMbWw1e4bpdo3/trj6Dp9E3MA41Vp72/k2GDS1ZuEYpkrmuE90TOAyF6aHUBKLwI07gO24i7AHA9V/8TGHq61UBjmV9WjHxbJDUmEmtHdeW6zbHfH0OgyYztR69y3S7RntXx5xBoMmNdUgbX7Z6ZOwDr32WfbFUv5D4urue5i8ocTMliXkfL47rcj8KLQEdn9GNexwgTGCbMkP+NxcgBpmuz2kK+zY1edxSBJjM2H0rnvm2irpU7ziLQZMbMzb9z33ZqTCTcXZy4b1cr5J5LhnZoaIh+P0NXJDKvo+VxXe6n/7+Qhvl6yptHXO8D1wG2NbkaOcAAd98bTg8jVTDuhyQEmsxYu/s8/40TxZSNjBtoMmPuVv5/y5/f7GGIWyLVseUcsuCFdhwrUYec16+3qQ8cJEkyVE9Ji8UCb29v5OTkwMtL3gBqPJWUSrL6cxjl5GLLScQo70FV0rPyhQ7W18DLBZsmhWlyRljyoCMX/sSQ5fuE7sPoxxRA5xy55xUtvHaW6ze1vAjm5OiAt/s2Yl6vrYyhnLXIlseFO8zVfwfm6jTwdUdqTCQ8BTXdp1uK0G7uVhrsTuPW7bmAQJNZaHBZNzpUExcn0ew9uADyRi/nOayDUmqoXYA9mNyvLZbGXWJax1JQjKzcItm3nrQiwMdN9ro384ow56dTmPXM4xwr0p6TcwcgK7eI21w0lRnx7SEAQEMvV5gn9aLWGJUdTMmS1SdBDqNclB/GluBihA66ANBaxnvg4uhg03laLXTbSCHJV27JemzNKCceW08sRuhAZ43445kYueawIvuKiWyBYU82VWRf5O5Ip7yG77dGwtTesp561CNbzi9/6/IoPnyuNcdq1JGTd0fWAx9ausawXL8pvChIzgHWtI4H4mSMGaNFtpxgzs8fBCdHAb1cNar9nC3IyuczuZ41Fj3dCs93D1Jsf/ZCdItaZVYO64A+IQGK7lNNtpxXajg6IGX+II7VqEfO+/DD2G7o0OQRAdXIQ+FFo+Elt6BY1rTkJ2dHwNPNGHf4bDnRfD4sBE+FNOBYjbbJ/SZlq06NvLB6TDdDP0Yr0unLFgz6fLfi+7XHljRbn0zUUquDLZ780IxLt9nX09rrp/Ci0fACAN1mm3GFfZ4szX3IbGHLCadXsA++Hq3v6elZpWTkInzJLtX2Hzu8I8La+qu2fz1Yu/s8TOazquzbCcB5A50frEXB5S65X4oPTQ+Hn5ergIrko/Ci4fACyDvoPJwdcfpDY3QqA2w78TgAuGiQEw8LJR6ltca3IzrrZghxUdbvS8Xkn06pXYahWmVZUHD5i5z3wt3ZEWc0eD2h8KLx8CK3454Wk7It6AQkj5JPqlijZd2a+O7NHoa9iJ5Iy8HTX+5Ru4xyHk5A4j/62+0TY3Te+Ivc90Kr7wGFF42HF8B4Hzq56EQkn9YuqvdbO6orujarrXYZTH7an4a3N55Qu4xKhT/2CL58tavdPHlXGTpf/OW6pRCd5m9nXm/v+33QwNddQEW2o/Cig/ACUIApQyck26jxRIutVgxtj/D29VTZ95qEFEzflKzKvuWwx464laHzREVy3o8ajkDKfO2+DxRedBJe5N4+Cg2qjf++wXeGYrXRiYkPvV2YSdWoo/RdPJ66M9r5wahffDUxPcBHH32Ebt26wcPDAz4+PlatI0kSZs6ciXr16sHd3R3h4eH4/Xf+s6lqhdxRDRMv3kR+UQnnatRl60EVaDKjqLiUUzX6NbxnMFJjInU1tT35Sx0PZxyd0Q+pMZEUXACEzo+j4HKfJz6QF1z0OAVAdYSFl6KiIrzwwgsYP3681essWLAAS5cuxbJly3DgwAHUrFkTERERKCiQ8WyxTsg9sFrO3My5EvXZepJpNuNXTPnuEKdq9K1zsC9SYyKRGhOJBU+1VLsc8hCrX+mE1JhIHJ7ZX/dTgvASaDLjqsW2c7/RgktO3h38eYd9PRcnfU4BUB3ht41iY2MxadIkZGdnV7ucJEmoX78+pkyZgqlTpwIAcnJy4O/vj9jYWAwbNsyq/enptlEZuc/pA8Y7OAHbbyEBxnxfbJVbUIxBc7cgjRqoNGHJs60RFfqo2mVoEp0DKmfU20VlNHHbiNXFixeRkZGB8PDw8t95e3ujS5cuSEys+rHQwsJCWCyWCj964+lWA4Eyc1bIHOVHYBWNx4HG4+RnNJ5uNZAw/25rzNEZ/dQuxy4tfubx8hYxCi4PysotouBSBbnvy5m5AzhXog2aGZghIyMDAODvX/E+r7+/f/m/VSY6Ohpz5swRWpsS4qdHyvpwZuffMcTs0/dLjZH3ftwr0GTG0Rn9DPfe8ODr6VLhBK/mCLFGZ29zDckVMmcrsvNl3BO5DwWXv4QG1TbsNB9MLS8mkwkODg7V/pw9q+wJcNq0acjJySn/uXTpkqL750nuQae3x2StxeMk1H7eNjSlVpiHGvZk0/IWgU0Tn1S7HF3r3MgLZ+YOKH8/Kbg8XKDJTMGlCuEL5HUpAGC4p1LvxdTyMmXKFIwcObLaZZo0aSKrkICAuwd4ZmYm6tX7a/yHzMxMhISEVLmeq6srXF2NM+rs0Rn9ZIWRQJPZkAcujxaYEhj3/RGhVUOvCu9VUmo2opbtVbEibWvk7Ypf3ulltyPe2kLucBGVMeLxnVtQjJQsebPLG/H9uBdTePHz84Ofn5g5TYKCghAQEIC4uLjysGKxWHDgwAGmJ5b0ztfTBa4ACmWsa9QLNI8AA9x9f3ZODkNQ3ZocqrIfIYE+D3yuvt97EVN/Pq1SRepa/Uon9GxdV+0ydK/ZPzahqITP8yJGPO8BkP0gh1H7udxL2NNGaWlpyMrKwk8//YSFCxdi9+67U8QHBwfD09MTANCiRQtER0fjueeeAwB8/PHHiImJwddff42goCB88MEHOH78OE6fPg03N+se89Lj00aVseVibdQDmWcnXKO+R2rSysSRvNSt6YzN74ZRnynO8otKuA31sOC5xzG0SyCXbWmN3PNd96a18Z8x+rxdpIkRdkeOHImvv/76gd/v3LkTYWFhd3fu4IBVq1aV34qSJAmzZs3C8uXLkZ2djR49euDLL79Es2bNrN6vUcILIP/D6+gAXIg25sX5iblb8Wee7ffGAW3P8WE0F6/dRu/F8WqX8QAael9Zz3+5F4fTsrls6/z8QXBydOCyLa2x1y+vmggvajFSeAHkf4h7PeaHr1/vzLkabeA9l4+eD3ZC9MCWsawqY+Rj1l6DC6DTcV5I5eSOx7Hr9+uGm0KgzP2P+doq0GTGxWu3uW2PEPKXPgt3UnCxkj0HF1YUXjTO19MFHjJbRo04hcC9eB6svRfH08B2hHCUkV2AQJMZF27mcdlekLeDoS/Qtpx/zs8fxLESfaDwogOnbei/YvQLcmpMJOp58ZuzI9BkxsGULG7bI8QeNZu+idsj0ABwcnYEdk4z7gW6+TT55+mFz7c1bN+f6lB40QlbvnEYPcAkTu+L32b257a9oSsSEWgyIz0rn9s2CbEHB1Oy7s7wXsqvK2VqTCQ83TQzGDx3gz7diUKZb5enixNe6NiIb0E6QR12dYbuiVZPRFCzh/eNEFukZ+Wj+4IdXLfp5Qgcn2/sY2/Ozyewam+a7PWNdm6iDrsGRi0w1UuNicTgkAZctxloMiP+eCbXbRJiBPlFJQgymbkHl99m9jd8cNl0/AoFFxtQy4tOUQtM9YqKS9Fsxq/ct7tudCg6B/ty3y4hevPMP3fj+FUL9+3S+enhjPoe0TgvdhBebB2l0qgf/vu1/OBX5N8p5b7dTROfRKuGxv18EVKV+OOZGLnmMPftbp/UC8EBnty3qzU//3YFb/33mOz1jXzupvBiB+EFAEb+ex/if/9T9vpGPgjudd1SiE7ztwvZtr2ccAnZc/o6Xll9UMi27eVcNCr2IHacvS57faO/TxRe7CS8AEATkxm2tCsY/WC4V9vZW2ApkDdD68NQiCFGte/sDQyPPSBk21ve7onm9WsJ2bbW9IiOw+WcAtnr28O5msKLHYUXwPaOuPZwUJThPbXA/ezpZEyMLeHkNYz49pCw7dvTeafZdDOKbPiWaS/vFYUXOwsvgO0B5ty8gXCpYT8Pn0389ih+OXlV2PY3juuOkEAfYdsnRJTtR69i9LqjwrafMLU3GtfxELZ9raEvl9aj8GKH4QWw/SD5W2hDfPhsO07VaJ+oJ5LutfqVTujZuq7QfRBiq/yiEoxdvhW7L/Pv3F6mbW3gp/fs50IMUHBhReHFTsMLYPvB4gLgnJ0dMCkZuQhfskvoPp5vWx/zh7azq9Yton1pN/LQc9FO4fs5M3cA3F2chO9HSyi4sKPwYsfhBeAzGJ09Hjgin6a4F3XuJWoT2Qn3Xvb6WafgIg+FFzsPLwAFGFu0mbkZt4pKhO9n5bAO6BMSIHw/hABATt4dRHy0FRniP9r4dkRn9GjlJ35HGkTBRT4KLxReAFCAsYXop5Lu5Qhgq51+QyXiKdXKAgDhwa5YMTpckX1pTW5BMVrP3mLTNuz1fFuGwguFl3I8Asz5+YPscsp1ALh47TZ6L45XbH9hTWrjXyM72V3/AMKX0p9bwP6eWLzXwCW7cCYj16Zt2HtwASi8UHi5j60D2QHAgucex9AugTzK0aUTaTl4+ss9iu6TbisRFlm5Reg5bxtsu4SyOzqjH3w9XRTeq3ZQCzc/FF4ovDygR8wOXM7Ot3k79n6QJaVmI2rZXsX3a899CEjV8otKMG7FduxKEzNydHUOTQ+Hn5er4vvVEgoufFF4ofBSqQ1H0/HuuiSbt0MHG3Dkwp8YsnyfKvuOHd4RYW39Vdk3UV9O3h0MjN6KK3fU2f9+U18E+Lips3ONsHVi3DJ0Lq2IwguFlyqVlEpoOn2Tzduxx3EbKqPG7aR7tavvhf+MDYWnWw3VaiDiqdGH5V7ODsCBf9j37aEyL//ffuw9f9Pm7VBweRCFFwovD8WjubN7oCf+M64Xh2r0T+2LSxmalsA4diRlYNTaI6rW0LNJbXxFHcjL8ThvAhRcqkLhhcKLVehA5C8n7w76z9uKTHGjrFvNw9kR294NQwNfd7VLIVZQqz9VZaizeEW8WqwBOl9Wh8ILhRer8Qow9vyYZFVET3DHisKMtqjZb6oqP4zthg5NHlG7DE35Zt8f+OCnk1y2RcGlehReKLww4RVgBrf3w+Khnblsy0jU7hdTnbWjuqJrs9pql2F4RcWl+PuG/dhw5E+1S3mAr7sztrzby+6fHKpMkMkMHhfIl7s2wEdRIRy2ZGwUXii8MBvw6S6czeQzQgR9u6icmo+1sqBv37bTQn+Vh6HH76vGY7TcMtQqbT0KLxReZOF5wNLjlNU7fdmCQZ/vVrsMq3m5OmHru2H0N72PllvVKlOvlivM7/Skp4aq0e+TXfj9On2RUwOFFwovNuF1Gwmgg9ca8cczMXLNYbXLkO3vfZthbN9gw04hkZFdgK4xcWqXYZOf3+yBNo291S5D04qKS9Fsxq/ctkfnPnYUXii82IxngLH34cOtVVRcivd+2Icfj+WoXQp3Wu1bk19UggmrdmLHxUK1S+Fuzcgu6Naijtpl6MLUdb/h+6OXuWyrjhtweDYFFzkovFB44aJbdByu5BRw2ZYTgPP0TcRqOXl3MGD+VlzVdvcYojE0+jIb3q0tv83sD28PZ27bszeaCC8fffQRzGYzkpKS4OLiguzs7IeuM3LkSHz99dcVfhcREYHNm60fhpnCC185eXfQbu5Wbtuj+VDY5ReVYPy/4xD/h0rjwRNNo4638kz5XxJ+OJbObXt0m8h2mggvs2bNgo+PDy5fvox///vfVoeXzMxMrFq1qvx3rq6ueOQR6598oPAiBs/bSAAd6LZIOHkNI749pHYZRCUOAMwTn0SrhnR+k4Pngwll6HzGB8v1W9iEKHPmzAEAxMbGMq3n6uqKgAAa2VFrUmMi0WrmZuQVlXDZXqDJjJ2TwxBUtyaX7dmTnq3rlp8s07Py0X3BDpUrIqJN6RmMNwc0M2ynaKX0XbQT52/kcdsetSSrR3OzucXHx6Nu3bp45JFH0KdPH8ybNw+1a1fd0a+wsBCFhX91trNYLEqUaZdOzx2A65ZCdJq/ncv2yuYConEQ5Gvg617hW5+Whpgn8tFoyHyJCPnU2qIu4R12Y2NjMWnSJKtuG61duxYeHh4ICgrC+fPnMX36dHh6eiIxMRFOTpVPDDZ79uzyVp570W0jsXjfRnq6jQ8+e7k7120S/Y1DYq9oHB0xSkolPPaPTSjleJXzqQEkzaPgIoKwPi8mkwkff/xxtcucOXMGLVq0KP9/lvByvwsXLqBp06bYvn07+vbtW+kylbW8NGrUiMKLAl76aj8SL9o+Nfy9Eqb2RuM6Hly3Sf6SlVuEsHnbQO2T6qKngsT7Ii4FC7clc90mPU0klrDwcv36ddy8Wf3FqkmTJnBx+WtMD1vCCwD4+flh3rx5eOONN6xanjrsKiu/qAQtZ1r/NJi1zswdAHeXylvbCF/XLYXoOn87+PRmIvejGZqVlXYjDz0X7eS+XbpNJJ6wDrt+fn7w81PukbzLly/j5s2bqFevnmL7JGzcXZyQGhOJrvPjkGHhMyYMALScuRlP1HXEhskDuW2TVM7Py7XSMXi0OOuxllE/FXWJeIoIAPa+34f+phokrM9LWloasrKy8NNPP2HhwoXYvfvuPC7BwcHw9PQEALRo0QLR0dF47rnnkJubizlz5mDIkCEICAjA+fPn8f777+PWrVs4ceIEXF2t69FNLS/q4T0mTJmN47ojJNCH+3aJfHqf0sAW9HnUnt4Ld+LiTX5PEZWh1hZlaWKcl8oGnAOAnTt3Iiws7O7OHRywatUqjBw5Evn5+YiKisKxY8eQnZ2N+vXro3///vjwww/h72/9vWEKL+rr8tE2ZN4q4r7d7ZN6ITjAk/t2CX85eXcQ+fFWXNbZqPs04Ju+HEzJwtAVidy3S8M4qEMT4UUtFF60QVQrDEAzVhNi70Q+RUetLeqh8ELhRTOG/isRB//IErJtGiCKEPty+rIFgz7fLWTb1LdFfRReKLxoiqgnksrQrNWEGFvylVuIWJogZNueAE5Sa4smUHih8KJJKRm5CF+yS9j2KcQQYiyizxk0bou2UHih8KJpbWdvgaWgWNj2KcQQom8iW1oAYBNNbKlJFF4ovGheVm4R2s/bJnQf1LGXEH1RYjoL6pCrXRReKLzohuhvWAA9Yk2I1ol65Ple1CKrfRReKLzozlvfHsXPJ68K3QcNLkaIdhQVl2L6T4fw/cEbQvdDX170g8ILhRddKiouRevZm1FULPYjOSOiJV7rFQQnRweh+yGEPEipubTWjQ5F52BfwXshPFF4ofCia0r0hwGAWq6O2DGlD40VQ4gClLhFDADt6gA/TqV+LXpE4YXCiyGImh22Mj+/2QNtGnsrsi9C7EVJqYTYhGR8uPm8Ivuj2ej1jcILhRdDUeobGwC0b+SD1a93gacb04TrhJB7KPnFA6DRto2CwguFF0MSOTR4Zag1hhDrlZRKWL37HOb8mqLYPmk4BGOh8ELhxdCUGAviXnVqOuPXd3rRNztCKqH0lwqA5iEyKgovFF7sQlJqNqKW7VV0n/SkEiF3nxh6cv52FCi8XwotxkbhhcKLXVG6JabMD2O7oUOTRxTfLyFqyC0oxogvt+PoNdEPOT+Ibg/ZBwovFF7skuhJ3KpDA+ARIyoqLsU/fj6M7w5cV3zfro5A4nQaFdeeUHih8GLX0rPy0X3BDtX2T0GG6JmagQUAwpvXxWcvt6dHnu0QhRcKLwR3m7lf+Xwrkm6o9xGnUT6JHuTk3cGzn2xF6m31alg7qiu6NqutXgFEdRReKLyQ++w7ewPDYw+oWkMzv5r4bnx3eHs4q1oHIYDyY7FUZcvbPdG8fi21yyAaQOGFwgupgpID3lXHEcCvdNImCtt/7iaGrdyvdhk0/ACpFIUXCi/kIXLy7uCZRVvxR57aldxFrTJEhIvXbqP34ni1yyhHt4ZIdSi8UHghDNQYL+ZhIlr6Y8lLT1CnRcIkI7sAoTFx0NJJ3dXJAdveDUPjOh5ql0I0jsILhRcig9ZaY+7Vr0VdLB1OT2CQitR+sq461MpCWFF4ofBCbKTmmDHWoG+z9kmLrYT3auDthp/fepLGZiGyUHih8EI4OpiShaErEtUu46Ga1/XEunHdqN+MQWjlaaCHcXIAtrzTC8EBnmqXQnSOwguFFyKA2oN3yeFX0xmb6KkOzdPKU3AsaDBGwhuFFwovRDA153nh4YOBLTHySZpgUmnXLYXoNX87NNityio0nxcRicILhReiIL0HmXu5OTlg67vUl8ZWWu+bwoICC1EKhRcKL0QlJaUSvt3zO2Zt+l3tUoRo5OOGHydSh0ytjZ/CkwMA88Qn0aohnT+Jsii8UHghGnH6sgWDPt+tdhmK02sLzom0HDz95R61y1BcTWdHbH03DA183dUuhdgxCi8UXogGZeUWoV/0NtzU/90lYgA0DgvRGpbrt6OoIlJTU/H6668jKCgI7u7uaNq0KWbNmoWioqJq1ysoKMCECRNQu3ZteHp6YsiQIcjMzBRVJiGK8fV0wZGPIpEac/dny9s91S6J2BFvNyfsN/Ut//xRcCF6VkPUhs+ePYvS0lJ89dVXCA4OxsmTJzFmzBjcvn0bixYtqnK9d999F2azGd999x28vb0xceJEDB48GHv3GqPzGyFlmtevhdSYyPL/N1InT6I+GsiQGJmit40WLlyIf/3rX7hw4UKl/56TkwM/Pz+sWbMGzz//PIC7Iahly5ZITExE165dH7oPum1EjMJe+18QefTaz4iQMizXb2EtL5XJycmBr69vlf9+5MgR3LlzB+Hh4eW/a9GiBRo3blxleCksLERhYWH5/1ssFr5FE6KSNo29K7TMZGQXoFtMHEpVrIlox986N8YHzzwOlxrC7v4TolmKhZeUlBR89tln1d4yysjIgIuLC3x8fCr83t/fHxkZGZWuEx0djTlz5vAslRBNCvBxw4V7wgxgv08z2RtvNydsmRSGAB83tUshRBOYw4vJZMLHH39c7TJnzpxBixYtyv8/PT0dAwYMwAsvvIAxY8awV1mNadOmYfLkyeX/b7FY0KhRI677IESrWjX0qtA6Axh7DBJ70MyvJr4b353mqCKkGszhZcqUKRg5cmS1yzRp0qT8v69cuYLevXujW7duWL58ebXrBQQEoKioCNnZ2RVaXzIzMxEQEFDpOq6urnB1pXlbCCkTVLfmA4EGAI5c+BNDlu9ToSJSGWdHYOukMATVral2KYToDnN48fPzg5+fn1XLpqeno3fv3ujQoQNWrVoFR8fq78126NABzs7OiIuLw5AhQwAAycnJSEtLQ2hoKGuphJB7dGjySKWhhvrSiPVql0fxj6dbUd8UQjgS9rRReno6wsLC8Oijj+Lrr7+Gk5NT+b+VtaKkp6ejb9++WL16NTp37gwAGD9+PDZt2oTY2Fh4eXnhrbfeAgDs22fdN0Z62ogQvlIychG+ZJfaZWgWTZlACB+aeNpo27ZtSElJQUpKCho2bFjh38ry0p07d5CcnIy8vL/mWP3000/h6OiIIUOGoLCwEBEREfjyyy9FlUkIeYjgAM9KW2zuZ7QWHGoxIUS7aHoAQgghhKhOE9MDEEIIIYSIQOGFEEIIIbpC4YUQQgghukLhhRBCCCG6QuGFEEIIIbpC4YUQQgghukLhhRBCCCG6QuGFEEIIIbpC4YUQQgghuiJsegC1lA0YbLFYVK6EEEIIIdYqu25bM/C/4cLLrVu3AACNGjVSuRJCCCGEsLp16xa8vb2rXcZwcxuVlpbiypUrqFWrFhwcHLhu22KxoFGjRrh06ZIh500y+usDjP8a6fXpn9FfI70+/RP1GiVJwq1bt1C/fn04Olbfq8VwLS+Ojo4PzGLNm5eXl2E/lIDxXx9g/NdIr0//jP4a6fXpn4jX+LAWlzLUYZcQQgghukLhhRBCCCG6QuGFgaurK2bNmgVXV1e1SxHC6K8PMP5rpNenf0Z/jfT69E8Lr9FwHXYJIYQQYmzU8kIIIYQQXaHwQgghhBBdofBCCCGEEF2h8EIIIYQQXaHwUo3U1FS8/vrrCAoKgru7O5o2bYpZs2ahqKio2vUKCgowYcIE1K5dG56enhgyZAgyMzMVqprNRx99hG7dusHDwwM+Pj5WrTNy5Eg4ODhU+BkwYIDYQmWS8/okScLMmTNRr149uLu7Izw8HL///rvYQm2QlZWFl19+GV5eXvDx8cHrr7+O3NzcatcJCwt74G84btw4hSqu3hdffIHAwEC4ubmhS5cuOHjwYLXLf/fdd2jRogXc3NzQpk0bbNq0SaFK5WN5jbGxsQ/8rdzc3BSslk1CQgKefvpp1K9fHw4ODti4ceND14mPj0f79u3h6uqK4OBgxMbGCq9TLtbXFx8f/8Dfz8HBARkZGcoUzCg6OhqdOnVCrVq1ULduXURFRSE5Ofmh6yl9HFJ4qcbZs2dRWlqKr776CqdOncKnn36KZcuWYfr06dWu9+677+Lnn3/Gd999h127duHKlSsYPHiwQlWzKSoqwgsvvIDx48czrTdgwABcvXq1/Oe///2voAptI+f1LViwAEuXLsWyZctw4MAB1KxZExERESgoKBBYqXwvv/wyTp06hW3btuGXX35BQkICxo4d+9D1xowZU+FvuGDBAgWqrd7//vc/TJ48GbNmzcLRo0fRrl07RERE4Nq1a5Uuv2/fPrz00kt4/fXXcezYMURFRSEqKgonT55UuHLrsb5G4O5Ipvf+rf744w8FK2Zz+/ZttGvXDl988YVVy1+8eBGRkZHo3bs3kpKSMGnSJIwePRpbtmwRXKk8rK+vTHJycoW/Yd26dQVVaJtdu3ZhwoQJ2L9/P7Zt24Y7d+6gf//+uH37dpXrqHIcSoTJggULpKCgoCr/PTs7W3J2dpa+++678t+dOXNGAiAlJiYqUaIsq1atkry9va1a9tVXX5WeffZZofXwZu3rKy0tlQICAqSFCxeW/y47O1tydXWV/vvf/wqsUJ7Tp09LAKRDhw6V/+7XX3+VHBwcpPT09CrX69Wrl/TOO+8oUCGbzp07SxMmTCj//5KSEql+/fpSdHR0pcsPHTpUioyMrPC7Ll26SG+88YbQOm3B+hpZjk2tASBt2LCh2mXef/996fHHH6/wuxdffFGKiIgQWBkf1ry+nTt3SgCkP//8U5GaeLt27ZoEQNq1a1eVy6hxHFLLC6OcnBz4+vpW+e9HjhzBnTt3EB4eXv67Fi1aoHHjxkhMTFSiREXEx8ejbt26aN68OcaPH4+bN2+qXRIXFy9eREZGRoW/n7e3N7p06aLJv19iYiJ8fHzQsWPH8t+Fh4fD0dERBw4cqHbd//znP6hTpw5at26NadOmIS8vT3S51SoqKsKRI0cqvPeOjo4IDw+v8r1PTEyssDwAREREaPJvBch7jQCQm5uLRx99FI0aNcKzzz6LU6dOKVGuIvT2N5QrJCQE9erVQ79+/bB37161y7FaTk4OAFR73VPjb2i4iRlFSklJwWeffYZFixZVuUxGRgZcXFwe6F/h7++v2XucrAYMGIDBgwcjKCgI58+fx/Tp0zFw4EAkJibCyclJ7fJsUvY38vf3r/B7rf79MjIyHmh+rlGjBnx9fautd/jw4Xj00UdRv359HD9+HH//+9+RnJyM9evXiy65Sjdu3EBJSUml7/3Zs2crXScjI0M3fytA3mts3rw5Vq5cibZt2yInJweLFi1Ct27dcOrUKeGT0Cqhqr+hxWJBfn4+3N3dVaqMj3r16mHZsmXo2LEjCgsLsWLFCoSFheHAgQNo37692uVVq7S0FJMmTUL37t3RunXrKpdT4zi0y5YXk8lUaQeqe3/uP5Gkp6djwIABeOGFFzBmzBiVKreOnNfHYtiwYXjmmWfQpk0bREVF4ZdffsGhQ4cQHx/P70VUQ/Tr0wLRr3Hs2LGIiIhAmzZt8PLLL2P16tXYsGEDzp8/z/FVEB5CQ0MxYsQIhISEoFevXli/fj38/Pzw1VdfqV0asULz5s3xxhtvoEOHDujWrRtWrlyJbt264dNPP1W7tIeaMGECTp48ibVr16pdygPssuVlypQpGDlyZLXLNGnSpPy/r1y5gt69e6Nbt25Yvnx5tesFBASgqKgI2dnZFVpfMjMzERAQYEvZVmN9fbZq0qQJ6tSpg5SUFPTt25fbdqsi8vWV/Y0yMzNRr1698t9nZmYiJCRE1jblsPY1BgQEPNDRs7i4GFlZWUyfty5dugC427rYtGlT5np5qFOnDpycnB54Mq+6YycgIIBpebXJeY33c3Z2xhNPPIGUlBQRJSquqr+hl5eX7ltdqtK5c2fs2bNH7TKqNXHixPIHAB7WwqfGcWiX4cXPzw9+fn5WLZueno7evXujQ4cOWLVqFRwdq2+s6tChA5ydnREXF4chQ4YAuNvLPC0tDaGhoTbXbg2W18fD5cuXcfPmzQoXe5FEvr6goCAEBAQgLi6uPKxYLBYcOHCA+YksW1j7GkNDQ5GdnY0jR46gQ4cOAIAdO3agtLS0PJBYIykpCQAU+xtWxsXFBR06dEBcXByioqIA3G22jouLw8SJEytdJzQ0FHFxcZg0aVL577Zt26bYscZKzmu8X0lJCU6cOIFBgwYJrFQ5oaGhDzxWq+W/IQ9JSUmqHmvVkSQJb731FjZs2ID4+HgEBQU9dB1VjkNhXYEN4PLly1JwcLDUt29f6fLly9LVq1fLf+5dpnnz5tKBAwfKfzdu3DipcePG0o4dO6TDhw9LoaGhUmhoqBov4aH++OMP6dixY9KcOXMkT09P6dixY9KxY8ekW7dulS/TvHlzaf369ZIkSdKtW7ekqVOnSomJidLFixel7du3S+3bt5cee+wxqaCgQK2XUSXW1ydJkhQTEyP5+PhIP/74o3T8+HHp2WeflYKCgqT8/Hw1XsJDDRgwQHriiSekAwcOSHv27JEee+wx6aWXXir/9/s/oykpKdLcuXOlw4cPSxcvXpR+/PFHqUmTJlLPnj3Vegnl1q5dK7m6ukqxsbHS6dOnpbFjx0o+Pj5SRkaGJEmS9Le//U0ymUzly+/du1eqUaOGtGjRIunMmTPSrFmzJGdnZ+nEiRNqvYSHYn2Nc+bMkbZs2SKdP39eOnLkiDRs2DDJzc1NOnXqlFovoVq3bt0qP84ASIsXL5aOHTsm/fHHH5IkSZLJZJL+9re/lS9/4cIFycPDQ3rvvfekM2fOSF988YXk5OQkbd68Wa2XUC3W1/fpp59KGzdulH7//XfpxIkT0jvvvCM5OjpK27dvV+slVGv8+PGSt7e3FB8fX+Gal5eXV76MFo5DCi/VWLVqlQSg0p8yFy9elABIO3fuLP9dfn6+9Oabb0qPPPKI5OHhIT333HMVAo+WvPrqq5W+vntfDwBp1apVkiRJUl5entS/f3/Jz89PcnZ2lh599FFpzJgx5SderWF9fZJ093HpDz74QPL395dcXV2lvn37SsnJycoXb6WbN29KL730kuTp6Sl5eXlJr732WoVwdv9nNC0tTerZs6fk6+srubq6SsHBwdJ7770n5eTkqPQKKvrss8+kxo0bSy4uLlLnzp2l/fv3l/9br169pFdffbXC8uvWrZOaNWsmubi4SI8//rhkNpsVrpgdy2ucNGlS+bL+/v7SoEGDpKNHj6pQtXXKHg2+/6fsNb366qtSr169HlgnJCREcnFxkZo0aVLheNQa1tf38ccfS02bNpXc3NwkX19fKSwsTNqxY4c6xVuhqmvevX8TLRyHDv+/WEIIIYQQXbDLp40IIYQQol8UXgghhBCiKxReCCGEEKIrFF4IIYQQoisUXgghhBCiKxReCCGEEKIrFF4IIYQQoisUXgghhBCiKxReCCGEEKIrFF4IIYQQoisUXgghhBCiKxReCCGEEKIr/w9sBbL6I3bHFQAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# from sklearn.utils import shuffle as util_shuffle\n", - "np.random.seed(192)\n", - "def make_circles(batch_size: int=1000, rng=None) -> np.ndarray:\n", - " n_samples4 = n_samples3 = n_samples2 = batch_size // 4\n", - " n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2\n", - "\n", - " # so as not to have the first point = last point, we set endpoint=False\n", - " linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False)\n", - " linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False)\n", - " linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False)\n", - " linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False)\n", - "\n", - " circ4_x = np.cos(linspace4)\n", - " circ4_y = np.sin(linspace4)\n", - " circ3_x = np.cos(linspace4) * 0.75\n", - " circ3_y = np.sin(linspace3) * 0.75\n", - " circ2_x = np.cos(linspace2) * 0.5\n", - " circ2_y = np.sin(linspace2) * 0.5\n", - " circ1_x = np.cos(linspace1) * 0.25\n", - " circ1_y = np.sin(linspace1) * 0.25\n", - "\n", - " X = np.vstack([\n", - " np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]),\n", - " np.hstack([circ4_y, circ3_y, circ2_y, circ1_y])\n", - " ]).T * 3.0\n", - " # X = util_shuffle(X, random_state=rng)\n", - "\n", - " # Add noise\n", - " # X = X + rng.normal(scale=0.08, size=X.shape)\n", - "\n", - " return X.astype(\"float32\")\n", - "\n", - "\n", - "def transform(data: np.ndarray) -> np.ndarray:\n", - " assert data.shape[1] == 2\n", - " data[:, 0] = data[:, 0] / np.max(np.abs(data[:, 0]))\n", - " data[:, 1] = data[:, 1] / np.max(np.abs(data[:, 1]))\n", - " # data[:, 2] = data[:, 2] / np.max(np.abs(data[:, 2]))\n", - " data = (data - data.min()) / (data.max()\n", - " - data.min()) # Towards [0, 1]\n", - " data = data * 4 - 2 # [-1, 1]\n", - " return data\n", - "# get data from sklearn\n", - "data = make_circles(100000)\n", - "data = transform(data)\n", - "data = data.astype(np.float32)\n", - "plot2d(data)\n", - "def get_train_data(dataloader):\n", - " while True:\n", - " yield from dataloader" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "def task_run_flow_model(config, data, flow_model_type=IndependentConditionalFlowModel):\n", - " seed_value = set_seed()\n", - "\n", - " flow_model = flow_model_type(config=config.flow_model).to(config.flow_model.device)\n", - " flow_model = torch.compile(flow_model)\n", - "\n", - " data_loader = torch.utils.data.DataLoader(data, batch_size=config.parameter.batch_size, shuffle=True)\n", - " data_generator = get_train_data(data_loader)\n", - "\n", - " optimizer = torch.optim.Adam(\n", - " flow_model.parameters(),\n", - " lr=config.parameter.lr,\n", - " # weight_decay=config.parameter.weight_decay,\n", - " )\n", - " for iteration in track(range(config.parameter.iterations), description=config.project):\n", - "\n", - " batch_data = next(data_generator).to(config.device)\n", - " flow_model.train()\n", - " if config.parameter.training_loss_type == \"flow_matching\":\n", - " x0 = flow_model.gaussian_generator(batch_data.shape[0]).to(config.device)\n", - " loss = flow_model.flow_matching_loss(x0=x0, x1=batch_data)\n", - " else:\n", - " raise NotImplementedError(\n", - " f\"Unknown loss type {config.parameter.training_loss_type}, we need flow matching.\"\n", - " )\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " return flow_model\n", - "\n", - "def task_run_diffusion_model(config, data):\n", - " seed_value = set_seed()\n", - "\n", - " diffusion_model = DiffusionModel(config=config.diffusion_model).to(config.diffusion_model.device)\n", - " diffusion_model = torch.compile(diffusion_model)\n", - "\n", - " data_loader = torch.utils.data.DataLoader(data, batch_size=config.parameter.batch_size, shuffle=True)\n", - " data_generator = get_train_data(data_loader)\n", - "\n", - " optimizer = torch.optim.Adam(\n", - " diffusion_model.parameters(),\n", - " lr=config.parameter.lr,\n", - " # weight_decay=config.parameter.weight_decay,\n", - " )\n", - " for iteration in track(range(config.parameter.iterations), description=config.project):\n", - "\n", - " batch_data = next(data_generator).to(config.device)\n", - " diffusion_model.train()\n", - " if config.parameter.training_loss_type == \"flow_matching\":\n", - " loss = diffusion_model.flow_matching_loss(batch_data)\n", - " elif config.parameter.training_loss_type == \"score_matching_maximum_likelihhood\":\n", - " loss = diffusion_model.score_matching_loss(batch_data)\n", - " elif config.parameter.training_loss_type == \"score_matching\":\n", - " loss = diffusion_model.score_matching_loss(batch_data, weighting_scheme=\"vanilla\")\n", - " else:\n", - " raise NotImplementedError(\n", - " f\"Unknown loss type {config.parameter.training_loss_type}, we need flow matching or score matching.\"\n", - " )\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " return diffusion_model" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "45717a201101413eb7717ce540e9fa21", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b434ca73e5ec4cb797efa1054c61aa7f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "diffusion_model = task_run_diffusion_model(diffusion_model_config, data)\n", - "flow_model = task_run_flow_model(flow_model_config, data, IndependentConditionalFlowModel)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "diffusion_model: shape of x: torch.Size([1, 2])\n", - "diffusion_model: likelihood: tensor([0.1089], device='cuda:0')\n", - "flow_model: shape of x: torch.Size([1, 2])\n", - "flow_model: likelihood: tensor([0.0774], device='cuda:0')\n" - ] - } - ], - "source": [ - "t_span = torch.linspace(0.0, 1.0, 1000)\n", - "# diffusion model\n", - "\n", - "x = diffusion_model.sample(t_span=t_span, batch_size=1)\n", - "print(f\"diffusion_model: shape of x: {x.shape}\")\n", - "with torch.no_grad():\n", - " logp = compute_likelihood(diffusion_model, x)\n", - " print(f\"diffusion_model: likelihood: {torch.exp(logp)}\")\n", - "\n", - "# flow model\n", - "x = flow_model.sample(t_span=t_span, batch_size=1)\n", - "print(f\"flow_model: shape of x: {x.shape}\")\n", - "with torch.no_grad():\n", - " logp = compute_likelihood(flow_model, x)\n", - " print(f\"flow_model: likelihood: {torch.exp(logp)}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "diffusion_model: shape of x: torch.Size([1, 2])\n", - "diffusion_model: likelihood: tensor([0.1333], device='cuda:0', grad_fn=)\n", - "flow_model: shape of x: torch.Size([1, 2])\n", - "flow_model: likelihood: tensor([0.1217], device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "# diffusion model\n", - "\n", - "x = diffusion_model.sample(t_span=t_span, batch_size=1)\n", - "print(f\"diffusion_model: shape of x: {x.shape}\")\n", - "logp = compute_likelihood(diffusion_model, x)\n", - "print(f\"diffusion_model: likelihood: {torch.exp(logp)}\")\n", - "\n", - "# test if the tensor has grad_fn\n", - "assert logp.grad_fn is not None\n", - "\n", - "# flow model\n", - "\n", - "x = flow_model.sample(t_span=t_span, batch_size=1)\n", - "print(f\"flow_model: shape of x: {x.shape}\")\n", - "logp = compute_likelihood(flow_model, x)\n", - "print(f\"flow_model: likelihood: {torch.exp(logp)}\")\n", - "\n", - "# test if the tensor has grad_fn\n", - "assert logp.grad_fn is not None" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "diffusion_model: shape of x: torch.Size([1, 2])\n", - "diffusion_model: likelihood: tensor([0.1349], device='cuda:0', grad_fn=)\n", - "flow_model: shape of x: torch.Size([1, 2])\n", - "flow_model: likelihood: tensor([0.0012], device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "# diffusion model\n", - "\n", - "x = diffusion_model.sample(t_span=t_span, batch_size=1)\n", - "print(f\"diffusion_model: shape of x: {x.shape}\")\n", - "logp = compute_likelihood(diffusion_model, x, using_Hutchinson_trace_estimator=True)\n", - "print(f\"diffusion_model: likelihood: {torch.exp(logp)}\")\n", - "\n", - "# test if the tensor has grad_fn\n", - "assert logp.grad_fn is not None\n", - "\n", - "# flow model\n", - "\n", - "x = flow_model.sample(t_span=t_span, batch_size=1)\n", - "print(f\"flow_model: shape of x: {x.shape}\")\n", - "logp = compute_likelihood(diffusion_model, x, using_Hutchinson_trace_estimator=True)\n", - "print(f\"flow_model: likelihood: {torch.exp(logp)}\")\n", - "\n", - "# test if the tensor has grad_fn\n", - "assert logp.grad_fn is not None\n" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d02c87b5b8ee4226b52b091ba4114b4c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "def density_flow_of_generative_model(model):\n", - "\n", - " model.eval()\n", - " x_range = torch.linspace(-4, 4, 100, device=model.device)\n", - " y_range = torch.linspace(-4, 4, 100, device=model.device)\n", - " xx, yy = torch.meshgrid(x_range, y_range)\n", - " z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)\n", - "\n", - " indexs = torch.arange(0, z_grid.shape[0], device=model.device)\n", - " memory = 0.01\n", - "\n", - " p_list = []\n", - " for t in track(range(100), description=\"Density Flow Training\"):\n", - " t_span = torch.linspace(0.01 * t, 1, 101 - t, device=model.device)\n", - " logp_list = []\n", - " for ii in torch.split(indexs, int(z_grid.shape[0] * memory)):\n", - " logp_ii = compute_likelihood(\n", - " model=model,\n", - " x=z_grid[ii],\n", - " t=t_span,\n", - " using_Hutchinson_trace_estimator=True,\n", - " )\n", - " logp_list.append(logp_ii.unsqueeze(0))\n", - " logp = torch.cat(logp_list, 1)\n", - " p = torch.exp(logp).reshape(100, 100)\n", - " p_list.append(p)\n", - "\n", - " return p_list\n", - "p_list = density_flow_of_generative_model(diffusion_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
[11:34:18] INFO     Animation.save using <class 'matplotlib.animation.FFMpegWriter'>              animation.py:1060\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m[11:34:18]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Animation.save using \u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'matplotlib.animation.FFMpegWriter'\u001b[0m\u001b[1m>\u001b[0m \u001b]8;id=198352;file:///opt/conda/lib/python3.10/site-packages/matplotlib/animation.py\u001b\\\u001b[2manimation.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=80104;file:///opt/conda/lib/python3.10/site-packages/matplotlib/animation.py#1060\u001b\\\u001b[2m1060\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
           INFO     MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s      animation.py:338\n",
-       "                    700x600 -pix_fmt rgba -framerate 20 -loglevel error -i pipe: -vcodec h264                      \n",
-       "                    -pix_fmt yuv420p -y ./video-diffusion/density_flow_diffuse.mp4                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s \u001b]8;id=561233;file:///opt/conda/lib/python3.10/site-packages/matplotlib/animation.py\u001b\\\u001b[2manimation.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=842454;file:///opt/conda/lib/python3.10/site-packages/matplotlib/animation.py#338\u001b\\\u001b[2m338\u001b[0m\u001b]8;;\u001b\\\n", - "\u001b[2;36m \u001b[0m 70\u001b[1;36m0x600\u001b[0m -pix_fmt rgba -framerate \u001b[1;36m20\u001b[0m -loglevel error -i pipe: -vcodec h264 \u001b[2m \u001b[0m\n", - "\u001b[2;36m \u001b[0m -pix_fmt yuv420p -y .\u001b[35m/video-diffusion/\u001b[0m\u001b[95mdensity_flow_diffuse.mp4\u001b[0m \u001b[2m \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "def render_density_flow_video(p_list, video_save_path, fps=20, dpi=100, generate_path=True):\n", - " if not os.path.exists(video_save_path):\n", - " os.makedirs(video_save_path)\n", - "\n", - " fig, ax = plt.subplots(figsize=(7, 6))\n", - " plt.xlim([-4, 4])\n", - " plt.ylim([-4, 4])\n", - "\n", - " ims = []\n", - " colors = np.linspace(0, 1, len(p_list))\n", - "\n", - " # Assuming p_list contains 2D arrays of the same shape\n", - " x = np.linspace(-4, 4, p_list[0].shape[1])\n", - " y = np.linspace(-4, 4, p_list[0].shape[0])\n", - " X, Y = np.meshgrid(x, y)\n", - "\n", - " cbar = None # Initialize color bar\n", - "\n", - " if generate_path:\n", - " enumerate_items = list(enumerate(p_list))[::-1]\n", - " enumerate_items = enumerate_items[:-1]\n", - " # enumerate_items = enumerate_items\n", - " else:\n", - " enumerate_items = list(enumerate(p_list))[1:]\n", - "\n", - " for i, p in enumerate_items:\n", - " p_max = 0.2\n", - " p_min = 0.0\n", - "\n", - " im = ax.pcolormesh(\n", - " Y, X, p.cpu().detach().numpy(),\n", - " cmap=\"viridis\",\n", - " vmin=p_min, vmax=p_max,\n", - " shading='auto'\n", - " )\n", - " title = ax.text(0.5, 1.05, f't={colors[i]:.2f}', size=plt.rcParams[\"axes.titlesize\"], ha=\"center\", transform=ax.transAxes)\n", - "\n", - " # Remove the previous color bar if it exists\n", - " if cbar:\n", - " cbar.remove()\n", - "\n", - " # Adding the colorbar inside the loop to update it each frame\n", - " cbar = fig.colorbar(im, ax=ax)\n", - " cbar.set_label('Density')\n", - "\n", - " ims.append([im, title])\n", - "\n", - " ani = animation.ArtistAnimation(fig, ims, interval=20/fps, blit=True)\n", - " ani.save(\n", - " os.path.join(video_save_path, f\"density_flow_diffuse.mp4\"),\n", - " fps=fps,\n", - " dpi=dpi,\n", - " )\n", - "\n", - " # clean up\n", - " plt.close(fig)\n", - " plt.clf()\n", - "render_density_flow_video(p_list=p_list, video_save_path=diffusion_model_config.parameter.video_save_path, generate_path=False)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "base", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 5a2831d..2fc0e34 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Literal +from typing import Optional, Tuple, Union from dataclasses import dataclass import numpy as np @@ -10,13 +10,16 @@ from easydict import EasyDict from functools import partial -from .edm_preconditioner import PreConditioner -from .edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV, SCALE_T, SCALE_T_DERIV, INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN -from .edm_utils import DEFAULT_SOLVER_PARAM from grl.generative_models.intrinsic_model import IntrinsicModel +from grl.utils import find_parameters from grl.utils import set_seed from grl.utils.log import log +from .edm_preconditioner import PreConditioner +from .edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV +from .edm_utils import SCALE_T, SCALE_T_DERIV +from .edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN, DEFAULT_SOLVER_PARAM + class Simple(nn.Module): def __init__(self): super().__init__() @@ -31,16 +34,28 @@ def forward(self, x, noise, class_labels=None): return self.model(x) class EDMModel(nn.Module): - + """ + Overview: + An implementation of EDM, which eludicates diffusion based generative model through preconditioning, training, sampling. + This implementation supports 4 types: `VP_edm`(DDPM-SDE), `VE_edm` (SGM-SDE), `iDDPM_edm`, `EDM`. More details see Table 1 in paper + EDM class utilizes different params and executes different scheules during precondition, training and sample process. + Sampling supports 1st order Euler step and 2nd order Heun step as Algorithm 1 in paper. + For EDM type itself, stochastic sampler as Algorithm 2 in paper is also supported. + Interface: + ``__init__``, ``forward``, ``sample`` + Reference: + EDM original paper: https://arxiv.org/abs/2206.00364 + Code reference: https://github.com/NVlabs/edm + """ def __init__(self, config: Optional[EasyDict]=None) -> None: super().__init__() - self.config= config + self.config = config # self.x_size = config.x_size self.device = config.device # EDM Type ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] - self.edm_type: str = config.edm_model.path.edm_type + self.edm_type = config.edm_model.path.edm_type assert self.edm_type in ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"], \ f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}" @@ -58,7 +73,8 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: #* 3. Solver setup self.solver_type = config.edm_model.solver.solver_type - assert self.solver_type in ['euler', 'heun'] + assert self.solver_type in ['euler', 'heun'], \ + f"Your solver type should in ['euler', 'heun'], but got {self.solver_type}" self.solver_params = DEFAULT_SOLVER_PARAM self.solver_params.update(config.edm_model.solver.params) @@ -70,7 +86,7 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: self.sigma_max = INITIAL_SIGMA_MAX[self.edm_type] if "sigma_max" not in self.params else self.params.sigma_max - def get_type(self): + def get_type(self) -> str: return "EDMModel" # For VP_edm @@ -102,7 +118,7 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso def forward(self, x: Tensor, - class_labels=None) -> Tensor: + class_labels: Tensor=None) -> Tensor: x = x.to(self.device) sigma, weight = self._sample_sigma_weight_train(x, **self.params) n = torch.randn_like(x) * sigma @@ -111,7 +127,7 @@ def forward(self, return loss - def _get_sigma_steps_t_steps(self, num_steps=18, epsilon_s=1e-3, rho=7): + def _get_sigma_steps_t_steps(self, num_steps: int=18, epsilon_s: float=1e-3, rho: Union[int, float]=7): """ Overview: Get the schedule of sigma according to differernt t schedules. @@ -204,7 +220,7 @@ def sample(self, if not use_stochastic: # Main sampling loop t_next = t_steps[0] - x_next = latents.to(torch.float64) * (sigma(t_next) * scale(t_next)) + x_next = latents * (sigma(t_next) * scale(t_next)) for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next @@ -215,7 +231,7 @@ def sample(self, # Euler step. h = t_next - t_hat - denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels).to(torch.float64) + denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels) d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised x_prime = x_hat + alpha * h * d_cur t_prime = t_hat + alpha * h @@ -225,13 +241,13 @@ def sample(self, x_next = x_hat + h * d_cur else: assert self.solver_type == 'heun' - denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels).to(torch.float64) + denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels) d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) else: assert self.edm_type == "EDM", f"Stochastic can only use in EDM, but your precond type is {self.edm_type}" - x_next = latents.to(torch.float64) * t_steps[0] + x_next = latents * t_steps[0] for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next @@ -241,13 +257,13 @@ def sample(self, x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) # Euler step. - denoised = self.preconditioner(x_hat, t_hat, class_labels).to(torch.float64) + denoised = self.preconditioner(x_hat, t_hat, class_labels) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - denoised = self.preconditioner(x_next, t_next, class_labels).to(torch.float64) + denoised = self.preconditioner(x_next, t_next, class_labels) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index 7618332..8b0294a 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -39,9 +39,16 @@ def __init__(self, self.C_1 = precond_config_kwargs.get("C_1", 0.001) self.C_2 = precond_config_kwargs.get("C_2", 0.008) self.M = precond_config_kwargs.get("M", 1000) + + # For iDDPM_edm + def alpha_bar(j): + assert self.precondition_type == "iDDPM_edm", f"Only iDDPM_edm supports the alpha bar function, but your precond type is {self.precondition_type}" + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + u = torch.zeros(self.M + 1) for j in range(self.M, 0, -1): # M, ..., 1 - u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=self.C_1) - 1).sqrt() + u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=self.C_1) - 1).sqrt() self.register_buffer('u', u) self.sigma_min = float(u[self.M - 1]) self.sigma_max = float(u[0]) @@ -54,21 +61,15 @@ def __init__(self, else: raise ValueError(f"Please check your precond type {self.precondition_type} is in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") - - # For iDDPM_edm - def alpha_bar(self, j): - assert self.precondition_type == "iDDPM_edm", f"Only iDDPM_edm supports the alpha bar function, but your precond type is {self.precondition_type}" - j = torch.as_tensor(j) - return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 def round_sigma(self, sigma, return_index=False): if self.precondition_type == "iDDPM_edm": sigma = torch.as_tensor(sigma) - index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) + index = torch.cdist(sigma.to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) result = index if return_index else self.u[index.flatten()].to(sigma.dtype) - return result.reshape(sigma.shape).to(sigma.device) + return result.reshape(sigma.shape) else: return torch.as_tensor(sigma) diff --git a/grl/generative_models/edm_diffusion_model/edm_utils.py b/grl/generative_models/edm_diffusion_model/edm_utils.py index 8e54d89..7293ae7 100644 --- a/grl/generative_models/edm_diffusion_model/edm_utils.py +++ b/grl/generative_models/edm_diffusion_model/edm_utils.py @@ -4,8 +4,7 @@ ############# Sampling Section ############# -# Scheduling in Table 1 - +# Scheduling in Table 1 in paper https://arxiv.org/abs/2206.00364 SIGMA_T = { "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: ((0.5 * beta_d * (t ** 2) + beta_min * t).exp() - 1) ** 0.5, "VE_edm": lambda t, **kwargs: t.sqrt(), diff --git a/grl/generative_models/edm_diffusion_model/test.ipynb b/grl/generative_models/edm_diffusion_model/test.ipynb deleted file mode 100644 index 81b4bf0..0000000 --- a/grl/generative_models/edm_diffusion_model/test.ipynb +++ /dev/null @@ -1,370 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Optional, Tuple, Literal\n", - "from dataclasses import dataclass\n", - "\n", - "import numpy as np\n", - "import torch\n", - "from torch import Tensor\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim\n", - "from easydict import EasyDict\n", - "from functools import partial\n", - "\n", - "from edm_preconditioner import PreConditioner\n", - "from edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV, SCALE_T, SCALE_T_DERIV, DEFAULT_SOLVER_PARAM\n", - "from grl.generative_models.intrinsic_model import IntrinsicModel\n", - "\n", - "class Simple(nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.model = nn.Sequential(\n", - " nn.Linear(2, 32),\n", - " nn.ReLU(),\n", - " nn.Linear(32, 32), \n", - " nn.ReLU(),\n", - " nn.Linear(32, 2)\n", - " )\n", - " def forward(self, x, noise, class_labels=None):\n", - " return self.model(x)\n", - "\n", - "class EDMModel(nn.Module):\n", - " \n", - " def __init__(self, config: Optional[EasyDict]=None) -> None:\n", - " \n", - " super().__init__()\n", - " self.config= config\n", - " # self.x_size = config.x_size\n", - " self.device = config.device\n", - " \n", - " # EDM Type [\"VP_edm\", \"VE_edm\", \"iDDPM_edm\", \"EDM\"]\n", - " self.edm_type: str = config.edm_model.path.edm_type\n", - " assert self.edm_type in [\"VP_edm\", \"VE_edm\", \"iDDPM_edm\", \"EDM\"], \\\n", - " f\"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}\"\n", - " \n", - " #* 1. Construct basic Unet architecture through params in config\n", - " self.base_denoise_network = Simple()\n", - "\n", - " #* 2. Precond setup\n", - " self.params = config.edm_model.path.params\n", - " self.preconditioner = PreConditioner(\n", - " self.edm_type, \n", - " base_denoise_model=self.base_denoise_network, \n", - " use_mixes_precision=False,\n", - " **self.params\n", - " )\n", - " \n", - " #* 3. Solver setup\n", - " self.solver_type = config.edm_model.solver.solver_type\n", - " assert self.solver_type in ['euler', 'heun']\n", - " \n", - " self.solver_params = DEFAULT_SOLVER_PARAM\n", - " self.solver_params.update(config.edm_model.solver.params)\n", - " \n", - " # Initialize sigma_min and sigma_max if not provided\n", - " \n", - " if \"sigma_min\" not in self.params:\n", - " min = torch.tensor(1e-3)\n", - " self.sigma_min = {\n", - " \"VP_edm\": SIGMA_T[\"VP_edm\"](min, 19.9, 0.1), \n", - " \"VE_edm\": 0.02, \n", - " \"iDDPM_edm\": 0.002, \n", - " \"EDM\": 0.002\n", - " }[self.edm_type]\n", - " else:\n", - " self.sigma_min = self.params.sigma_min\n", - " if \"sigma_max\" not in self.params:\n", - " max = torch.tensor(1)\n", - " self.sigma_max = {\n", - " \"VP_edm\": SIGMA_T[\"VP_edm\"](max, 19.9, 0.1), \n", - " \"VE_edm\": 100, \n", - " \"iDDPM_edm\": 81, \n", - " \"EDM\": 80\n", - " }[self.edm_type] \n", - " else:\n", - " self.sigma_max = self.params.sigma_max\n", - " \n", - " def get_type(self):\n", - " return \"EDMModel\"\n", - "\n", - " # For VP_edm\n", - " def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]:\n", - " # assert the first dim of x is batch size\n", - " print(f\"params is {params}\")\n", - " rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) \n", - " if self.edm_type == \"VP_edm\":\n", - " epsilon_t = params.get(\"epsilon_t\", 1e-5)\n", - " beta_d = params.get(\"beta_d\", 19.9)\n", - " beta_min = params.get(\"beta_min\", 0.1)\n", - " \n", - " rand_uniform = torch.rand(*rand_shape, device=x.device)\n", - " sigma = SIGMA_T[\"VP_edm\"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min)\n", - " weight = 1 / sigma ** 2\n", - " elif self.edm_type == \"VE_edm\":\n", - " rand_uniform = torch.rand(*rand_shape, device=x.device)\n", - " sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform)\n", - " weight = 1 / sigma ** 2\n", - " elif self.edm_type == \"EDM\":\n", - " P_mean = params.get(\"P_mean\", -1.2)\n", - " P_std = params.get(\"P_mean\", 1.2)\n", - " sigma_data = params.get(\"sigma_data\", 0.5)\n", - " \n", - " rand_normal = torch.randn(*rand_shape, device=x.device)\n", - " sigma = (rand_normal * P_std + P_mean).exp()\n", - " weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2\n", - " return sigma, weight\n", - " \n", - " def forward(self, \n", - " x: Tensor, \n", - " class_labels=None) -> Tensor:\n", - " x = x.to(self.device)\n", - " sigma, weight = self._sample_sigma_weight_train(x, **self.params)\n", - " n = torch.randn_like(x) * sigma\n", - " D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels)\n", - " loss = weight * ((D_xn - x) ** 2)\n", - " return loss\n", - " \n", - " \n", - " def _get_sigma_steps_t_steps(self, num_steps=18, epsilon_s=1e-3, rho=7):\n", - " \"\"\"\n", - " Overview:\n", - " Get the schedule of sigma according to differernt t schedules.\n", - " \n", - " \"\"\"\n", - " self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min)\n", - " self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max)\n", - " \n", - " # Define time steps in terms of noise level\n", - " step_indices = torch.arange(num_steps, dtype=torch.float64, device=self.device)\n", - " sigma_steps = None\n", - " if self.edm_type == \"VP_edm\":\n", - " vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1)\n", - " vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d\n", - " \n", - " orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)\n", - " sigma_steps = SIGMA_T[\"VP_edm\"](orig_t_steps, vp_beta_d, vp_beta_min)\n", - " \n", - " elif self.edm_type == \"VE_edm\":\n", - " orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (num_steps - 1)))\n", - " sigma_steps = SIGMA_T[\"VE_edm\"](orig_t_steps)\n", - " \n", - " elif self.edm_type == \"iDDPM_edm\":\n", - " M, C_1, C_2 = self.params.M, self.params.C_1, self.params.C_2\n", - " \n", - " u = torch.zeros(M + 1, dtype=torch.float64, device=self.device)\n", - " alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2\n", - " for j in torch.arange(self.params.M, 0, -1, device=self.device): # M, ..., 1\n", - " u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()\n", - " u_filtered = u[torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)]\n", - " \n", - " sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] \n", - " orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) \n", - " \n", - " elif self.edm_type == \"EDM\": \n", - " sigma_steps = (self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * \\\n", - " (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))) ** rho\n", - " orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) \n", - " \n", - " t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0\n", - " \n", - " return sigma_steps, t_steps \n", - " \n", - " \n", - " def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s=1e-3):\n", - " \"\"\"\n", - " Overview:\n", - " Get sigma(t) for different solver schedules.\n", - " \n", - " Returns:\n", - " sigma(t), sigma'(t), sigma^{-1}(sigma) \n", - " \"\"\"\n", - " vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1)\n", - " vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d\n", - " sigma = partial(SIGMA_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", - " sigma_deriv = partial(SIGMA_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", - " sigma_inv = partial(SIGMA_T_INV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", - " scale = partial(SCALE_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", - " scale_deriv = partial(SCALE_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)\n", - "\n", - " return sigma, sigma_deriv, sigma_inv, scale, scale_deriv\n", - " \n", - " \n", - " def sample(self, \n", - " latents: Tensor, \n", - " class_labels: Tensor=None, \n", - " use_stochastic: bool=False, \n", - " **solver_params) -> Tensor:\n", - " \n", - " # Get sigmas, scales, and timesteps\n", - " print(f\"solver_params is {solver_params}\")\n", - " num_steps = self.solver_params.num_steps\n", - " epsilon_s = self.solver_params.epsilon_s\n", - " rho = self.solver_params.rho\n", - " \n", - " latents = latents.to(self.device)\n", - " sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=num_steps, epsilon_s=epsilon_s, rho=rho)\n", - " sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv()\n", - " \n", - " S_churn = self.solver_params.S_churn\n", - " S_min = self.solver_params.S_min\n", - " S_max = self.solver_params.S_max\n", - " S_noise = self.solver_params.S_noise\n", - " alpha = self.solver_params.alpha\n", - " \n", - " if not use_stochastic:\n", - " # Main sampling loop\n", - " t_next = t_steps[0]\n", - " x_next = latents.to(torch.float64) * (sigma(t_next) * scale(t_next))\n", - " for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1\n", - " x_cur = x_next\n", - "\n", - " # Increase noise temporarily.\n", - " gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0\n", - " t_hat = sigma_inv(self.preconditioner.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))\n", - " x_hat = scale(t_hat) / scale(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * scale(t_hat) * S_noise * torch.randn_like(x_cur)\n", - "\n", - " # Euler step.\n", - " h = t_next - t_hat\n", - " denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels).to(torch.float64)\n", - " d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised\n", - " x_prime = x_hat + alpha * h * d_cur\n", - " t_prime = t_hat + alpha * h\n", - "\n", - " # Apply 2nd order correction.\n", - " if self.solver_type == 'euler' or i == num_steps - 1:\n", - " x_next = x_hat + h * d_cur\n", - " else:\n", - " assert self.solver_type == 'heun'\n", - " denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels).to(torch.float64)\n", - " d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised\n", - " x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)\n", - " \n", - " else:\n", - " assert self.edm_type == \"EDM\", f\"Stochastic can only use in EDM, but your precond type is {self.edm_type}\"\n", - " x_next = latents.to(torch.float64) * t_steps[0]\n", - " for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1\n", - " x_cur = x_next\n", - "\n", - " # Increase noise temporarily.\n", - " gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0\n", - " t_hat = self.preconditioner.round_sigma(t_cur + gamma * t_cur)\n", - " x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)\n", - "\n", - " # Euler step.\n", - " denoised = self.preconditioner(x_hat, t_hat, class_labels).to(torch.float64)\n", - " d_cur = (x_hat - denoised) / t_hat\n", - " x_next = x_hat + (t_next - t_hat) * d_cur\n", - "\n", - " # Apply 2nd order correction.\n", - " if i < num_steps - 1:\n", - " denoised = self.preconditioner(x_next, t_next, class_labels).to(torch.float64)\n", - " d_prime = (x_next - denoised) / t_next\n", - " x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)\n", - "\n", - "\n", - " return x_next\n" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "params is {}\n", - "solver_params is {}\n" - ] - } - ], - "source": [ - "import torch\n", - "from easydict import EasyDict\n", - "\n", - "config = EasyDict(\n", - " dict(\n", - " device=torch.device(\"cuda\"), # Test if all tensors are converted to the same device\n", - " edm_model=dict( \n", - " path=dict(\n", - " edm_type=\"EDM\", # *[\"VP_edm\", \"VE_edm\", \"iDDPM_edm\", \"EDM\"]\n", - " params=dict(\n", - " #^ 1: VP_edm\n", - " # beta_d=19.9, \n", - " # beta_min=0.1, \n", - " # M=1000, \n", - " # epsilon_t=1e-5,\n", - " # epsilon_s=1e-3,\n", - " #^ 2: VE_edm\n", - " # sigma_min=0.02,\n", - " # sigma_max=100,\n", - " #^ 3: iDDPM_edm\n", - " # C_1=0.001,\n", - " # C_2=0.008,\n", - " # M=1000,\n", - " #^ 4: EDM\n", - " # sigma_min=0.002,\n", - " # sigma_max=80,\n", - " # sigma_data=0.5,\n", - " # P_mean=-1.2,\n", - " # P_std=1.2,\n", - " )\n", - " ),\n", - " solver=dict(\n", - " solver_type=\"heun\", \n", - " # *['euler', 'heun']\n", - " params=dict(\n", - " num_steps=18,\n", - " alpha=1, \n", - " S_churn=0, \n", - " S_min=0, \n", - " S_max=float(\"inf\"),\n", - " S_noise=1,\n", - " rho=7, #* EDM needs rho \n", - " epsilon_s=1e-3 #* VP_edm needs epsilon_s\n", - " )\n", - " )\n", - " )\n", - " )\n", - ")\n", - "\n", - "edm = EDMModel(config).to(config.device)\n", - "x = torch.randn((1024, 2)).to(config.device)\n", - "noise = torch.randn_like(x)\n", - "loss = edm(x).mean()\n", - "sample = edm.sample(x)\n", - "sample.shape\n", - "loss.backward()\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "base", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 4ca066ec29fcd3e177014f652da45e131b5638ce Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Wed, 21 Aug 2024 07:02:32 +0000 Subject: [PATCH 04/14] feature(wrh): add edm initial implementation --- .../edm_diffusion_model.py | 55 ++++++++++++++----- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 2fc0e34..27c8d33 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Callable from dataclasses import dataclass import numpy as np @@ -85,14 +85,25 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: self.sigma_min = INITIAL_SIGMA_MIN[self.edm_type] if "sigma_min" not in self.params else self.params.sigma_min self.sigma_max = INITIAL_SIGMA_MAX[self.edm_type] if "sigma_max" not in self.params else self.params.sigma_max - + @property def get_type(self) -> str: return "EDMModel" # For VP_edm def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]: - # assert the first dim of x is batch size + """ + Overview: + Sample sigma from given distribution for training according to edm type. + + Arguments: + x (:obj:`torch.Tensor`): The sample which needs to add noise. + + Returns: + sigma (:obj:`torch.Tensor`): Sampled sigma from the distribution. + weight (:obj:`torch.Tensor`): Loss weight obtained from sampled sigma. + """ log.info(f"Params of trainig is: {params}") + # assert the first dim of x is batch size rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) if self.edm_type == "VP_edm": epsilon_t = params.get("epsilon_t", 1e-5) @@ -116,22 +127,32 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 return sigma, weight - def forward(self, - x: Tensor, - class_labels: Tensor=None) -> Tensor: + def forward(self, x: Tensor, class_labels: Tensor=None) -> Tensor: x = x.to(self.device) sigma, weight = self._sample_sigma_weight_train(x, **self.params) n = torch.randn_like(x) * sigma D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels) loss = weight * ((D_xn - x) ** 2) - return loss + return loss.mean() - def _get_sigma_steps_t_steps(self, num_steps: int=18, epsilon_s: float=1e-3, rho: Union[int, float]=7): + def _get_sigma_steps_t_steps(self, + num_steps: int=18, + epsilon_s: float=1e-3, rho: Union[int, float]=7 + )-> Tuple[Tensor, Tensor]: """ Overview: Get the schedule of sigma according to differernt t schedules. + Arguments: + num_steps (:obj:`int`): The number of timesteps during denoise sampling. Default setting: 18. + epsilon_s (:obj:`float`): Parameter epsilon_s (only VP_edm needs). + rho (:obj:`Union[int, float]`): Parameter rho (only EDM needs). + + Returns: + sigma_steps (:obj:`torch.Tensor`): The scheduled sigma. + t_steps (:obj:`torch.Tensor`): The scheduled t. + """ self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min) self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max) @@ -144,11 +165,11 @@ def _get_sigma_steps_t_steps(self, num_steps: int=18, epsilon_s: float=1e-3, rho vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) - sigma_steps = SIGMA_T["VP_edm"](orig_t_steps, vp_beta_d, vp_beta_min) + sigma_steps = SIGMA_T[self.edm_type](orig_t_steps, vp_beta_d, vp_beta_min) elif self.edm_type == "VE_edm": orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (num_steps - 1))) - sigma_steps = SIGMA_T["VE_edm"](orig_t_steps) + sigma_steps = SIGMA_T[self.edm_type](orig_t_steps) elif self.edm_type == "iDDPM_edm": M, C_1, C_2 = self.params.M, self.params.C_1, self.params.C_2 @@ -166,19 +187,27 @@ def _get_sigma_steps_t_steps(self, num_steps: int=18, epsilon_s: float=1e-3, rho sigma_steps = (self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * \ (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))) ** rho orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) + else: + raise NotImplementedError(f"Please check your edm_type: {self.edm_type}, which is not in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") - t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0 + t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0E return sigma_steps, t_steps - def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s=1e-3): + def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s: Union[int, float]=1e-3) \ + -> Tuple[Callable, Callable, Callable, Callable, Callable]: """ Overview: Get sigma(t) for different solver schedules. Returns: - sigma(t), sigma'(t), sigma^{-1}(sigma) + sigma: (:obj:`Callable`): sigma(t) + sigma_deriv: (:obj:`Callable`): sigma'(t) + sigma_inv: (:obj:`Callable`): sigma^{-1} (sigma) + scale: (:obj:`Callable`): s(t) + scale_deriv: (:obj:`Callable`): s'(t) + """ vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1) vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d From 09a74251aebcdebd82b75dbc3559ff21eba5d45a Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Wed, 21 Aug 2024 07:51:32 +0000 Subject: [PATCH 05/14] feature(wrh): add initial version of edm --- .../edm_diffusion_model/edm_diffusion_model.py | 18 ++++++++++-------- .../swiss_roll/swiss_roll_edm_diffusion.py | 4 ---- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 27c8d33..c5d3985 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -18,7 +18,8 @@ from .edm_preconditioner import PreConditioner from .edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV from .edm_utils import SCALE_T, SCALE_T_DERIV -from .edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN, DEFAULT_SOLVER_PARAM +from .edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN +from .edm_utils import DEFAULT_PARAM, DEFAULT_SOLVER_PARAM class Simple(nn.Module): def __init__(self): @@ -63,7 +64,9 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: self.base_denoise_network = Simple() #* 2. Precond setup - self.params = config.edm_model.path.params + self.params = DEFAULT_PARAM[self.edm_type] + self.params.update(config.edm_model.path.params) + log.info(f"Using edm type: {self.edm_type}\nParam is {self.params}") self.preconditioner = PreConditioner( self.edm_type, base_denoise_model=self.base_denoise_network, @@ -78,7 +81,7 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: self.solver_params = DEFAULT_SOLVER_PARAM self.solver_params.update(config.edm_model.solver.params) - + log.info(f"Using solver type: {self.solver_type}\nSolver param is {self.solver_params}") # Initialize sigma_min and sigma_max if not provided @@ -102,7 +105,6 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso sigma (:obj:`torch.Tensor`): Sampled sigma from the distribution. weight (:obj:`torch.Tensor`): Loss weight obtained from sampled sigma. """ - log.info(f"Params of trainig is: {params}") # assert the first dim of x is batch size rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) if self.edm_type == "VP_edm": @@ -110,11 +112,11 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso beta_d = params.get("beta_d", 19.9) beta_min = params.get("beta_min", 0.1) - rand_uniform = torch.rand(*rand_shape, device=x.device) + rand_uniform = torch.rand(*rand_shape, device=self.device) sigma = SIGMA_T["VP_edm"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min) weight = 1 / sigma ** 2 elif self.edm_type == "VE_edm": - rand_uniform = torch.rand(*rand_shape, device=x.device) + rand_uniform = torch.rand(*rand_shape, device=self.device) sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform) weight = 1 / sigma ** 2 elif self.edm_type == "EDM": @@ -122,7 +124,7 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso P_std = params.get("P_mean", 1.2) sigma_data = params.get("sigma_data", 0.5) - rand_normal = torch.randn(*rand_shape, device=x.device) + rand_normal = torch.randn(*rand_shape, device=self.device) sigma = (rand_normal * P_std + P_mean).exp() weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 return sigma, weight @@ -230,7 +232,7 @@ def sample(self, ) -> Tensor: # Get sigmas, scales, and timesteps - log.info(f"Solver param is {self.solver_params}") + log.info(f"Start sampling!") num_steps = self.solver_params.num_steps epsilon_s = self.solver_params.epsilon_s rho = self.solver_params.rho diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py index d56b1be..e61aa0c 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -64,10 +64,6 @@ solver=dict( solver_type="heun", # *['euler', 'heun'] - schedule="Linear", - #* ['VP', 'VE', 'Linear'] Give "Linear" when edm type in ["iDDPM_edm", "EDM"] - scaling="none", - #* ["VP", "none"] Give "none" when edm type in ["VE_edm", "iDDPM_edm", "EDM"] params=dict( num_steps=18, alpha=1, From a6e9555be5012eeea2c66738b0bc185569ac97ba Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Wed, 21 Aug 2024 08:23:20 +0000 Subject: [PATCH 06/14] feature(wrh): add initial version of edm --- .../edm_diffusion_model.py | 18 +++++------ .../edm_diffusion_model/edm_preconditioner.py | 32 +++++++++++-------- .../swiss_roll/swiss_roll_edm_diffusion.py | 4 +-- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index c5d3985..3586561 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -64,7 +64,7 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: self.base_denoise_network = Simple() #* 2. Precond setup - self.params = DEFAULT_PARAM[self.edm_type] + self.params = EasyDict(DEFAULT_PARAM[self.edm_type]) self.params.update(config.edm_model.path.params) log.info(f"Using edm type: {self.edm_type}\nParam is {self.params}") self.preconditioner = PreConditioner( @@ -79,7 +79,7 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: assert self.solver_type in ['euler', 'heun'], \ f"Your solver type should in ['euler', 'heun'], but got {self.solver_type}" - self.solver_params = DEFAULT_SOLVER_PARAM + self.solver_params = EasyDict(DEFAULT_SOLVER_PARAM) self.solver_params.update(config.edm_model.solver.params) log.info(f"Using solver type: {self.solver_type}\nSolver param is {self.solver_params}") # Initialize sigma_min and sigma_max if not provided @@ -92,7 +92,6 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: def get_type(self) -> str: return "EDMModel" - # For VP_edm def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]: """ Overview: @@ -106,11 +105,12 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso weight (:obj:`torch.Tensor`): Loss weight obtained from sampled sigma. """ # assert the first dim of x is batch size + params = EasyDict(params) rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) if self.edm_type == "VP_edm": - epsilon_t = params.get("epsilon_t", 1e-5) - beta_d = params.get("beta_d", 19.9) - beta_min = params.get("beta_min", 0.1) + epsilon_t = params.epsilon_t + beta_d = params.beta_d + beta_min = params.beta_min rand_uniform = torch.rand(*rand_shape, device=self.device) sigma = SIGMA_T["VP_edm"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min) @@ -120,9 +120,9 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform) weight = 1 / sigma ** 2 elif self.edm_type == "EDM": - P_mean = params.get("P_mean", -1.2) - P_std = params.get("P_mean", 1.2) - sigma_data = params.get("sigma_data", 0.5) + P_mean = params.P_mean + P_std = params.P_std + sigma_data = params.sigma_data rand_normal = torch.randn(*rand_shape, device=self.device) sigma = (rand_normal * P_std + P_mean).exp() diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index 8b0294a..ced73c6 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -1,12 +1,14 @@ from typing import Optional, Tuple, Literal from dataclasses import dataclass +from torch import Tensor, as_tensor +from easydict import EasyDict import numpy as np import torch -from torch import Tensor, as_tensor import torch.nn as nn import torch.nn.functional as F +from grl.utils.log import log from .edm_utils import SIGMA_T, SIGMA_T_INV class PreConditioner(nn.Module): @@ -15,30 +17,32 @@ def __init__(self, precondition_type: Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] = "EDM", base_denoise_model: nn.Module = None, use_mixes_precision: bool = False, - **precond_config_kwargs) -> None: + **precond_params) -> None: super().__init__() + log.info(f"Precond_params: {precond_params}") + precond_params = EasyDict(precond_params) self.precondition_type = precondition_type self.base_denoise_model = base_denoise_model self.use_mixes_precision = use_mixes_precision if self.precondition_type == "VP_edm": - self.beta_d = precond_config_kwargs.get("beta_d", 19.9) - self.beta_min = precond_config_kwargs.get("beta_min", 0.1) - self.M = precond_config_kwargs.get("M", 1000) - self.epsilon_t = precond_config_kwargs.get("epsilon_t", 1e-5) + self.beta_d = precond_params.beta_d + self.beta_min = precond_params.beta_min + self.M = precond_params.M + self.epsilon_t = precond_params.epsilon_t self.sigma_min = SIGMA_T["VP_edm"](self.epsilon_t, self.beta_d, self.beta_min) self.sigma_max = SIGMA_T["VP_edm"](1, self.beta_d, self.beta_min) elif self.precondition_type == "VE_edm": - self.sigma_min = precond_config_kwargs.get("sigma_min", 0.02) - self.sigma_max = precond_config_kwargs.get("sigma_max", 100) + self.sigma_min = precond_params.sigma_min + self.sigma_max = precond_params.sigma_max elif self.precondition_type == "iDDPM_edm": - self.C_1 = precond_config_kwargs.get("C_1", 0.001) - self.C_2 = precond_config_kwargs.get("C_2", 0.008) - self.M = precond_config_kwargs.get("M", 1000) + self.C_1 = precond_params.C_1 + self.C_2 = precond_params.C_2 + self.M = precond_params.M # For iDDPM_edm def alpha_bar(j): @@ -54,9 +58,9 @@ def alpha_bar(j): self.sigma_max = float(u[0]) elif self.precondition_type == "EDM": - self.sigma_min = precond_config_kwargs.get("sigma_min", 0.002) - self.sigma_max = precond_config_kwargs.get("sigma_max", 80) - self.sigma_data = precond_config_kwargs.get("sigma_data", 0.5) + self.sigma_min = precond_params.sigma_min + self.sigma_max = precond_params.sigma_max + self.sigma_data = precond_params.sigma_data else: raise ValueError(f"Please check your precond type {self.precondition_type} is in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py index e61aa0c..99b588e 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -56,8 +56,8 @@ sigma_min=0.002, sigma_max=80, sigma_data=0.5, - P_mean=-1.2, - P_std=1.2, + P_mean=-1.21, + P_std=1.21, ) ), From 9ab4216b18f056a8a5410dff8e28fd5d6290808d Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Wed, 21 Aug 2024 08:41:21 +0000 Subject: [PATCH 07/14] feature(wrh): add initial version of edm --- .../edm_diffusion_model/edm_diffusion_model.py | 5 +++++ .../edm_diffusion_model/edm_preconditioner.py | 4 ++-- .../swiss_roll/swiss_roll_edm_diffusion.py | 16 ++++++++-------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 3586561..409490a 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -119,6 +119,11 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso rand_uniform = torch.rand(*rand_shape, device=self.device) sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform) weight = 1 / sigma ** 2 + elif self.edm_type == "iDDPM_edm": + u = self.preconditioner.u + sigma_index = torch.randint(0, self.params.M - 1, rand_shape, device=self.device) + sigma = u[sigma_index] + weight = 1 / sigma ** 2 elif self.edm_type == "EDM": P_mean = params.P_mean P_std = params.P_std diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index ced73c6..670bc4e 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -32,8 +32,8 @@ def __init__(self, self.M = precond_params.M self.epsilon_t = precond_params.epsilon_t - self.sigma_min = SIGMA_T["VP_edm"](self.epsilon_t, self.beta_d, self.beta_min) - self.sigma_max = SIGMA_T["VP_edm"](1, self.beta_d, self.beta_min) + self.sigma_min = float(SIGMA_T["VP_edm"](torch.tensor(self.epsilon_t), self.beta_d, self.beta_min)) + self.sigma_max = float(SIGMA_T["VP_edm"](torch.tensor(1), self.beta_d, self.beta_min)) elif self.precondition_type == "VE_edm": self.sigma_min = precond_params.sigma_min diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py index 99b588e..74d5614 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -37,14 +37,14 @@ device=device, edm_model=dict( path=dict( - edm_type="EDM", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] + edm_type="iDDPM_edm", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] params=dict( #^ 1: VP_edm # beta_d=19.9, # beta_min=0.1, # M=1000, # epsilon_t=1e-5, - # epsilon_s=1e-3, + # epsilon_s=1e-4, #^ 2: VE_edm # sigma_min=0.02, # sigma_max=100, @@ -53,11 +53,11 @@ # C_2=0.008, # M=1000, #^ 4: EDM - sigma_min=0.002, - sigma_max=80, - sigma_data=0.5, - P_mean=-1.21, - P_std=1.21, + # sigma_min=0.002, + # sigma_max=80, + # sigma_data=0.5, + # P_mean=-1.21, + # P_std=1.21, ) ), @@ -217,7 +217,7 @@ def save_checkpoint(model, optimizer, iteration): for i in range(10): edm_diffusion_model.train() - loss = edm_diffusion_model(batch_data).mean() + loss = edm_diffusion_model(batch_data) optimizer.zero_grad() loss.backward() gradien_norm = torch.nn.utils.clip_grad_norm_( From 207f35dd0c2b099ec9b3a0503120601480749df8 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Wed, 21 Aug 2024 10:46:41 +0000 Subject: [PATCH 08/14] feature(wrh): add initial version of edm --- .../edm_diffusion_model.py | 22 +++++++++---------- .../edm_diffusion_model/edm_preconditioner.py | 6 +++-- .../swiss_roll/swiss_roll_edm_diffusion.py | 4 ++-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 409490a..5f7ea86 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -31,7 +31,7 @@ def __init__(self): nn.ReLU(), nn.Linear(32, 2) ) - def forward(self, x, noise, class_labels=None): + def forward(self, x, noise, condition=None): return self.model(x) class EDMModel(nn.Module): @@ -61,7 +61,7 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}" #* 1. Construct basic Unet architecture through params in config - self.base_denoise_network = Simple() + self.base_denoise_network = IntrinsicModel(config.edm_model.model.args) #* 2. Precond setup self.params = EasyDict(DEFAULT_PARAM[self.edm_type]) @@ -134,11 +134,11 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 return sigma, weight - def forward(self, x: Tensor, class_labels: Tensor=None) -> Tensor: + def forward(self, x: Tensor, condition: Tensor=None) -> Tensor: x = x.to(self.device) sigma, weight = self._sample_sigma_weight_train(x, **self.params) n = torch.randn_like(x) * sigma - D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels) + D_xn = self.preconditioner(sigma, x+n, condition=condition) loss = weight * ((D_xn - x) ** 2) return loss.mean() @@ -165,7 +165,7 @@ def _get_sigma_steps_t_steps(self, self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max) # Define time steps in terms of noise level - step_indices = torch.arange(num_steps, dtype=torch.float64, device=self.device) + step_indices = torch.arange(num_steps, dtype=torch.float32, device=self.device) sigma_steps = None if self.edm_type == "VP_edm": vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1) @@ -181,7 +181,7 @@ def _get_sigma_steps_t_steps(self, elif self.edm_type == "iDDPM_edm": M, C_1, C_2 = self.params.M, self.params.C_1, self.params.C_2 - u = torch.zeros(M + 1, dtype=torch.float64, device=self.device) + u = torch.zeros(M + 1, dtype=torch.float, device=self.device) alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 for j in torch.arange(self.params.M, 0, -1, device=self.device): # M, ..., 1 u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() @@ -231,7 +231,7 @@ def sample(self, t_span, batch_size, latents: Tensor, - class_labels: Tensor=None, + condition: Tensor=None, use_stochastic: bool=False, **solver_kwargs ) -> Tensor: @@ -267,7 +267,7 @@ def sample(self, # Euler step. h = t_next - t_hat - denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels) + denoised = self.preconditioner(sigma(t_hat), x_hat / scale(t_hat), condition) d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised x_prime = x_hat + alpha * h * d_cur t_prime = t_hat + alpha * h @@ -277,7 +277,7 @@ def sample(self, x_next = x_hat + h * d_cur else: assert self.solver_type == 'heun' - denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels) + denoised = self.preconditioner(sigma(t_prime), x_prime / scale(t_prime), condition) d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) @@ -293,13 +293,13 @@ def sample(self, x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) # Euler step. - denoised = self.preconditioner(x_hat, t_hat, class_labels) + denoised = self.preconditioner(t_hat, x_hat, condition) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - denoised = self.preconditioner(x_next, t_next, class_labels) + denoised = self.preconditioner(t_next, x_next, condition) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index 670bc4e..98ab5fc 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -101,16 +101,18 @@ def get_precondition_c(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Ten c_noise = sigma.log() / 4 return c_skip, c_out, c_in, c_noise - def forward(self, x: Tensor, sigma: Tensor, class_labels=None, **model_kwargs): + def forward(self, sigma: Tensor, x: Tensor, condition: Tensor=None, **model_kwargs): # Suppose the first dim of x is batch size x = x.to(torch.float32) sigma_shape = [x.shape[0]] + [1] * (x.ndim - 1) + + if sigma.numel() == 1: sigma = sigma.view(-1).expand(*sigma_shape) dtype = torch.float16 if (self.use_mixes_precision and x.device.type == 'cuda') else torch.float32 c_skip, c_out, c_in, c_noise = self.get_precondition_c(sigma) - F_x = self.base_denoise_model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) + F_x = self.base_denoise_model(c_noise.flatten(), (c_in * x).to(dtype), condition=condition, **model_kwargs) assert F_x.dtype == dtype D_x = c_skip * x + c_out * F_x.to(torch.float32) return D_x \ No newline at end of file diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py index 74d5614..1640ec9 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -37,7 +37,7 @@ device=device, edm_model=dict( path=dict( - edm_type="iDDPM_edm", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] + edm_type="EDM", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] params=dict( #^ 1: VP_edm # beta_d=19.9, @@ -109,7 +109,7 @@ seed_value = set_seed() log.info(f"start exp with seed value {seed_value}.") edm_diffusion_model = EDMModel(config=config).to(config.device) - edm_diffusion_model = torch.compile(edm_diffusion_model) + # edm_diffusion_model = torch.compile(edm_diffusion_model) # get data data = make_swiss_roll(n_samples=config.parameter.data_num, noise=0.01)[0].astype( np.float32 From fff231d68050a3cc39ddae4df1fa4dfcb402626d Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Wed, 21 Aug 2024 13:14:59 +0000 Subject: [PATCH 09/14] feature(wrh): add initial version of edm --- .../edm_diffusion_model/edm_diffusion_model.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 5f7ea86..225610b 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -21,18 +21,6 @@ from .edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN from .edm_utils import DEFAULT_PARAM, DEFAULT_SOLVER_PARAM -class Simple(nn.Module): - def __init__(self): - super().__init__() - self.model = nn.Sequential( - nn.Linear(2, 32), - nn.ReLU(), - nn.Linear(32, 32), - nn.ReLU(), - nn.Linear(32, 2) - ) - def forward(self, x, noise, condition=None): - return self.model(x) class EDMModel(nn.Module): """ From 249574b099b5415f93d8ba6b2562f699ee4f81e0 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Fri, 23 Aug 2024 05:49:30 +0000 Subject: [PATCH 10/14] feature(wrh): fit edm into known format --- .../edm_diffusion_model.py | 394 +++++++++++++----- .../edm_diffusion_model/edm_preconditioner.py | 19 +- .../edm_diffusion_model/edm_utils.py | 4 +- .../swiss_roll/swiss_roll_edm_diffusion.py | 90 ++-- 4 files changed, 367 insertions(+), 140 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 225610b..7f968d1 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -1,16 +1,28 @@ -from typing import Optional, Tuple, Union, Callable +from typing import Any, Callable, Dict, List, Tuple, Union, Optional from dataclasses import dataclass import numpy as np -import torch from torch import Tensor +import torch import torch.nn as nn import torch.nn.functional as F +import treetensor +from tensordict import TensorDict + import torch.optim as optim from easydict import EasyDict from functools import partial from grl.generative_models.intrinsic_model import IntrinsicModel +from grl.generative_models.random_generator import gaussian_random_variable +from grl.numerical_methods.numerical_solvers import get_solver +from grl.numerical_methods.numerical_solvers.dpm_solver import DPMSolver +from grl.numerical_methods.numerical_solvers.ode_solver import ( + DictTensorODESolver, + ODESolver, +) +from grl.numerical_methods.numerical_solvers.sde_solver import SDESolver + from grl.utils import find_parameters from grl.utils import set_seed from grl.utils.log import log @@ -40,39 +52,43 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: super().__init__() self.config = config - # self.x_size = config.x_size + self.x_size = config.x_size self.device = config.device + + self.gaussian_generator = gaussian_random_variable( + config.x_size, + config.device, + config.use_tree_tensor if hasattr(config, "use_tree_tensor") else False, + ) + + if hasattr(config, "solver"): + self.solver = get_solver(config.solver.type)(**config.solver.args) + # EDM Type ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] - self.edm_type = config.edm_model.path.edm_type + self.edm_type = config.path.edm_type assert self.edm_type in ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"], \ f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}" #* 1. Construct basic Unet architecture through params in config - self.base_denoise_network = IntrinsicModel(config.edm_model.model.args) + self.model = IntrinsicModel(config.model.args) #* 2. Precond setup self.params = EasyDict(DEFAULT_PARAM[self.edm_type]) - self.params.update(config.edm_model.path.params) + self.params.update(config.path.params) log.info(f"Using edm type: {self.edm_type}\nParam is {self.params}") self.preconditioner = PreConditioner( self.edm_type, - base_denoise_model=self.base_denoise_network, + denoise_model=self.model, use_mixes_precision=False, **self.params ) - - #* 3. Solver setup - self.solver_type = config.edm_model.solver.solver_type - assert self.solver_type in ['euler', 'heun'], \ - f"Your solver type should in ['euler', 'heun'], but got {self.solver_type}" - + self.solver_params = EasyDict(DEFAULT_SOLVER_PARAM) - self.solver_params.update(config.edm_model.solver.params) - log.info(f"Using solver type: {self.solver_type}\nSolver param is {self.solver_params}") + self.solver_params.update(config.sample_params) + # Initialize sigma_min and sigma_max if not provided - self.sigma_min = INITIAL_SIGMA_MIN[self.edm_type] if "sigma_min" not in self.params else self.params.sigma_min self.sigma_max = INITIAL_SIGMA_MAX[self.edm_type] if "sigma_max" not in self.params else self.params.sigma_max @@ -80,7 +96,7 @@ def __init__(self, config: Optional[EasyDict]=None) -> None: def get_type(self) -> str: return "EDMModel" - def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]: + def _sample_sigma_weight_train(self, x: Tensor) -> Tuple[Tensor, Tensor]: """ Overview: Sample sigma from given distribution for training according to edm type. @@ -93,12 +109,12 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso weight (:obj:`torch.Tensor`): Loss weight obtained from sampled sigma. """ # assert the first dim of x is batch size - params = EasyDict(params) + rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) if self.edm_type == "VP_edm": - epsilon_t = params.epsilon_t - beta_d = params.beta_d - beta_min = params.beta_min + epsilon_t = self.params.epsilon_t + beta_d = self.params.beta_d + beta_min = self.params.beta_min rand_uniform = torch.rand(*rand_shape, device=self.device) sigma = SIGMA_T["VP_edm"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min) @@ -113,28 +129,43 @@ def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tenso sigma = u[sigma_index] weight = 1 / sigma ** 2 elif self.edm_type == "EDM": - P_mean = params.P_mean - P_std = params.P_std - sigma_data = params.sigma_data + P_mean = self.params.P_mean + P_std = self.params.P_std + sigma_data = self.params.sigma_data rand_normal = torch.randn(*rand_shape, device=self.device) sigma = (rand_normal * P_std + P_mean).exp() weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 return sigma, weight - def forward(self, x: Tensor, condition: Tensor=None) -> Tensor: - x = x.to(self.device) - sigma, weight = self._sample_sigma_weight_train(x, **self.params) + def forward(self, x, condition=None): + return self.sample(x, condition) + + def L2_denoising_matching_loss( + self, + x: Tensor, + condition: Optional[Tensor]=None + ): + """ + Overview: + Calculate the L2 denoising matching loss. + Arguments: + x (:obj:`torch.Tensor`): The sample which needs to add noise. + condition (:obj:`torch.Tensor`): The condition for the sample. Default setting: None. + Returns: + loss (:obj:`torch.Tensor`): The L2 denoising matching loss. + """ + + sigma, weight = self._sample_sigma_weight_train(x) n = torch.randn_like(x) * sigma D_xn = self.preconditioner(sigma, x+n, condition=condition) loss = weight * ((D_xn - x) ** 2) return loss.mean() - - + def _get_sigma_steps_t_steps(self, num_steps: int=18, epsilon_s: float=1e-3, rho: Union[int, float]=7 - )-> Tuple[Tensor, Tensor]: + ): """ Overview: Get the schedule of sigma according to differernt t schedules. @@ -149,9 +180,6 @@ def _get_sigma_steps_t_steps(self, t_steps (:obj:`torch.Tensor`): The scheduled t. """ - self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min) - self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max) - # Define time steps in terms of noise level step_indices = torch.arange(num_steps, dtype=torch.float32, device=self.device) sigma_steps = None @@ -213,25 +241,106 @@ def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s: Union[int, float]=1e-3) \ scale_deriv = partial(SCALE_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) return sigma, sigma_deriv, sigma_inv, scale, scale_deriv - - - def sample(self, - t_span, - batch_size, - latents: Tensor, - condition: Tensor=None, - use_stochastic: bool=False, - **solver_kwargs - ) -> Tensor: + + def sample( + self, + t_span: torch.Tensor = None, + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): + + return self.sample_forward_process( + t_span=t_span, + batch_size=batch_size, + x_0=x_0, + condition=condition, + with_grad=with_grad, + solver_config=solver_config, + )[-1] + + def sample_forward_process( + self, + t_span: torch.Tensor = None, + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): - # Get sigmas, scales, and timesteps - log.info(f"Start sampling!") - num_steps = self.solver_params.num_steps - epsilon_s = self.solver_params.epsilon_s - rho = self.solver_params.rho + if t_span is not None: + t_span = t_span.to(self.device) + + if batch_size is None: + extra_batch_size = torch.tensor((1,), device=self.device) + elif isinstance(batch_size, int): + extra_batch_size = torch.tensor((batch_size,), device=self.device) + else: + if ( + isinstance(batch_size, torch.Size) + or isinstance(batch_size, Tuple) + or isinstance(batch_size, List) + ): + extra_batch_size = torch.tensor(batch_size, device=self.device) + else: + assert False, "Invalid batch size" + + if x_0 is not None and condition is not None: + assert ( + x_0.shape[0] == condition.shape[0] + ), "The batch size of x_0 and condition must be the same" + data_batch_size = x_0.shape[0] + elif x_0 is not None: + data_batch_size = x_0.shape[0] + elif condition is not None: + data_batch_size = condition.shape[0] + else: + data_batch_size = 1 + + if solver_config is not None: + solver = get_solver(solver_config.type)(**solver_config.args) + else: + assert hasattr( + self, "solver" + ), "solver must be specified in config or solver_config" + solver = self.solver + + if x_0 is None: + x = self.gaussian_generator( + batch_size=torch.prod(extra_batch_size) * data_batch_size + ) + # x.shape = (B*N, D) + else: + if isinstance(self.x_size, int): + assert ( + torch.Size([self.x_size]) == x_0[0].shape + ), "The shape of x_0 must be the same as the x_size that is specified in the config" + elif ( + isinstance(self.x_size, Tuple) + or isinstance(self.x_size, List) + or isinstance(self.x_size, torch.Size) + ): + assert ( + torch.Size(self.x_size) == x_0[0].shape + ), "The shape of x_0 must be the same as the x_size that is specified in the config" + else: + assert False, "Invalid x_size" + + x = torch.repeat_interleave(x_0, torch.prod(extra_batch_size), dim=0) + # x.shape = (B*N, D) + + if condition is not None: + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + - latents = latents.to(self.device) - sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=num_steps, epsilon_s=epsilon_s, rho=rho) + sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=self.solver_params.num_steps, epsilon_s=self.solver_params.epsilon_s, rho=self.solver_params.rho) + sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv() S_churn = self.solver_params.S_churn @@ -240,56 +349,145 @@ def sample(self, S_noise = self.solver_params.S_noise alpha = self.solver_params.alpha - - if not use_stochastic: - # Main sampling loop - t_next = t_steps[0] - x_next = latents * (sigma(t_next) * scale(t_next)) - for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 - x_cur = x_next - - # Increase noise temporarily. - gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 - t_hat = sigma_inv(self.preconditioner.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) - x_hat = scale(t_hat) / scale(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * scale(t_hat) * S_noise * torch.randn_like(x_cur) - - # Euler step. - h = t_next - t_hat - denoised = self.preconditioner(sigma(t_hat), x_hat / scale(t_hat), condition) - d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised - x_prime = x_hat + alpha * h * d_cur - t_prime = t_hat + alpha * h - - # Apply 2nd order correction. - if self.solver_type == 'euler' or i == num_steps - 1: - x_next = x_hat + h * d_cur - else: - assert self.solver_type == 'heun' - denoised = self.preconditioner(sigma(t_prime), x_prime / scale(t_prime), condition) - d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised - x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) - - else: - assert self.edm_type == "EDM", f"Stochastic can only use in EDM, but your precond type is {self.edm_type}" - x_next = latents * t_steps[0] - for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 - x_cur = x_next + # # Main sampling loop + # t_next = t_steps[0] + # x_next = x_0 * (sigma(t_next) * scale(t_next)) + # for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + # x_cur = x_next + + # # Euler step. + # h = t_next - t_cur + # denoised = self.preconditioner(sigma(t_cur), x_cur / scale(t_cur), condition) + # d_cur = (sigma_deriv(t_cur) / sigma(t_cur) + scale_deriv(t_cur) / scale(t_cur)) * x_cur - sigma_deriv(t_cur) * scale(t_cur) / sigma(t_cur) * denoised + + # x_next = x_cur + h * d_cur + + def drift(t, x): + t_shape = [x.shape[0]] + [1] * (x.ndim - 1) + t = t.view(*t_shape) + denoised = self.preconditioner(sigma(t), x / scale(t), condition) + f=(sigma_deriv(t) / sigma(t) + scale_deriv(t) / scale(t)) * x - sigma_deriv(t) * scale(t) / sigma(t) * denoised + return f + + t_span = torch.tensor(t_steps, device=self.device) + if isinstance(solver, ODESolver): + # TODO: make it compatible with TensorDict + if with_grad: + data = solver.integrate( + drift=drift, + x0=x, + t_span=t_span, + adjoint_params=find_parameters(self.model), + ) + else: + with torch.no_grad(): + data = solver.integrate( + drift=drift, + x0=x, + t_span=t_span, + adjoint_params=find_parameters(self.model), + ) - # Increase noise temporarily. - gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 - t_hat = self.preconditioner.round_sigma(t_cur + gamma * t_cur) - x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) - # Euler step. - denoised = self.preconditioner(t_hat, x_hat, condition) - d_cur = (x_hat - denoised) / t_hat - x_next = x_hat + (t_next - t_hat) * d_cur + if isinstance(data, torch.Tensor): + # data.shape = (T, B*N, D) + if len(extra_batch_size.shape) == 0: + if isinstance(self.x_size, int): + data = data.reshape( + -1, extra_batch_size, data_batch_size, self.x_size + ) + elif ( + isinstance(self.x_size, Tuple) + or isinstance(self.x_size, List) + or isinstance(self.x_size, torch.Size) + ): + data = data.reshape( + -1, extra_batch_size, data_batch_size, *self.x_size + ) + else: + assert False, "Invalid x_size" + else: + if isinstance(self.x_size, int): + data = data.reshape( + -1, *extra_batch_size, data_batch_size, self.x_size + ) + elif ( + isinstance(self.x_size, Tuple) + or isinstance(self.x_size, List) + or isinstance(self.x_size, torch.Size) + ): + data = data.reshape( + -1, *extra_batch_size, data_batch_size, *self.x_size + ) + else: + assert False, "Invalid x_size" + # data.shape = (T, B, N, D) - # Apply 2nd order correction. - if i < num_steps - 1: - denoised = self.preconditioner(t_next, x_next, condition) - d_prime = (x_next - denoised) / t_next - x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + if batch_size is None: + if x_0 is None and condition is None: + data = data.squeeze(1).squeeze(1) + # data.shape = (T, D) + else: + data = data.squeeze(1) + # data.shape = (T, N, D) + else: + if x_0 is None and condition is None: + data = data.squeeze(1 + len(extra_batch_size.shape)) + # data.shape = (T, B, D) + else: + # data.shape = (T, B, N, D) + pass + elif isinstance(data, TensorDict): + raise NotImplementedError("Not implemented") + elif isinstance(data, treetensor.torch.Tensor): + for key in data.keys(): + if len(extra_batch_size.shape) == 0: + if isinstance(self.x_size[key], int): + data[key] = data[key].reshape( + -1, extra_batch_size, data_batch_size, self.x_size[key] + ) + elif ( + isinstance(self.x_size[key], Tuple) + or isinstance(self.x_size[key], List) + or isinstance(self.x_size[key], torch.Size) + ): + data[key] = data[key].reshape( + -1, extra_batch_size, data_batch_size, *self.x_size[key] + ) + else: + assert False, "Invalid x_size" + else: + if isinstance(self.x_size[key], int): + data[key] = data[key].reshape( + -1, *extra_batch_size, data_batch_size, self.x_size[key] + ) + elif ( + isinstance(self.x_size[key], Tuple) + or isinstance(self.x_size[key], List) + or isinstance(self.x_size[key], torch.Size) + ): + data[key] = data[key].reshape( + -1, *extra_batch_size, data_batch_size, *self.x_size[key] + ) + else: + assert False, "Invalid x_size" + # data.shape = (T, B, N, D) + if batch_size is None: + if x_0 is None and condition is None: + data[key] = data[key].squeeze(1).squeeze(1) + # data.shape = (T, D) + else: + data[key] = data[key].squeeze(1) + # data.shape = (T, N, D) + else: + if x_0 is None and condition is None: + data[key] = data[key].squeeze(1 + len(extra_batch_size.shape)) + # data.shape = (T, B, D) + else: + # data.shape = (T, B, N, D) + pass + else: + raise NotImplementedError("Not implemented") - return x_next + return data \ No newline at end of file diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index 98ab5fc..d8527d0 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -12,10 +12,16 @@ from .edm_utils import SIGMA_T, SIGMA_T_INV class PreConditioner(nn.Module): - + """ + Overview: + Precondition step in EDM. + + Interface: + ``__init__``, ``round_sigma``, ``get_precondition_c``, ``forward`` + """ def __init__(self, precondition_type: Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] = "EDM", - base_denoise_model: nn.Module = None, + denoise_model: nn.Module = None, use_mixes_precision: bool = False, **precond_params) -> None: @@ -23,7 +29,7 @@ def __init__(self, log.info(f"Precond_params: {precond_params}") precond_params = EasyDict(precond_params) self.precondition_type = precondition_type - self.base_denoise_model = base_denoise_model + self.denoise_model = denoise_model self.use_mixes_precision = use_mixes_precision if self.precondition_type == "VP_edm": @@ -67,7 +73,7 @@ def alpha_bar(j): - def round_sigma(self, sigma, return_index=False): + def round_sigma(self, sigma: Tensor, return_index: bool=False) -> Tensor: if self.precondition_type == "iDDPM_edm": sigma = torch.as_tensor(sigma) @@ -109,10 +115,11 @@ def forward(self, sigma: Tensor, x: Tensor, condition: Tensor=None, **model_kwar if sigma.numel() == 1: sigma = sigma.view(-1).expand(*sigma_shape) - + else: + sigma = sigma.view(*sigma_shape) dtype = torch.float16 if (self.use_mixes_precision and x.device.type == 'cuda') else torch.float32 c_skip, c_out, c_in, c_noise = self.get_precondition_c(sigma) - F_x = self.base_denoise_model(c_noise.flatten(), (c_in * x).to(dtype), condition=condition, **model_kwargs) + F_x = self.denoise_model(c_noise.flatten(), (c_in * x).to(dtype), condition=condition, **model_kwargs) assert F_x.dtype == dtype D_x = c_skip * x + c_out * F_x.to(torch.float32) return D_x \ No newline at end of file diff --git a/grl/generative_models/edm_diffusion_model/edm_utils.py b/grl/generative_models/edm_diffusion_model/edm_utils.py index 7293ae7..e31c16e 100644 --- a/grl/generative_models/edm_diffusion_model/edm_utils.py +++ b/grl/generative_models/edm_diffusion_model/edm_utils.py @@ -43,14 +43,14 @@ INITIAL_SIGMA_MIN = { - "VP_edm": SIGMA_T["VP_edm"](torch.tensor(1e-3), 19.9, 0.1), + "VP_edm": float(SIGMA_T["VP_edm"](torch.tensor(1e-3), 19.9, 0.1)), "VE_edm": 0.02, "iDDPM_edm": 0.002, "EDM": 0.002 } INITIAL_SIGMA_MAX = { - "VP_edm": SIGMA_T["VP_edm"](torch.tensor(1.), 19.9, 0.1), + "VP_edm": float(SIGMA_T["VP_edm"](torch.tensor(1.), 19.9, 0.1)), "VE_edm": 100, "iDDPM_edm": 81, "EDM": 80 diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py index 1640ec9..1e94e4f 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -34,47 +34,68 @@ ) config = EasyDict( dict( - device=device, - edm_model=dict( + device=device, + + edm_model=dict( + device=device, + x_size=[2], + sample_params=dict( + num_steps=18, + alpha=1, + S_churn=0.0, + S_min=0.0, + S_max=float("inf"), + S_noise=1.0, + rho=7, # * EDM needs rho + epsilon_s=1e-3, # * VP needs epsilon_s + ), + solver=dict( + type="ODESolver", + args=dict( + library="torchdyn", + ode_solver="euler", + ), + ), path=dict( - edm_type="EDM", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] + edm_type="EDM", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] + # solver=dict( + # solver_type="heun", + # # *['euler', 'heun'] + # params=dict( + # num_steps=18, + # alpha=1, + # S_churn=0.0, + # S_min=0.0, + # S_max=float("inf"), + # S_noise=1.0, + # rho=7, # * EDM needs rho + # epsilon_s=1e-3, # * VP needs epsilon_s + # ), + # ), params=dict( - #^ 1: VP_edm - # beta_d=19.9, - # beta_min=0.1, - # M=1000, + # ^ 1: VP_edm + # beta_d=19.9, + # beta_min=0.1, + # M=1000, # epsilon_t=1e-5, # epsilon_s=1e-4, - #^ 2: VE_edm + # ^ 2: VE_edm # sigma_min=0.02, # sigma_max=100, - #^ 3: iDDPM_edm + # ^ 3: iDDPM_edm # C_1=0.001, # C_2=0.008, # M=1000, - #^ 4: EDM + # ^ 4: EDM # sigma_min=0.002, # sigma_max=80, # sigma_data=0.5, # P_mean=-1.21, # P_std=1.21, - ) + ), ), + - solver=dict( - solver_type="heun", - # *['euler', 'heun'] - params=dict( - num_steps=18, - alpha=1, - S_churn=0., - S_min=0., - S_max=float("inf"), - S_noise=1., - rho=7, #* EDM needs rho - epsilon_s=1e-3 #* VP needs epsilon_s - ) - ), model=dict( type="noise_function", args=dict( @@ -108,7 +129,7 @@ if __name__ == "__main__": seed_value = set_seed() log.info(f"start exp with seed value {seed_value}.") - edm_diffusion_model = EDMModel(config=config).to(config.device) + edm_diffusion_model = EDMModel(config=config.edm_model).to(config.device) # edm_diffusion_model = torch.compile(edm_diffusion_model) # get data data = make_swiss_roll(n_samples=config.parameter.data_num, noise=0.01)[0].astype( @@ -214,10 +235,10 @@ def save_checkpoint(model, optimizer, iteration): history_iteration = [-1] batch_data = next(data_generator) batch_data = batch_data.to(config.device) - + for i in range(10): edm_diffusion_model.train() - loss = edm_diffusion_model(batch_data) + loss = edm_diffusion_model.L2_denoising_matching_loss(batch_data) optimizer.zero_grad() loss.backward() gradien_norm = torch.nn.utils.clip_grad_norm_( @@ -228,10 +249,11 @@ def save_checkpoint(model, optimizer, iteration): loss_sum += loss.item() counter += 1 iteration += 1 - log.info(f"iteration {iteration}, gradient {gradient_sum/counter}, loss {loss_sum/counter}") - + log.info( + f"iteration {iteration}, gradient {gradient_sum/counter}, loss {loss_sum/counter}" + ) + edm_diffusion_model.eval() - latents = torch.randn((2048, 2)) - sampled = edm_diffusion_model.sample(None, None, latents=latents) - log.info(f"Sampled size: {sampled.shape}") - \ No newline at end of file + + sampled = edm_diffusion_model.sample(batch_size=10) + log.info(f"Sampled size: {sampled.shape}") \ No newline at end of file From 8c3c90dfee94f6cd6cffb1ca8c739a434093bf33 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Fri, 23 Aug 2024 07:53:23 +0000 Subject: [PATCH 11/14] feature(wrh): polish code with formula given in comments --- .../edm_diffusion_model.py | 72 ++++++++++++++----- .../edm_diffusion_model/edm_preconditioner.py | 65 +++++++++++++++-- 2 files changed, 114 insertions(+), 23 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index 7f968d1..f5eaa1f 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -42,14 +42,22 @@ class EDMModel(nn.Module): EDM class utilizes different params and executes different scheules during precondition, training and sample process. Sampling supports 1st order Euler step and 2nd order Heun step as Algorithm 1 in paper. For EDM type itself, stochastic sampler as Algorithm 2 in paper is also supported. + Interface: ``__init__``, ``forward``, ``sample`` + Reference: - EDM original paper: https://arxiv.org/abs/2206.00364 + EDM original paper link: https://arxiv.org/abs/2206.00364 Code reference: https://github.com/NVlabs/edm """ def __init__(self, config: Optional[EasyDict]=None) -> None: - + """ + Overview: + Initialization of EDMModel. + + Arguments: + config (:obj:`EasyDict`): The configuration. + """ super().__init__() self.config = config self.x_size = config.x_size @@ -100,13 +108,17 @@ def _sample_sigma_weight_train(self, x: Tensor) -> Tuple[Tensor, Tensor]: """ Overview: Sample sigma from given distribution for training according to edm type. + More details refer to Training section in the Table 1 of EDM paper. + + ..math: + \sigma\sim p_{\mathrm{train}}, \lambda(\sigma) Arguments: x (:obj:`torch.Tensor`): The sample which needs to add noise. Returns: sigma (:obj:`torch.Tensor`): Sampled sigma from the distribution. - weight (:obj:`torch.Tensor`): Loss weight obtained from sampled sigma. + weight (:obj:`torch.Tensor`): Loss weight lambda(sigma) obtained from sampled sigma. """ # assert the first dim of x is batch size @@ -138,20 +150,25 @@ def _sample_sigma_weight_train(self, x: Tensor) -> Tuple[Tensor, Tensor]: weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 return sigma, weight - def forward(self, x, condition=None): + def forward(self, x: Tensor, condition: Optional[Tensor]=None): return self.sample(x, condition) def L2_denoising_matching_loss( self, x: Tensor, condition: Optional[Tensor]=None - ): + ) -> Tensor: """ Overview: - Calculate the L2 denoising matching loss. + Calculate the L2 denoising matching loss. The denoise matching loss is given in Equation 2, 3 in EDM paper. + + ..math: + \mathbb{E}_{\sigma\sim p_{\mathrm{train}}}\mathbb{E}_{y\sim p_\mathrm{data}}\mathbb{E}_{n\sim \mathcal{N}(0, \sigma^2 \mathbf{I})} \left[\lambda(\sigma) \| \mathbf{D}(y+n) - y \|_2^2\right] + Arguments: x (:obj:`torch.Tensor`): The sample which needs to add noise. - condition (:obj:`torch.Tensor`): The condition for the sample. Default setting: None. + condition (:obj:`Optional[torch.Tensor]`): The condition for the sample. + Returns: loss (:obj:`torch.Tensor`): The L2 denoising matching loss. """ @@ -164,11 +181,15 @@ def L2_denoising_matching_loss( def _get_sigma_steps_t_steps(self, num_steps: int=18, - epsilon_s: float=1e-3, rho: Union[int, float]=7 - ): + epsilon_s: float=1e-3, + rho: Union[int, float]=7 + ) -> Tuple[Tensor, Tensor]: """ Overview: - Get the schedule of sigma according to differernt t schedules. + Get the schedule of sigma steps and t steps according to differernt t schedules (or sigma schedules). + + ..math: + \sigma_{i Tuple[Callable, Callable, Callable, Callable, Callable]: """ Overview: - Get sigma(t) for different solver schedules. + Get sigma(t) and scale(t) for different solver schedules. + More details in sampling section of Table 1 in EDM paper. + + ..math: + \sigma(t), \sigma^\prime(t), \sigma^{-1}(\sigma), s(t), s^\prime(t) Returns: sigma: (:obj:`Callable`): sigma(t) @@ -244,14 +269,27 @@ def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s: Union[int, float]=1e-3) \ def sample( self, - t_span: torch.Tensor = None, + t_span: Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, - x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, - condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + x_0: Union[Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[Tensor, TensorDict, treetensor.torch.Tensor] = None, with_grad: bool = False, solver_config: EasyDict = None, - ): + ) -> Tensor: + """ + Overview: + Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. + Rather, it is used for encode a sampled x to the latent space. + Arguments: + t_span (:obj:`torch.Tensor`): The time span. + batch_size: (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size of sampling. + x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. + with_grad (:obj:`bool`): Whether to return the gradient. + solver_config (:obj:`EasyDict`): The configuration of the solver. + + """ return self.sample_forward_process( t_span=t_span, batch_size=batch_size, diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index d8527d0..d01277d 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Literal +from typing import Union, Optional, Tuple, Literal from dataclasses import dataclass from torch import Tensor, as_tensor @@ -21,10 +21,23 @@ class PreConditioner(nn.Module): """ def __init__(self, precondition_type: Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] = "EDM", - denoise_model: nn.Module = None, + denoise_model: Optional[nn.Module] = None, use_mixes_precision: bool = False, **precond_params) -> None: + """ + Overview: + Initialize preconditioner for Network preconditioning in EDM. + More details in Network and Preconditioning in Section 5 of EDM paper. + + Arguments: + precondition_type (:obj:`Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"]`): The precond type. + denoise_model (:obj:`Optional[nn.Module]`): The basic denoise network. + use_mixes_precision (:obj:`bool`): If mixes precision is used. + Reference: + EDM original paper link: https://arxiv.org/abs/2206.00364 + Code reference: https://github.com/NVlabs/edm + """ super().__init__() log.info(f"Precond_params: {precond_params}") precond_params = EasyDict(precond_params) @@ -70,11 +83,20 @@ def alpha_bar(j): else: raise ValueError(f"Please check your precond type {self.precondition_type} is in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") - - + - def round_sigma(self, sigma: Tensor, return_index: bool=False) -> Tensor: + def round_sigma(self, sigma: Union[Tensor, float], return_index: bool=False) -> Tensor: + """ + Overview: + return sigma as tensor. When in iDDPM_edm mode, we need index as sigma. + Arguments: + sigma (:obj:`Union[torch.Tensor, float]`): Input sigma. + return_index (:obj:`bool`): whether index is returned. Only iDDPM_edm type needs it. + + Returns: + sigma (:obj:`torch.Tensor`): Output sigma in Tensor format. + """ if self.precondition_type == "iDDPM_edm": sigma = torch.as_tensor(sigma) index = torch.cdist(sigma.to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) @@ -84,7 +106,23 @@ def round_sigma(self, sigma: Tensor, return_index: bool=False) -> Tensor: return torch.as_tensor(sigma) def get_precondition_c(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Overview: + Obtain precondition c according to sigma including c_skip, c_out, c_in, c_noise + Accordig to section Network and preconditioning Table 1, 4 precondition functions are shown as follows: + + .. math:: + \mathbf{c}_{\mathrm{skip}}(\sigma), \mathbf{c}_{\mathrm{out}}(\sigma), \mathbf{c}_{\mathrm{in}}(\sigma), \mathbf{c}_{\mathrm{noise}}(\sigma) + Arguments: + sigma (:obj:`torch.Tensor`): Input sigma. + + Returns: + c_skip (:obj:`torch.Tensor`): Output c_skip(sigma). + c_out (:obj:`torch.Tensor`): Output c_out(sigma). + c_in (:obj:`torch.Tensor`): Output c_in(sigma). + c_noise (:obj:`torch.Tensor`): Output c_noise(sigma). + """ if self.precondition_type == "VP_edm": c_skip = 1 c_out = -sigma @@ -107,7 +145,22 @@ def get_precondition_c(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Ten c_noise = sigma.log() / 4 return c_skip, c_out, c_in, c_noise - def forward(self, sigma: Tensor, x: Tensor, condition: Tensor=None, **model_kwargs): + def forward(self, sigma: Tensor, x: Tensor, condition: Optional[Tensor]=None, **model_kwargs): + """ + Overview: + Obtain denoiser from basic denoise network and precondition scaling functions, which is given as follows: + + .. math: + \mathbf{D}_{\theta} (\mathbf{x}; \sigma; c) = \mathbf{c}_{\mathrm{skip}}(\sigma) \mathbf{x} + \mathbf{c}_{\mathrm{out}}(\sigma) \mathbf{F}_{\theta}(\mathbf{c}_{\mathrm{in}}(\sigma)\mathbf{x}; \mathbf{c}_{\mathrm{noise}}(\sigma); c) + + Arguments: + sigma (:obj:`torch.Tensor`): Input sigma. + x (:obj:`torch.Tensor`): Input x. + condition: (:obj:`Optional[torch.Tensor]`): Input condition. + + Returns: + D_x (:obj:`torch.Tensor`): Output denoiser. + """ # Suppose the first dim of x is batch size x = x.to(torch.float32) sigma_shape = [x.shape[0]] + [1] * (x.ndim - 1) From d004867391f4e43601fb5de5b1a3c84b31001f2d Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Tue, 27 Aug 2024 09:40:53 +0000 Subject: [PATCH 12/14] feature(wrh): edm convergence tested on swiss roll --- .../edm_diffusion_model.py | 42 ++++++++++++------- .../edm_diffusion_model/edm_preconditioner.py | 4 +- .../edm_diffusion_model/edm_utils.py | 10 ++--- .../swiss_roll/swiss_roll_edm_diffusion.py | 16 ++++--- 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index f5eaa1f..db7503e 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -27,11 +27,11 @@ from grl.utils import set_seed from grl.utils.log import log -from .edm_preconditioner import PreConditioner -from .edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV -from .edm_utils import SCALE_T, SCALE_T_DERIV -from .edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN -from .edm_utils import DEFAULT_PARAM, DEFAULT_SOLVER_PARAM +from grl.generative_models.edm_diffusion_model.edm_preconditioner import PreConditioner +from grl.generative_models.edm_diffusion_model.edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV +from grl.generative_models.edm_diffusion_model.edm_utils import SCALE_T, SCALE_T_DERIV +from grl.generative_models.edm_diffusion_model.edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN +from grl.generative_models.edm_diffusion_model.edm_utils import DEFAULT_PARAM, DEFAULT_SOLVER_PARAM class EDMModel(nn.Module): @@ -175,6 +175,7 @@ def L2_denoising_matching_loss( sigma, weight = self._sample_sigma_weight_train(x) n = torch.randn_like(x) * sigma + inv_t = SIGMA_T_INV[self.edm_type](sigma) # TODO: Use t? or sigma? as input D_xn = self.preconditioner(sigma, x+n, condition=condition) loss = weight * ((D_xn - x) ** 2) return loss.mean() @@ -308,7 +309,18 @@ def sample_forward_process( with_grad: bool = False, solver_config: EasyDict = None, ): + sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=self.solver_params.num_steps, epsilon_s=self.solver_params.epsilon_s, rho=self.solver_params.rho) + + sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv() + + S_churn = self.solver_params.S_churn + S_min = self.solver_params.S_min + S_max = self.solver_params.S_max + S_noise = self.solver_params.S_noise + alpha = self.solver_params.alpha + t_next = t_steps[0] + # x_next = x_0 * (sigma(t_next) * scale(t_next)) if t_span is not None: t_span = t_span.to(self.device) @@ -350,6 +362,7 @@ def sample_forward_process( x = self.gaussian_generator( batch_size=torch.prod(extra_batch_size) * data_batch_size ) + x = x * (sigma(t_next) * scale(t_next)) # x.shape = (B*N, D) else: if isinstance(self.x_size, int): @@ -377,28 +390,25 @@ def sample_forward_process( # condition.shape = (B*N, D) - sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=self.solver_params.num_steps, epsilon_s=self.solver_params.epsilon_s, rho=self.solver_params.rho) - sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv() - - S_churn = self.solver_params.S_churn - S_min = self.solver_params.S_min - S_max = self.solver_params.S_max - S_noise = self.solver_params.S_noise - alpha = self.solver_params.alpha # # Main sampling loop - # t_next = t_steps[0] - # x_next = x_0 * (sigma(t_next) * scale(t_next)) + + # x_next = torch.randn_like(x) + # x_list = [x_next] # for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 # x_cur = x_next # # Euler step. # h = t_next - t_cur # denoised = self.preconditioner(sigma(t_cur), x_cur / scale(t_cur), condition) - # d_cur = (sigma_deriv(t_cur) / sigma(t_cur) + scale_deriv(t_cur) / scale(t_cur)) * x_cur - sigma_deriv(t_cur) * scale(t_cur) / sigma(t_cur) * denoised + # d_cur = ((sigma_deriv(t_cur) / sigma(t_cur)) + (scale_deriv(t_cur) / scale(t_cur))) * x_cur - ((sigma_deriv(t_cur) * scale(t_cur)) / sigma(t_cur)) * denoised # x_next = x_cur + h * d_cur + # x_list.append(x_next) + + # return x_list + def drift(t, x): t_shape = [x.shape[0]] + [1] * (x.ndim - 1) diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index d01277d..e562554 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from grl.utils.log import log -from .edm_utils import SIGMA_T, SIGMA_T_INV +from grl.generative_models.edm_diffusion_model.edm_utils import SIGMA_T, SIGMA_T_INV class PreConditioner(nn.Module): """ @@ -169,7 +169,7 @@ def forward(self, sigma: Tensor, x: Tensor, condition: Optional[Tensor]=None, ** if sigma.numel() == 1: sigma = sigma.view(-1).expand(*sigma_shape) else: - sigma = sigma.view(*sigma_shape) + sigma = sigma.reshape(*sigma_shape) dtype = torch.float16 if (self.use_mixes_precision and x.device.type == 'cuda') else torch.float32 c_skip, c_out, c_in, c_noise = self.get_precondition_c(sigma) F_x = self.denoise_model(c_noise.flatten(), (c_in * x).to(dtype), condition=condition, **model_kwargs) diff --git a/grl/generative_models/edm_diffusion_model/edm_utils.py b/grl/generative_models/edm_diffusion_model/edm_utils.py index e31c16e..7e4b2a2 100644 --- a/grl/generative_models/edm_diffusion_model/edm_utils.py +++ b/grl/generative_models/edm_diffusion_model/edm_utils.py @@ -4,7 +4,7 @@ ############# Sampling Section ############# -# Scheduling in Table 1 in paper https://arxiv.org/abs/2206.00364 +# Scheduling in Table 1 of paper https://arxiv.org/abs/2206.00364 SIGMA_T = { "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: ((0.5 * beta_d * (t ** 2) + beta_min * t).exp() - 1) ** 0.5, "VE_edm": lambda t, **kwargs: t.sqrt(), @@ -13,10 +13,10 @@ } SIGMA_T_DERIV = { - "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: 0.5 * (beta_min + beta_d * t) * (SIGMA_T["VP_edm"](t, beta_d, beta_min) + 1 / SIGMA_T["VP_edm"](t, beta_d, beta_min)), - "VE_edm": lambda t, **kwargs: t.sqrt(), - "iDDPM_edm": lambda t, **kwargs: t, - "EDM": lambda t, **kwargs: t + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: 0.5 * (beta_min + beta_d * t) * (SIGMA_T["VP_edm"](t, beta_d, beta_min) + (1 / SIGMA_T["VP_edm"](t, beta_d, beta_min))), + "VE_edm": lambda t, **kwargs: 1 / (2 * t.sqrt()), + "iDDPM_edm": lambda t, **kwargs: 1, + "EDM": lambda t, **kwargs: 1 } SIGMA_T_INV = { diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py index 1e94e4f..93bf2c6 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -113,7 +113,7 @@ ), parameter=dict( training_loss_type="score_matching", - lr=5e-3, + lr=1e-4, data_num=10000, iterations=1000, batch_size=2048, @@ -233,10 +233,10 @@ def save_checkpoint(model, optimizer, iteration): ) history_iteration = [-1] - batch_data = next(data_generator) - batch_data = batch_data.to(config.device) + # batch_data = next(data_generator).to(config.device) - for i in range(10): + for i in range(10000): + batch_data = next(data_generator).to(config.device) edm_diffusion_model.train() loss = edm_diffusion_model.L2_denoising_matching_loss(batch_data) optimizer.zero_grad() @@ -255,5 +255,9 @@ def save_checkpoint(model, optimizer, iteration): edm_diffusion_model.eval() - sampled = edm_diffusion_model.sample(batch_size=10) - log.info(f"Sampled size: {sampled.shape}") \ No newline at end of file + sampled = edm_diffusion_model.sample(batch_size=1000) + log.info(f"Sampled size: {sampled.shape}") + + plt.scatter(sampled[:, 0].detach().cpu(), sampled[:, 1].detach().cpu(), s=1) + + plt.savefig("./result.png") \ No newline at end of file From c406de4f8dea92bfd5d85d6518f249741f0eea49 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Tue, 27 Aug 2024 09:53:24 +0000 Subject: [PATCH 13/14] feature(wrh): edm convergence tested on swiss roll --- .../edm_diffusion_model.py | 42 ++++++++++++------- .../edm_diffusion_model/edm_preconditioner.py | 4 +- .../edm_diffusion_model/edm_utils.py | 12 +++--- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py index f5eaa1f..db7503e 100644 --- a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -27,11 +27,11 @@ from grl.utils import set_seed from grl.utils.log import log -from .edm_preconditioner import PreConditioner -from .edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV -from .edm_utils import SCALE_T, SCALE_T_DERIV -from .edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN -from .edm_utils import DEFAULT_PARAM, DEFAULT_SOLVER_PARAM +from grl.generative_models.edm_diffusion_model.edm_preconditioner import PreConditioner +from grl.generative_models.edm_diffusion_model.edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV +from grl.generative_models.edm_diffusion_model.edm_utils import SCALE_T, SCALE_T_DERIV +from grl.generative_models.edm_diffusion_model.edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN +from grl.generative_models.edm_diffusion_model.edm_utils import DEFAULT_PARAM, DEFAULT_SOLVER_PARAM class EDMModel(nn.Module): @@ -175,6 +175,7 @@ def L2_denoising_matching_loss( sigma, weight = self._sample_sigma_weight_train(x) n = torch.randn_like(x) * sigma + inv_t = SIGMA_T_INV[self.edm_type](sigma) # TODO: Use t? or sigma? as input D_xn = self.preconditioner(sigma, x+n, condition=condition) loss = weight * ((D_xn - x) ** 2) return loss.mean() @@ -308,7 +309,18 @@ def sample_forward_process( with_grad: bool = False, solver_config: EasyDict = None, ): + sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=self.solver_params.num_steps, epsilon_s=self.solver_params.epsilon_s, rho=self.solver_params.rho) + + sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv() + + S_churn = self.solver_params.S_churn + S_min = self.solver_params.S_min + S_max = self.solver_params.S_max + S_noise = self.solver_params.S_noise + alpha = self.solver_params.alpha + t_next = t_steps[0] + # x_next = x_0 * (sigma(t_next) * scale(t_next)) if t_span is not None: t_span = t_span.to(self.device) @@ -350,6 +362,7 @@ def sample_forward_process( x = self.gaussian_generator( batch_size=torch.prod(extra_batch_size) * data_batch_size ) + x = x * (sigma(t_next) * scale(t_next)) # x.shape = (B*N, D) else: if isinstance(self.x_size, int): @@ -377,28 +390,25 @@ def sample_forward_process( # condition.shape = (B*N, D) - sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=self.solver_params.num_steps, epsilon_s=self.solver_params.epsilon_s, rho=self.solver_params.rho) - sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv() - - S_churn = self.solver_params.S_churn - S_min = self.solver_params.S_min - S_max = self.solver_params.S_max - S_noise = self.solver_params.S_noise - alpha = self.solver_params.alpha # # Main sampling loop - # t_next = t_steps[0] - # x_next = x_0 * (sigma(t_next) * scale(t_next)) + + # x_next = torch.randn_like(x) + # x_list = [x_next] # for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 # x_cur = x_next # # Euler step. # h = t_next - t_cur # denoised = self.preconditioner(sigma(t_cur), x_cur / scale(t_cur), condition) - # d_cur = (sigma_deriv(t_cur) / sigma(t_cur) + scale_deriv(t_cur) / scale(t_cur)) * x_cur - sigma_deriv(t_cur) * scale(t_cur) / sigma(t_cur) * denoised + # d_cur = ((sigma_deriv(t_cur) / sigma(t_cur)) + (scale_deriv(t_cur) / scale(t_cur))) * x_cur - ((sigma_deriv(t_cur) * scale(t_cur)) / sigma(t_cur)) * denoised # x_next = x_cur + h * d_cur + # x_list.append(x_next) + + # return x_list + def drift(t, x): t_shape = [x.shape[0]] + [1] * (x.ndim - 1) diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py index d01277d..e562554 100644 --- a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from grl.utils.log import log -from .edm_utils import SIGMA_T, SIGMA_T_INV +from grl.generative_models.edm_diffusion_model.edm_utils import SIGMA_T, SIGMA_T_INV class PreConditioner(nn.Module): """ @@ -169,7 +169,7 @@ def forward(self, sigma: Tensor, x: Tensor, condition: Optional[Tensor]=None, ** if sigma.numel() == 1: sigma = sigma.view(-1).expand(*sigma_shape) else: - sigma = sigma.view(*sigma_shape) + sigma = sigma.reshape(*sigma_shape) dtype = torch.float16 if (self.use_mixes_precision and x.device.type == 'cuda') else torch.float32 c_skip, c_out, c_in, c_noise = self.get_precondition_c(sigma) F_x = self.denoise_model(c_noise.flatten(), (c_in * x).to(dtype), condition=condition, **model_kwargs) diff --git a/grl/generative_models/edm_diffusion_model/edm_utils.py b/grl/generative_models/edm_diffusion_model/edm_utils.py index e31c16e..7ddb26c 100644 --- a/grl/generative_models/edm_diffusion_model/edm_utils.py +++ b/grl/generative_models/edm_diffusion_model/edm_utils.py @@ -4,7 +4,7 @@ ############# Sampling Section ############# -# Scheduling in Table 1 in paper https://arxiv.org/abs/2206.00364 +# Scheduling in Table 1 of paper https://arxiv.org/abs/2206.00364 SIGMA_T = { "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: ((0.5 * beta_d * (t ** 2) + beta_min * t).exp() - 1) ** 0.5, "VE_edm": lambda t, **kwargs: t.sqrt(), @@ -13,10 +13,10 @@ } SIGMA_T_DERIV = { - "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: 0.5 * (beta_min + beta_d * t) * (SIGMA_T["VP_edm"](t, beta_d, beta_min) + 1 / SIGMA_T["VP_edm"](t, beta_d, beta_min)), - "VE_edm": lambda t, **kwargs: t.sqrt(), - "iDDPM_edm": lambda t, **kwargs: t, - "EDM": lambda t, **kwargs: t + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: 0.5 * (beta_min + beta_d * t) * (SIGMA_T["VP_edm"](t, beta_d, beta_min) + (1 / SIGMA_T["VP_edm"](t, beta_d, beta_min))), + "VE_edm": lambda t, **kwargs: 1 / (2 * t.sqrt()), + "iDDPM_edm": lambda t, **kwargs: 1, + "EDM": lambda t, **kwargs: 1 } SIGMA_T_INV = { @@ -97,4 +97,4 @@ "S_max": float("inf"), "S_noise": 1., "alpha": 1 -}) \ No newline at end of file +}) From 5265393d6ac5fb3eb88a8a6b412fb0b6f22c34ff Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Tue, 27 Aug 2024 13:08:54 +0000 Subject: [PATCH 14/14] feature(wrh): add edm interface in doc folder --- docs/source/api_doc/generative_models/index.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/api_doc/generative_models/index.rst b/docs/source/api_doc/generative_models/index.rst index 6e560e2..43f478f 100644 --- a/docs/source/api_doc/generative_models/index.rst +++ b/docs/source/api_doc/generative_models/index.rst @@ -28,3 +28,9 @@ OptimalTransportConditionalFlowModel .. autoclass:: OptimalTransportConditionalFlowModel :special-members: __init__ :members: + +EDMDiffusionModel +------------------------------- +.. autoclass:: EDMModel + :special-members: __init__ + :members: \ No newline at end of file