Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions pasd/myutils/vaehook.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def attn_forward_new(self, h_):
return hidden_states

def attn_forward_new_pt2_0(self, hidden_states,):
scale = 1
attention_mask = None
encoder_hidden_states = None

Expand All @@ -194,15 +193,15 @@ def attn_forward_new_pt2_0(self, hidden_states,):
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = self.to_q(hidden_states, scale=scale)
query = self.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)

key = self.to_k(encoder_hidden_states, scale=scale)
value = self.to_v(encoder_hidden_states, scale=scale)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
Expand All @@ -222,7 +221,7 @@ def attn_forward_new_pt2_0(self, hidden_states,):
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = self.to_out[0](hidden_states, scale=scale)
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)

Expand All @@ -232,7 +231,6 @@ def attn_forward_new_pt2_0(self, hidden_states,):
return hidden_states

def attn_forward_new_xformers(self, hidden_states):
scale = 1
attention_op = None
attention_mask = None
encoder_hidden_states = None
Expand Down Expand Up @@ -261,15 +259,15 @@ def attn_forward_new_xformers(self, hidden_states):
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = self.to_q(hidden_states, scale=scale)
query = self.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)

key = self.to_k(encoder_hidden_states, scale=scale)
value = self.to_v(encoder_hidden_states, scale=scale)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)

query = self.head_to_batch_dim(query).contiguous()
key = self.head_to_batch_dim(key).contiguous()
Expand All @@ -282,7 +280,7 @@ def attn_forward_new_xformers(self, hidden_states):
hidden_states = self.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = self.to_out[0](hidden_states, scale=scale)
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)

Expand Down