diff --git a/pasd/myutils/vaehook.py b/pasd/myutils/vaehook.py index 0fc09e5..4bde630 100644 --- a/pasd/myutils/vaehook.py +++ b/pasd/myutils/vaehook.py @@ -171,7 +171,6 @@ def attn_forward_new(self, h_): return hidden_states def attn_forward_new_pt2_0(self, hidden_states,): - scale = 1 attention_mask = None encoder_hidden_states = None @@ -194,15 +193,15 @@ def attn_forward_new_pt2_0(self, hidden_states,): if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = self.to_q(hidden_states, scale=scale) + query = self.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - key = self.to_k(encoder_hidden_states, scale=scale) - value = self.to_v(encoder_hidden_states, scale=scale) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // self.heads @@ -222,7 +221,7 @@ def attn_forward_new_pt2_0(self, hidden_states,): hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = self.to_out[0](hidden_states, scale=scale) + hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) @@ -232,7 +231,6 @@ def attn_forward_new_pt2_0(self, hidden_states,): return hidden_states def attn_forward_new_xformers(self, hidden_states): - scale = 1 attention_op = None attention_mask = None encoder_hidden_states = None @@ -261,15 +259,15 @@ def attn_forward_new_xformers(self, hidden_states): if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = self.to_q(hidden_states, scale=scale) + query = self.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - key = self.to_k(encoder_hidden_states, scale=scale) - value = self.to_v(encoder_hidden_states, scale=scale) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) query = self.head_to_batch_dim(query).contiguous() key = self.head_to_batch_dim(key).contiguous() @@ -282,7 +280,7 @@ def attn_forward_new_xformers(self, hidden_states): hidden_states = self.batch_to_head_dim(hidden_states) # linear proj - hidden_states = self.to_out[0](hidden_states, scale=scale) + hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states)