diff --git a/animatediff/models/attention.py b/animatediff/models/attention.py index ad23583..d5268f7 100644 --- a/animatediff/models/attention.py +++ b/animatediff/models/attention.py @@ -8,10 +8,10 @@ from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin +from diffusers import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available -from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm +from diffusers.models.attention import Attention, FeedForward, AdaLayerNorm from einops import rearrange, repeat import pdb @@ -178,7 +178,7 @@ def __init__( upcast_attention=upcast_attention, ) else: - self.attn1 = CrossAttention( + self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, @@ -190,7 +190,7 @@ def __init__( # Cross-Attn if cross_attention_dim is not None: - self.attn2 = CrossAttention( + self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, @@ -214,7 +214,7 @@ def __init__( # Temp-Attn assert unet_use_temporal_attention is not None if unet_use_temporal_attention: - self.attn_temp = CrossAttention( + self.attn_temp = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, @@ -225,7 +225,7 @@ def __init__( nn.init.zeros_(self.attn_temp.to_out[0].weight.data) self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): if not is_xformers_available(): print("Here is how to install it") raise ModuleNotFoundError( diff --git a/animatediff/models/motion_module.py b/animatediff/models/motion_module.py index 494ce32..9071e46 100644 --- a/animatediff/models/motion_module.py +++ b/animatediff/models/motion_module.py @@ -6,12 +6,14 @@ import torch.nn.functional as F from torch import nn import torchvision +import diffusers +from packaging import version from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available -from diffusers.models.attention import CrossAttention, FeedForward +from diffusers.models.attention import Attention, FeedForward from einops import rearrange, repeat import math @@ -250,7 +252,7 @@ def forward(self, x): return self.dropout(x) -class VersatileAttention(CrossAttention): +class VersatileAttention(Attention): def __init__( self, attention_mode = None, @@ -290,45 +292,60 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None encoder_hidden_states = encoder_hidden_states - if self.group_norm is not None: - hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + if version.parse(diffusers.__version__) > version.parse("0.11.1"): + hidden_states = self.processor(self, hidden_states, encoder_hidden_states) + else: + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = self.to_q(hidden_states) - dim = query.shape[-1] - query = self.reshape_heads_to_batch_dim(query) + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.head_to_batch_dim(query) - if self.added_kv_proj_dim is not None: - raise NotImplementedError + if self.added_kv_proj_dim is not None: + raise NotImplementedError - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) + key = self.head_to_batch_dim(key) + value = self.head_to_batch_dim(value) - if attention_mask is not None: - if attention_mask.shape[-1] != query.shape[1]: - target_length = query.shape[1] - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) - # attention, what we cannot get enough of - if self._use_memory_efficient_attention_xformers: - hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) - # Some versions of xformers return output in fp32, cast it back to the dtype of the input - hidden_states = hidden_states.to(query.dtype) - else: - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value, attention_mask) + # attention, what we cannot get enough of + + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + #if "xformers" in self.processor.__class__.__name__.lower(): + # hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attention_mask) + # # Some versions of xformers return output in fp32, cast it back to the dtype of the input + # hidden_states = hidden_states.to(query.dtype) + #else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) - # linear proj - hidden_states = self.to_out[0](hidden_states) + # linear proj + hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) if self.attention_mode == "Temporal": hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) diff --git a/animatediff/models/sparse_controlnet.py b/animatediff/models/sparse_controlnet.py index f319e12..1138798 100644 --- a/animatediff/models/sparse_controlnet.py +++ b/animatediff/models/sparse_controlnet.py @@ -23,7 +23,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps -from diffusers.modeling_utils import ModelMixin +from diffusers import ModelMixin from .unet_blocks import ( @@ -35,7 +35,7 @@ from einops import repeat, rearrange from .resnet import InflatedConv3d -from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/animatediff/models/unet.py b/animatediff/models/unet.py index 1d77e78..3dadb7f 100644 --- a/animatediff/models/unet.py +++ b/animatediff/models/unet.py @@ -12,7 +12,7 @@ import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin +from diffusers import ModelMixin from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from .unet_blocks import ( @@ -499,12 +499,23 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition "CrossAttnUpBlock3D" ] - from diffusers.utils import WEIGHTS_NAME + from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safe = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME) + + if os.path.isfile(model_file_safe): + model_file = model_file_safe + if not os.path.isfile(model_file): raise RuntimeError(f"{model_file} does not exist") - state_dict = torch.load(model_file, map_location="cpu") + + if SAFETENSORS_WEIGHTS_NAME in model_file: + from safetensors.torch import load_file + state_dict = load_file(model_file) + else: + state_dict = torch.load(model_file, map_location="cpu") m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") diff --git a/animatediff/pipelines/pipeline_animation.py b/animatediff/pipelines/pipeline_animation.py index bcc1ddb..8e2486e 100644 --- a/animatediff/pipelines/pipeline_animation.py +++ b/animatediff/pipelines/pipeline_animation.py @@ -14,7 +14,7 @@ from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL -from diffusers.pipeline_utils import DiffusionPipeline +from diffusers import DiffusionPipeline from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, diff --git a/requirements.txt b/requirements.txt index 38107c5..53f48f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -diffusers==0.11.1 +diffusers torch pyav torchvision @@ -14,4 +14,5 @@ omegaconf safetensors gradio wandb -lion-pytorch \ No newline at end of file +lion-pytorch +peft \ No newline at end of file diff --git a/train.py b/train.py index 1b17922..a0ef1c7 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,7 @@ import diffusers from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler -from diffusers.models import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.pipelines import StableDiffusionPipeline from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version