@@ -41,46 +41,50 @@ def __init__(
4141 # Initialize decoders based on tying configuration
4242 if config .decoder_tying == "per_source" :
4343 # Tied decoders: one decoder per source layer
44- self .decoders = nn .ModuleList ([
45- RowParallelLinear (
46- in_features = self .config .num_features ,
47- out_features = self .config .d_model ,
48- bias = True ,
49- process_group = self .process_group ,
50- input_is_parallel = False ,
51- d_model_for_init = self .config .d_model ,
52- num_layers_for_init = self .config .num_layers ,
53- device = self .device ,
54- dtype = self .dtype ,
55- )
56- for _ in range (self .config .num_layers )
57- ])
44+ self .decoders = nn .ModuleList (
45+ [
46+ RowParallelLinear (
47+ in_features = self .config .num_features ,
48+ out_features = self .config .d_model ,
49+ bias = True ,
50+ process_group = self .process_group ,
51+ input_is_parallel = False ,
52+ d_model_for_init = self .config .d_model ,
53+ num_layers_for_init = self .config .num_layers ,
54+ device = self .device ,
55+ dtype = self .dtype ,
56+ )
57+ for _ in range (self .config .num_layers )
58+ ]
59+ )
5860 elif config .decoder_tying == "per_target" :
5961 # Tied decoders: one decoder per target layer (EleutherAI style)
60- self .decoders = nn .ModuleList ([
61- RowParallelLinear (
62- in_features = self .config .num_features ,
63- out_features = self .config .d_model ,
64- bias = True ,
65- process_group = self .process_group ,
66- input_is_parallel = False ,
67- d_model_for_init = self .config .d_model ,
68- num_layers_for_init = self .config .num_layers ,
69- device = self .device ,
70- dtype = self .dtype ,
71- )
72- for _ in range (self .config .num_layers )
73- ])
74-
62+ self .decoders = nn .ModuleList (
63+ [
64+ RowParallelLinear (
65+ in_features = self .config .num_features ,
66+ out_features = self .config .d_model ,
67+ bias = True ,
68+ process_group = self .process_group ,
69+ input_is_parallel = False ,
70+ d_model_for_init = self .config .d_model ,
71+ num_layers_for_init = self .config .num_layers ,
72+ device = self .device ,
73+ dtype = self .dtype ,
74+ )
75+ for _ in range (self .config .num_layers )
76+ ]
77+ )
78+
7579 # Initialize decoder weights to zeros for tied decoders (both per_source and per_target)
7680 if config .decoder_tying in ["per_source" , "per_target" ]:
7781 for decoder in self .decoders :
7882 nn .init .zeros_ (decoder .weight )
79- if hasattr (decoder , ' bias_param' ) and decoder .bias_param is not None :
83+ if hasattr (decoder , " bias_param" ) and decoder .bias_param is not None :
8084 nn .init .zeros_ (decoder .bias_param )
81- elif hasattr (decoder , ' bias' ) and decoder .bias is not None :
85+ elif hasattr (decoder , " bias" ) and decoder .bias is not None :
8286 nn .init .zeros_ (decoder .bias )
83-
87+
8488 # Note: EleutherAI doesn't have per-target scale/bias parameters
8589 # These have been removed to match their architecture exactly
8690 else :
@@ -103,64 +107,75 @@ def __init__(
103107 }
104108 )
105109 # Note: EleutherAI doesn't have per-target scale/bias parameters
106-
110+
107111 # Initialize skip connection weights if enabled
108112 if config .skip_connection :
109113 if config .decoder_tying in ["per_source" , "per_target" ]:
110114 # For tied decoders, one skip connection per target layer
111- self .skip_weights = nn .ParameterList ([
112- nn .Parameter (torch .zeros (self .config .d_model , self .config .d_model ,
113- device = self .device , dtype = self .dtype ))
114- for _ in range (self .config .num_layers )
115- ])
115+ self .skip_weights = nn .ParameterList (
116+ [
117+ nn .Parameter (
118+ torch .zeros (self .config .d_model , self .config .d_model , device = self .device , dtype = self .dtype )
119+ )
120+ for _ in range (self .config .num_layers )
121+ ]
122+ )
116123 else :
117124 # For untied decoders, one skip connection per src->tgt pair
118- self .skip_weights = nn .ParameterDict ({
119- f"{ src_layer } ->{ tgt_layer } " : nn .Parameter (
120- torch .zeros (self .config .d_model , self .config .d_model ,
121- device = self .device , dtype = self .dtype )
122- )
123- for src_layer in range (self .config .num_layers )
124- for tgt_layer in range (src_layer , self .config .num_layers )
125- })
125+ self .skip_weights = nn .ParameterDict (
126+ {
127+ f"{ src_layer } ->{ tgt_layer } " : nn .Parameter (
128+ torch .zeros (self .config .d_model , self .config .d_model , device = self .device , dtype = self .dtype )
129+ )
130+ for src_layer in range (self .config .num_layers )
131+ for tgt_layer in range (src_layer , self .config .num_layers )
132+ }
133+ )
126134 else :
127135 self .skip_weights = None
128-
136+
129137 # Initialize feature_offset and feature_scale (indexed by target layer)
130138 # These match EleutherAI's post_enc and post_enc_scale
131139 # Note: Currently only implemented for tied decoders to match EleutherAI
132140 # For per_source tying, these would need to be indexed differently
133141 if config .decoder_tying in ["per_source" , "per_target" ]:
134142 features_per_rank = config .num_features // self .world_size if self .world_size > 1 else config .num_features
135-
143+
136144 if config .enable_feature_offset :
137145 # Initialize feature_offset for each target layer
138- self .feature_offset = nn .ParameterList ([
139- nn .Parameter (torch .zeros (features_per_rank , device = self .device , dtype = self .dtype ))
140- for _ in range (config .num_layers )
141- ])
146+ self .feature_offset = nn .ParameterList (
147+ [
148+ nn .Parameter (torch .zeros (features_per_rank , device = self .device , dtype = self .dtype ))
149+ for _ in range (config .num_layers )
150+ ]
151+ )
142152 else :
143153 self .feature_offset = None
144-
154+
145155 if config .enable_feature_scale :
146156 # Initialize feature_scale for each target layer
147157 # First target layer gets ones, rest get small non-zero values to allow gradient flow
148- self .feature_scale = nn .ParameterList ([
149- nn .Parameter (
150- torch .ones (features_per_rank , device = self .device , dtype = self .dtype ) if i == 0
151- else torch .full ((features_per_rank ,), 0.1 , device = self .device , dtype = self .dtype )
152- )
153- for i in range (config .num_layers )
154- ])
158+ self .feature_scale = nn .ParameterList (
159+ [
160+ nn .Parameter (
161+ torch .ones (features_per_rank , device = self .device , dtype = self .dtype )
162+ if i == 0
163+ else torch .full ((features_per_rank ,), 0.1 , device = self .device , dtype = self .dtype )
164+ )
165+ for i in range (config .num_layers )
166+ ]
167+ )
155168 else :
156169 self .feature_scale = None
157170 else :
158171 self .feature_offset = None
159172 self .feature_scale = None
160-
173+
161174 self .register_buffer ("_cached_decoder_norms" , None , persistent = False )
162175
163- def decode (self , a : Dict [int , torch .Tensor ], layer_idx : int , source_inputs : Optional [Dict [int , torch .Tensor ]] = None ) -> torch .Tensor :
176+ def decode (
177+ self , a : Dict [int , torch .Tensor ], layer_idx : int , source_inputs : Optional [Dict [int , torch .Tensor ]] = None
178+ ) -> torch .Tensor :
164179 """Decode the feature activations to reconstruct outputs at the specified layer.
165180
166181 Input activations `a` are expected to be the *full* tensors.
@@ -192,8 +207,10 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
192207
193208 if self .config .decoder_tying == "per_target" :
194209 # EleutherAI style: sum activations first, then decode once
195- summed_activation = torch .zeros ((batch_dim_size , self .config .num_features ), device = self .device , dtype = self .dtype )
196-
210+ summed_activation = torch .zeros (
211+ (batch_dim_size , self .config .num_features ), device = self .device , dtype = self .dtype
212+ )
213+
197214 for src_layer in range (layer_idx + 1 ):
198215 if src_layer in a :
199216 activation_tensor = a [src_layer ].to (device = self .device , dtype = self .dtype )
@@ -211,48 +228,28 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
211228 if self .feature_offset is not None or self .feature_scale is not None :
212229 # Get non-zero positions (selected features)
213230 nonzero_mask = activation_tensor != 0
214-
231+
215232 if nonzero_mask .any ():
216233 # Apply transformations only to selected features
217234 activation_tensor = activation_tensor .clone ()
218235 batch_indices , feature_indices = nonzero_mask .nonzero (as_tuple = True )
219-
236+
220237 if self .feature_offset is not None :
221238 # Apply offset only to non-zero features
222239 offset_values = self .feature_offset [layer_idx ][feature_indices ]
223240 activation_tensor [batch_indices , feature_indices ] += offset_values
224-
241+
225242 if self .feature_scale is not None :
226243 # Apply scale only to non-zero features
227244 scale_values = self .feature_scale [layer_idx ][feature_indices ]
228245 activation_tensor [batch_indices , feature_indices ] *= scale_values
229-
246+
230247 summed_activation += activation_tensor
231-
248+
232249 # Now decode ONCE with the summed activation
233250 decoder = self .decoders [layer_idx ]
234251 reconstruction = decoder (summed_activation )
235-
236- # Apply skip connections from source inputs if enabled
237- if self .skip_weights is not None and source_inputs is not None :
238- skip_weight = self .skip_weights [layer_idx ]
239- # Add skip connections from each source layer that contributed
240- for src_layer in range (layer_idx + 1 ):
241- if src_layer in source_inputs :
242- source_input = source_inputs [src_layer ].to (device = self .device , dtype = self .dtype )
243- # Flatten if needed
244- original_shape = source_input .shape
245- if source_input .dim () == 3 :
246- source_input_2d = source_input .view (- 1 , source_input .shape [- 1 ])
247- else :
248- source_input_2d = source_input
249- # Apply skip: source @ W_skip^T
250- skip_contribution = source_input_2d @ skip_weight .T
251- # Reshape back if needed
252- if source_input .dim () == 3 :
253- skip_contribution = skip_contribution .view (original_shape )
254- reconstruction += skip_contribution
255-
252+
256253 else :
257254 # Original logic for per_source and untied decoders
258255 for src_layer in range (layer_idx + 1 ):
@@ -271,17 +268,17 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
271268 if self .config .decoder_tying == "per_source" :
272269 # Get non-zero positions (selected features)
273270 nonzero_mask = activation_tensor != 0
274-
271+
275272 if nonzero_mask .any ():
276273 # Apply transformations only to selected features
277274 activation_tensor = activation_tensor .clone ()
278275 batch_indices , feature_indices = nonzero_mask .nonzero (as_tuple = True )
279-
276+
280277 if self .feature_offset is not None :
281278 # Apply offset indexed by target layer
282279 offset_values = self .feature_offset [layer_idx ][feature_indices ]
283280 activation_tensor [batch_indices , feature_indices ] += offset_values
284-
281+
285282 if self .feature_scale is not None :
286283 # Apply scale indexed by target layer
287284 scale_values = self .feature_scale [layer_idx ][feature_indices ]
@@ -291,10 +288,17 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
291288 # Use tied decoder for the source layer
292289 decoder = self .decoders [src_layer ]
293290 decoded = decoder (activation_tensor )
294-
295- # Apply skip connection from this source input if enabled
296- if self .skip_weights is not None and source_inputs is not None and src_layer in source_inputs :
297- skip_weight = self .skip_weights [layer_idx ]
291+
292+ else :
293+ # Use untied decoder for (src, tgt) pair
294+ decoder = self .decoders [f"{ src_layer } ->{ layer_idx } " ]
295+ decoded = decoder (activation_tensor )
296+
297+ # Apply skip connection from this source input if enabled
298+ if self .skip_weights is not None and source_inputs is not None and src_layer in source_inputs :
299+ skip_key = f"{ src_layer } ->{ layer_idx } "
300+ if skip_key in self .skip_weights :
301+ skip_weight = self .skip_weights [skip_key ]
298302 source_input = source_inputs [src_layer ].to (device = self .device , dtype = self .dtype )
299303 # Flatten if needed
300304 original_shape = source_input .shape
@@ -308,32 +312,31 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
308312 if source_input .dim () == 3 :
309313 skip_contribution = skip_contribution .view (original_shape )
310314 decoded += skip_contribution
311- else :
312- # Use untied decoder for (src, tgt) pair
313- decoder = self .decoders [f"{ src_layer } ->{ layer_idx } " ]
314- decoded = decoder (activation_tensor )
315-
316- # Apply skip connection from this source input if enabled
317- if self .skip_weights is not None and source_inputs is not None and src_layer in source_inputs :
318- skip_key = f"{ src_layer } ->{ layer_idx } "
319- if skip_key in self .skip_weights :
320- skip_weight = self .skip_weights [skip_key ]
321- source_input = source_inputs [src_layer ].to (device = self .device , dtype = self .dtype )
322- # Flatten if needed
323- original_shape = source_input .shape
324- if source_input .dim () == 3 :
325- source_input_2d = source_input .view (- 1 , source_input .shape [- 1 ])
326- else :
327- source_input_2d = source_input
328- # Apply skip: source @ W_skip^T
329- skip_contribution = source_input_2d @ skip_weight .T
330- # Reshape back if needed
331- if source_input .dim () == 3 :
332- skip_contribution = skip_contribution .view (original_shape )
333- decoded += skip_contribution
334-
315+
335316 reconstruction += decoded
336-
317+
318+ # For tied decoders, apply a single skip connection from the target layer's own input
319+ if self .config .decoder_tying in ["per_source" , "per_target" ]:
320+ if self .skip_weights is not None and source_inputs is not None and layer_idx in source_inputs :
321+ skip_weight = self .skip_weights [layer_idx ]
322+ source_input = source_inputs [layer_idx ].to (device = self .device , dtype = self .dtype )
323+
324+ # Flatten if needed
325+ original_shape = source_input .shape
326+ if source_input .dim () == 3 :
327+ source_input_2d = source_input .view (- 1 , source_input .shape [- 1 ])
328+ else :
329+ source_input_2d = source_input
330+
331+ # Apply skip: source @ W_skip^T
332+ skip_contribution = source_input_2d @ skip_weight .T
333+
334+ # Reshape back if needed
335+ if source_input .dim () == 3 :
336+ skip_contribution = skip_contribution .view (original_shape )
337+
338+ reconstruction += skip_contribution
339+
337340 return reconstruction
338341
339342 def get_decoder_norms (self ) -> torch .Tensor :
0 commit comments