Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions animatediff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
from diffusers.models.attention import Attention, FeedForward, AdaLayerNorm

from einops import rearrange, repeat
import pdb
Expand Down Expand Up @@ -178,7 +178,7 @@ def __init__(
upcast_attention=upcast_attention,
)
else:
self.attn1 = CrossAttention(
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
Expand All @@ -190,7 +190,7 @@ def __init__(

# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
Expand All @@ -214,7 +214,7 @@ def __init__(
# Temp-Attn
assert unet_use_temporal_attention is not None
if unet_use_temporal_attention:
self.attn_temp = CrossAttention(
self.attn_temp = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
Expand All @@ -225,7 +225,7 @@ def __init__(
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
Expand Down
81 changes: 49 additions & 32 deletions animatediff/models/motion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import torch.nn.functional as F
from torch import nn
import torchvision
import diffusers
from packaging import version

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import CrossAttention, FeedForward
from diffusers.models.attention import Attention, FeedForward

from einops import rearrange, repeat
import math
Expand Down Expand Up @@ -250,7 +252,7 @@ def forward(self, x):
return self.dropout(x)


class VersatileAttention(CrossAttention):
class VersatileAttention(Attention):
def __init__(
self,
attention_mode = None,
Expand Down Expand Up @@ -290,45 +292,60 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None

encoder_hidden_states = encoder_hidden_states

if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if version.parse(diffusers.__version__) > version.parse("0.11.1"):
hidden_states = self.processor(self, hidden_states, encoder_hidden_states)
else:
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.head_to_batch_dim(query)

if self.added_kv_proj_dim is not None:
raise NotImplementedError
if self.added_kv_proj_dim is not None:
raise NotImplementedError

encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)

key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
key = self.head_to_batch_dim(key)
value = self.head_to_batch_dim(value)

if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)

# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
# attention, what we cannot get enough of

if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
#if "xformers" in self.processor.__class__.__name__.lower():
# hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attention_mask)
# # Some versions of xformers return output in fp32, cast it back to the dtype of the input
# hidden_states = hidden_states.to(query.dtype)
#else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = self.to_out[0](hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)

# dropout
hidden_states = self.to_out[1](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)

if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
Expand Down
4 changes: 2 additions & 2 deletions animatediff/models/sparse_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.modeling_utils import ModelMixin
from diffusers import ModelMixin


from .unet_blocks import (
Expand All @@ -35,7 +35,7 @@
from einops import repeat, rearrange
from .resnet import InflatedConv3d

from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down
17 changes: 14 additions & 3 deletions animatediff/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.utils.checkpoint

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers import ModelMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
Expand Down Expand Up @@ -499,12 +499,23 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition
"CrossAttnUpBlock3D"
]

from diffusers.utils import WEIGHTS_NAME
from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
model = cls.from_config(config, **unet_additional_kwargs)

model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safe = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)

if os.path.isfile(model_file_safe):
model_file = model_file_safe

if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = torch.load(model_file, map_location="cpu")

if SAFETENSORS_WEIGHTS_NAME in model_file:
from safetensors.torch import load_file
state_dict = load_file(model_file)
else:
state_dict = torch.load(model_file, map_location="cpu")

m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
Expand Down
2 changes: 1 addition & 1 deletion animatediff/pipelines/pipeline_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers import DiffusionPipeline
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
diffusers==0.11.1
diffusers
torch
pyav
torchvision
Expand All @@ -14,4 +14,5 @@ omegaconf
safetensors
gradio
wandb
lion-pytorch
lion-pytorch
peft
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import diffusers
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler
from diffusers.models import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines import StableDiffusionPipeline
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
Expand Down