From e2c47c0a89c4cbe24c8391bd80e868d35d8f44eb Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Wed, 21 May 2025 05:19:04 -0400 Subject: [PATCH 01/13] add new transducer parts and assemblies --- i6_models/assemblies/transducer/__init__.py | 2 + .../assemblies/transducer/joint_network.py | 108 ++++++++ .../transducer/prediction_network.py | 236 ++++++++++++++++++ i6_models/parts/ffnn.py | 72 +++++- 4 files changed, 416 insertions(+), 2 deletions(-) create mode 100644 i6_models/assemblies/transducer/__init__.py create mode 100644 i6_models/assemblies/transducer/joint_network.py create mode 100644 i6_models/assemblies/transducer/prediction_network.py diff --git a/i6_models/assemblies/transducer/__init__.py b/i6_models/assemblies/transducer/__init__.py new file mode 100644 index 00000000..04c78687 --- /dev/null +++ b/i6_models/assemblies/transducer/__init__.py @@ -0,0 +1,2 @@ +from .prediction_network import * +from .joint_network import * diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py new file mode 100644 index 00000000..4287eaa2 --- /dev/null +++ b/i6_models/assemblies/transducer/joint_network.py @@ -0,0 +1,108 @@ +__all__ = ["TransducerJointNetworkV1Config", "TransducerJointNetworkV1"] + +from dataclasses import dataclass +from typing import Any, Dict, Tuple, Union + +import torch +from torch import nn + +from i6_models.config import ModelConfiguration +from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 + + +@dataclass +class TransducerJointNetworkV1Config(ModelConfiguration): + """ + Configuration for the Transducer Joint Network. + + Attributes: + ffnn_cfg: Configuration for the internal feed-forward network. + num_layers: Number of FFNN prediction network layer + """ + + ffnn_cfg: FeedForwardBlockV1Config + + +class TransducerJointNetworkV1(nn.Module): + def __init__( + self, + cfg: TransducerJointNetworkV1Config, + ) -> None: + super().__init__() + self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) + + def forward( + self, + source_encodings: torch.Tensor, # [B, T, E] + target_encodings: torch.Tensor, # [B, S, P] + ) -> torch.Tensor: # [B, T, S, F] + """ + Forward pass for recognition. Assume T = 1 and S = 1 + """ + source_encodings = source_encodings.unsqueeze(2).expand( + -1, -1, target_encodings.size(1), -1 + ) # [B, T, S, E] + target_encodings = target_encodings.unsqueeze(1).expand( + -1, source_encodings.size(1), -1, -1 + ) # [B, T, S, P] + joint_network_inputs = torch.cat( + [source_encodings, target_encodings], dim=-1 + ) # [B, T, S, E + P] + output = self.ffnn(joint_network_inputs) # [B, T, S, F] + + if not self.training: + output = torch.log_softmax(output, dim=-1) # [B, T, S, F] + return output + + def forward_viterbi( + self, + source_encodings: torch.Tensor, # [B, T, E] + source_lengths: torch.Tensor, # [B] + target_encodings: torch.Tensor, # [B, T, P] + target_lengths: torch.Tensor, # [B] + ) -> torch.Tensor: # [B, T, F] + """ + Forward pass for Viterbi training. + """ + joint_network_inputs = torch.cat( + [source_encodings, target_encodings], dim=-1 + ) # [B, T, E + P] + output = self.ffnn(joint_network_inputs) # [B, T, F] + if not self.training: + output = torch.log_softmax(output, dim=-1) # [B, T, F] + return output + + def forward_fullsum( + self, + source_encodings: torch.Tensor, # [B, T, E] + source_lengths: torch.Tensor, # [B] + target_encodings: torch.Tensor, # [B, S+1, P] + target_lengths: torch.Tensor, # [B] + ) -> torch.Tensor: # [B, T, S+1, F] + """ + Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. + """ + batch_outputs = [] + for b in range(source_encodings.size(0)): + valid_source = source_encodings[b, : source_lengths[b], :] # [T_b, E] + valid_target = target_encodings[b, : target_lengths[b] + 1, :] # [S_b+1, P] + + expanded_source = valid_source.unsqueeze(1).expand( + -1, int(target_lengths[b].item()) + 1, -1 + ) # [T_b, S_b+1, E] + expanded_target = valid_target.unsqueeze(0).expand( + int(source_lengths[b].item()), -1, -1 + ) # [T_b, S_b+1, P] + combination = torch.cat( + [expanded_source, expanded_target], dim=-1 + ) # [T_b, S_b+1, E + P] + output = self.ffnn(combination) # [T_b, S_b+1, F] + batch_outputs.append(output) + # Pad the outputs to a common shape, if necessary. This is crucial for + # handling variable sequence lengths within a batch. The padding + # ensures that the final output tensor has a consistent shape. + padded_outputs = torch.nn.utils.rnn.pad_sequence( + batch_outputs, batch_first=True, padding_value=0.0 + ) # [B, max_T, max_S+1, F] + + return padded_outputs diff --git a/i6_models/assemblies/transducer/prediction_network.py b/i6_models/assemblies/transducer/prediction_network.py new file mode 100644 index 00000000..f56b34e3 --- /dev/null +++ b/i6_models/assemblies/transducer/prediction_network.py @@ -0,0 +1,236 @@ +__all__ = [ + "EmbeddingTransducerPredictionNetworkV1Config", + "EmbeddingTransducerPredictionNetworkV1", + "FfnnTransducerPredictionNetworkV1Config", + "FfnnTransducerPredictionNetworkV1", +] + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from i6_models.config import ModelConfiguration +from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 +from i6_models.parts.lstm import LstmBlockV1Config, LstmBlockV1 + + +@dataclass +class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration): + """ + num_outputs: Number of output units (vocabulary size + blank). + blank_id: Index of the blank token. + context_history_size: Number of previous output tokens to consider as context + embedding_dim: Dimension of the embedding layer. + embedding_dropout: Dropout probability for the embedding layer. + reduce_embedding: Whether to use a reduction mechanism for the context embedding. + num_reduction_heads: Number of reduction heads if reduce_embedding is True. + """ + + num_outputs: int + blank_id: int = 0 + context_history_size: int = 1 + embedding_dim: int = 256 + embedding_dropout: float = 0.1 + reduce_embedding: bool = False + num_reduction_heads: Optional[int] = None + + def __post__init__(self): + super().__post_init__() + assert (num_reduction_heads is not None) == reduce_embedding + + @classmethod + def from_child(cls, child_instance): + return cls( + child_instance.num_outputs, + child_instance.blank_id, + child_instance.context_history_size, + child_instance.embedding_dim, + child_instance.embedding_dropout, + child_instance.reduce_embedding, + child_instance.num_reduction_heads, + ) + + +class EmbeddingTransducerPredictionNetworkV1(nn.Module): + def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None: + super().__init__() + self.cfg = cfg + self.blank_id = self.cfg.blank_id + self.context_history_size = self.cfg.context_history_size + self.embedding = nn.Embedding( + num_embeddings=self.cfg.num_outputs, + embedding_dim=self.cfg.embedding_dim, + padding_idx=self.blank_id, + ) + self.embed_dropout = nn.Dropout(self.cfg.embedding_dropout) + self.output_dim = ( + self.cfg.embedding_dim * self.cfg.context_history_size + if not self.cfg.reduce_embedding + else self.cfg.embedding_dim + ) + + self.reduce_embedding = self.cfg.reduce_embedding + self.num_reduction_heads = self.cfg.num_reduction_heads + if self.reduce_embedding: + self.register_buffer( + "position_vectors", + torch.randn( + self.cfg.context_history_size, + self.cfg.num_reduction_heads, + self.cfg.embedding_dim, + ), + ) + + def _reduce_embedding(self, emb: torch.Tensor) -> torch.Tensor: + """ + Reduces the context embedding using a weighted sum based on position vectors. + """ + B, _, H, E = emb.shape + emb_expanded = emb.unsqueeze(3) # [B, 1, H, 1, E] + pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) + alpha = (emb_expanded * pos_expanded).sum( + dim=-1, keepdim=True + ) # [B, 1, H, K, 1] + weighted = alpha * emb_expanded # [B, 1, H, K, E] + reduced = weighted.sum(dim=2).sum(dim=2) # [B, 1, E] + reduced *= 1.0 / (self.cfg.num_reduction_heads * self.cfg.context_history_size) + return reduced + + def _forward_embedding(self, history: torch.Tensor) -> torch.Tensor: + """ + Processes the input history through the embedding layer and optional reduction. + """ + if len(history.shape) == 2: # reshape if input shape [B, H] + history = history.view( + *history.shape[:-1], 1, history.shape[-1] + ) # [B, 1, H] + embed = self.embedding(history) # [B, S, H, E] + embed = self.embed_dropout(embed) + if self.reduce_embedding: + embed = self._reduce_embedding(embed) # [B, S, E] + else: + embed = embed.flatten(start_dim=-2) # [B, S, H*E] + return embed + + def forward( + self, + history: torch.Tensor, # [B, H] + ) -> torch.Tensor: # [B, 1, P] + """ + Forward pass for recognition mode. + """ + embed = self._forward_embedding(history) + return embed + + def forward_fullsum( + self, + targets: torch.Tensor, # [B, S] + target_lengths: torch.Tensor, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B] + """ + Forward pass for fullsum training. + """ + non_context_padding = torch.full( + (targets.size(0), self.cfg.context_history_size), + fill_value=self.blank_id, + dtype=targets.dtype, + device=targets.device, + ) # [B, H] + extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H] + history = torch.stack( + [ + extended_targets[ + :, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None) + ] + for i in reversed(range(self.cfg.context_history_size)) + ], + dim=-1, + ) # [B, S+1, H] + embed = self._forward_embedding(history) + + return embed, target_lengths + + def forward_viterbi( + self, + targets: torch.Tensor, # [B, T] + target_lengths: torch.Tensor, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] + """ + Forward pass for viterbi training. + """ + B, T = targets.shape + history = torch.zeros( + (B, T, self.cfg.context_history_size), + dtype=targets.dtype, + device=targets.device, + ) # [B, T, H] + recent_labels = torch.full( + (B, self.cfg.context_history_size), + fill_value=self.blank_id, + dtype=targets.dtype, + device=targets.device, + ) # [B, H] + + for t in range(T): + history[:, t, :] = recent_labels + current_labels = targets[:, t] + non_blank_positions = current_labels != self.blank_id + recent_labels[non_blank_positions, 1:] = recent_labels[ + non_blank_positions, :-1 + ] + recent_labels[non_blank_positions, 0] = current_labels[non_blank_positions] + embed = self._forward_embedding(history) + + return embed, target_lengths + + +@dataclass +class FfnnTransducerPredictionNetworkV1Config( + EmbeddingTransducerPredictionNetworkV1Config +): + """ + Attributes: + ffnn_cfg: Configuration for FFNN prediction network + """ + + ffnn_cfg: FeedForwardBlockV1Config + + +class FfnnTransducerPredictionNetworkV1(EmbeddingTransducerPredictionNetworkV1): + """ + FfnnTransducerPredictionNetworkV1 with feedforward layers. + """ + + def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config): + super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg)) + cfg.ffnn_cfg.input_dim = self.output_dim + self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) + self.output_dim = cfg.ffnn_cfg.output_dim + + def forward( + self, + history: torch.Tensor, # [B, H] + ) -> torch.Tensor: # [B, 1, P] + embed = super().forward(history) + output = self.ffnn(embed) + return output + + def forward_fullsum( + self, + targets: torch.Tensor, # [B, S] + target_lengths: torch.Tensor, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B] + embed, _ = super().forward_fullsum(targets, target_lengths) + output = self.ffnn(embed) + return output, target_lengths + + def forward_viterbi( + self, + targets: torch.Tensor, # [B, T] + target_lengths: torch.Tensor, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] + embed, _ = super().forward_fullsum(targets, target_lengths) + output = self.ffnn(embed) + return output, target_lengths diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 3bd30d38..62c92c1b 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -1,4 +1,9 @@ -__all__ = ["FeedForwardConfig", "FeedForwardModel"] +__all__ = [ + "FeedForwardLayerV1Config", + "FeedForwardLayerV1", + "FeedForwardBlockV1Config", + "FeedForwardBlockV1", +] from dataclasses import dataclass from functools import partial @@ -41,7 +46,9 @@ class FeedForwardLayerV1(nn.Module): def __init__(self, cfg: FeedForwardLayerV1Config): super().__init__() - self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True) + self.linear_ff = nn.Linear( + in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True + ) self.activation = cfg.activation self.dropout = nn.Dropout(cfg.dropout) @@ -57,3 +64,64 @@ def forward( tensor = self.activation(tensor) # [B,T,F] tensor = self.dropout(tensor) # [B,T,F] return tensor, sequence_mask + + +@dataclass +class FeedForwardBlockV1Config: + """ + Configuration for the FeedForwardBlockV1 module. + + Attributes: + input_dim: Input feature dimension. + layer_sizes: List of hidden layer sizes. The length of this list + determines the number of layers. + dropout: Dropout probability. + activation: Activation function applied after each linear layer. + use_layer_norm: Whether to use Layer Normalization. + """ + + input_dim: int + layer_sizes: List[int] + dropout: float + layer_activations: List[ + Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]] + ] + use_layer_norm: bool = True + + def __post_init__(self): + assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" + assert len(self.layer_sizes) > 0, "layer_sizes must not be empty" + assert len(self.layer_sizes) == len(self.layer_activations) + + +class FeedForwardBlockV1(nn.Module): + """ + A multi-layer feed-forward network block with optional Layer Normalization. + """ + + def __init__(self, cfg: FeedForwardBlockV1Config): + super().__init__() + self.cfg = cfg + network_layers: List[nn.Module] = [] + prev_size = cfg.input_dim + + for i, layer_size in enumerate(cfg.layer_sizes): + network_layers.append(nn.Dropout(cfg.dropout)) + network_layers.append(nn.Linear(prev_size, layer_size)) + prev_size = layer_size + if cfg.use_layer_norm: + network_layers.append(nn.LayerNorm(prev_size)) + if cfg.layer_activations[i] is not None: + network_layers.append(cfg.layer_activations[i]) + + self.output_dim = cfg.layer_sizes[-1] + self.network = nn.Sequential(*network_layers) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the feed-forward block. + + :param tensor: Input tensor of shape [B, T, F], where F is input_dim. + :return: Output tensor of shape [B, T, output_dim]. + """ + return self.network(tensor) From 745f6d76a318ccc663dfa62d6196955594446104 Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Thu, 22 May 2025 09:50:36 -0400 Subject: [PATCH 02/13] some fix after testing --- .../assemblies/transducer/joint_network.py | 49 +++++++++---------- .../transducer/prediction_network.py | 16 +++--- i6_models/parts/ffnn.py | 2 +- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py index 4287eaa2..9f00e0e8 100644 --- a/i6_models/assemblies/transducer/joint_network.py +++ b/i6_models/assemblies/transducer/joint_network.py @@ -30,6 +30,7 @@ def __init__( ) -> None: super().__init__() self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) + self.output_dim = self.ffnn.output_dim def forward( self, @@ -40,7 +41,7 @@ def forward( Forward pass for recognition. Assume T = 1 and S = 1 """ source_encodings = source_encodings.unsqueeze(2).expand( - -1, -1, target_encodings.size(1), -1 + target_encodings.size(0), -1, target_encodings.size(1), -1 ) # [B, T, S, E] target_encodings = target_encodings.unsqueeze(1).expand( -1, source_encodings.size(1), -1, -1 @@ -70,7 +71,7 @@ def forward_viterbi( output = self.ffnn(joint_network_inputs) # [B, T, F] if not self.training: output = torch.log_softmax(output, dim=-1) # [B, T, F] - return output + return output, source_lengths, target_lengths def forward_fullsum( self, @@ -83,26 +84,24 @@ def forward_fullsum( Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. """ batch_outputs = [] - for b in range(source_encodings.size(0)): - valid_source = source_encodings[b, : source_lengths[b], :] # [T_b, E] - valid_target = target_encodings[b, : target_lengths[b] + 1, :] # [S_b+1, P] - - expanded_source = valid_source.unsqueeze(1).expand( - -1, int(target_lengths[b].item()) + 1, -1 - ) # [T_b, S_b+1, E] - expanded_target = valid_target.unsqueeze(0).expand( - int(source_lengths[b].item()), -1, -1 - ) # [T_b, S_b+1, P] - combination = torch.cat( - [expanded_source, expanded_target], dim=-1 - ) # [T_b, S_b+1, E + P] - output = self.ffnn(combination) # [T_b, S_b+1, F] - batch_outputs.append(output) - # Pad the outputs to a common shape, if necessary. This is crucial for - # handling variable sequence lengths within a batch. The padding - # ensures that the final output tensor has a consistent shape. - padded_outputs = torch.nn.utils.rnn.pad_sequence( - batch_outputs, batch_first=True, padding_value=0.0 - ) # [B, max_T, max_S+1, F] - - return padded_outputs + max_target_length = target_encodings.size(1) # S+1 + max_source_length = source_encodings.size(1) # T + + # Expand source_encodings + expanded_source = source_encodings.unsqueeze(2).expand( + -1, -1, max_target_length, -1 + ) # [B, T, S+1, E] + + # Expand target_encodings + expanded_target = target_encodings.unsqueeze(1).expand( + -1, max_source_length, -1, -1 + ) # [B, T, S+1, P] + + # Concatenate + combination = torch.cat( + [expanded_source, expanded_target], dim=-1 + ) # [B, T, S+1, E + P] + + # Pass through FFNN + output = self.ffnn(combination) # [B, T, S+1, F] + return output, source_lengths, target_lengths diff --git a/i6_models/assemblies/transducer/prediction_network.py b/i6_models/assemblies/transducer/prediction_network.py index f56b34e3..69ff8b1d 100644 --- a/i6_models/assemblies/transducer/prediction_network.py +++ b/i6_models/assemblies/transducer/prediction_network.py @@ -29,12 +29,12 @@ class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration): """ num_outputs: int - blank_id: int = 0 - context_history_size: int = 1 - embedding_dim: int = 256 - embedding_dropout: float = 0.1 - reduce_embedding: bool = False - num_reduction_heads: Optional[int] = None + blank_id: int + context_history_size: int + embedding_dim: int + embedding_dropout: float + reduce_embedding: bool + num_reduction_heads: Optional[int] def __post__init__(self): super().__post_init__() @@ -207,7 +207,7 @@ def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config): super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg)) cfg.ffnn_cfg.input_dim = self.output_dim self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) - self.output_dim = cfg.ffnn_cfg.output_dim + self.output_dim = self.ffnn.output_dim def forward( self, @@ -231,6 +231,6 @@ def forward_viterbi( targets: torch.Tensor, # [B, T] target_lengths: torch.Tensor, # [B] ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] - embed, _ = super().forward_fullsum(targets, target_lengths) + embed, _ = super().forward_viterbi(targets, target_lengths) output = self.ffnn(embed) return output, target_lengths diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 62c92c1b..2ae899b3 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, List import torch from torch import nn From 02b4573c264b566c2acd9c00b86ce0c59e8eac7a Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Thu, 22 May 2025 10:52:00 -0400 Subject: [PATCH 03/13] code style and comment change --- i6_models/assemblies/transducer/joint_network.py | 4 ++-- .../assemblies/transducer/prediction_network.py | 12 +++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py index 9f00e0e8..62500543 100644 --- a/i6_models/assemblies/transducer/joint_network.py +++ b/i6_models/assemblies/transducer/joint_network.py @@ -34,7 +34,7 @@ def __init__( def forward( self, - source_encodings: torch.Tensor, # [B, T, E] + source_encodings: torch.Tensor, # [1, T, E] target_encodings: torch.Tensor, # [B, S, P] ) -> torch.Tensor: # [B, T, S, F] """ @@ -81,7 +81,7 @@ def forward_fullsum( target_lengths: torch.Tensor, # [B] ) -> torch.Tensor: # [B, T, S+1, F] """ - Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. + Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. """ batch_outputs = [] max_target_length = target_encodings.size(1) # S+1 diff --git a/i6_models/assemblies/transducer/prediction_network.py b/i6_models/assemblies/transducer/prediction_network.py index 69ff8b1d..86de6cb2 100644 --- a/i6_models/assemblies/transducer/prediction_network.py +++ b/i6_models/assemblies/transducer/prediction_network.py @@ -13,7 +13,6 @@ from i6_models.config import ModelConfiguration from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 -from i6_models.parts.lstm import LstmBlockV1Config, LstmBlockV1 @dataclass @@ -87,14 +86,13 @@ def _reduce_embedding(self, emb: torch.Tensor) -> torch.Tensor: """ Reduces the context embedding using a weighted sum based on position vectors. """ - B, _, H, E = emb.shape - emb_expanded = emb.unsqueeze(3) # [B, 1, H, 1, E] - pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) + emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E] + pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E] alpha = (emb_expanded * pos_expanded).sum( dim=-1, keepdim=True - ) # [B, 1, H, K, 1] - weighted = alpha * emb_expanded # [B, 1, H, K, E] - reduced = weighted.sum(dim=2).sum(dim=2) # [B, 1, E] + ) # [B, S, H, K, 1] + weighted = alpha * emb_expanded # [B, S, H, K, E] + reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] reduced *= 1.0 / (self.cfg.num_reduction_heads * self.cfg.context_history_size) return reduced From 5620c3895247e56bb3d5454d77c38db61f4459df Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Thu, 22 May 2025 11:21:10 -0400 Subject: [PATCH 04/13] use ruff instead of black --- i6_models/parts/ffnn.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 2ae899b3..28bc4636 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -46,9 +46,7 @@ class FeedForwardLayerV1(nn.Module): def __init__(self, cfg: FeedForwardLayerV1Config): super().__init__() - self.linear_ff = nn.Linear( - in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True - ) + self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True) self.activation = cfg.activation self.dropout = nn.Dropout(cfg.dropout) @@ -76,19 +74,20 @@ class FeedForwardBlockV1Config: layer_sizes: List of hidden layer sizes. The length of this list determines the number of layers. dropout: Dropout probability. - activation: Activation function applied after each linear layer. + layer_activations: List of activation function applied after each linear layer. + None represents no activation. + Must have the same length as layer_sizes. use_layer_norm: Whether to use Layer Normalization. """ input_dim: int layer_sizes: List[int] dropout: float - layer_activations: List[ - Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]] - ] + layer_activations: List[Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]]] use_layer_norm: bool = True def __post_init__(self): + super().__post_init__() assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" assert len(self.layer_sizes) > 0, "layer_sizes must not be empty" assert len(self.layer_sizes) == len(self.layer_activations) From 806be263121531f2c50333e196d497662a2cc1ac Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Fri, 23 May 2025 08:57:27 -0400 Subject: [PATCH 05/13] reformat --- .../assemblies/transducer/joint_network.py | 26 +++++-------------- .../transducer/prediction_network.py | 20 ++++---------- 2 files changed, 12 insertions(+), 34 deletions(-) diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py index 62500543..9b804d3f 100644 --- a/i6_models/assemblies/transducer/joint_network.py +++ b/i6_models/assemblies/transducer/joint_network.py @@ -38,17 +38,13 @@ def forward( target_encodings: torch.Tensor, # [B, S, P] ) -> torch.Tensor: # [B, T, S, F] """ - Forward pass for recognition. Assume T = 1 and S = 1 + Forward pass for recognition. """ source_encodings = source_encodings.unsqueeze(2).expand( target_encodings.size(0), -1, target_encodings.size(1), -1 ) # [B, T, S, E] - target_encodings = target_encodings.unsqueeze(1).expand( - -1, source_encodings.size(1), -1, -1 - ) # [B, T, S, P] - joint_network_inputs = torch.cat( - [source_encodings, target_encodings], dim=-1 - ) # [B, T, S, E + P] + target_encodings = target_encodings.unsqueeze(1).expand(-1, source_encodings.size(1), -1, -1) # [B, T, S, P] + joint_network_inputs = torch.cat([source_encodings, target_encodings], dim=-1) # [B, T, S, E + P] output = self.ffnn(joint_network_inputs) # [B, T, S, F] if not self.training: @@ -65,9 +61,7 @@ def forward_viterbi( """ Forward pass for Viterbi training. """ - joint_network_inputs = torch.cat( - [source_encodings, target_encodings], dim=-1 - ) # [B, T, E + P] + joint_network_inputs = torch.cat([source_encodings, target_encodings], dim=-1) # [B, T, E + P] output = self.ffnn(joint_network_inputs) # [B, T, F] if not self.training: output = torch.log_softmax(output, dim=-1) # [B, T, F] @@ -88,19 +82,13 @@ def forward_fullsum( max_source_length = source_encodings.size(1) # T # Expand source_encodings - expanded_source = source_encodings.unsqueeze(2).expand( - -1, -1, max_target_length, -1 - ) # [B, T, S+1, E] + expanded_source = source_encodings.unsqueeze(2).expand(-1, -1, max_target_length, -1) # [B, T, S+1, E] # Expand target_encodings - expanded_target = target_encodings.unsqueeze(1).expand( - -1, max_source_length, -1, -1 - ) # [B, T, S+1, P] + expanded_target = target_encodings.unsqueeze(1).expand(-1, max_source_length, -1, -1) # [B, T, S+1, P] # Concatenate - combination = torch.cat( - [expanded_source, expanded_target], dim=-1 - ) # [B, T, S+1, E + P] + combination = torch.cat([expanded_source, expanded_target], dim=-1) # [B, T, S+1, E + P] # Pass through FFNN output = self.ffnn(combination) # [B, T, S+1, F] diff --git a/i6_models/assemblies/transducer/prediction_network.py b/i6_models/assemblies/transducer/prediction_network.py index 86de6cb2..b7f6cb90 100644 --- a/i6_models/assemblies/transducer/prediction_network.py +++ b/i6_models/assemblies/transducer/prediction_network.py @@ -88,9 +88,7 @@ def _reduce_embedding(self, emb: torch.Tensor) -> torch.Tensor: """ emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E] pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E] - alpha = (emb_expanded * pos_expanded).sum( - dim=-1, keepdim=True - ) # [B, S, H, K, 1] + alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1] weighted = alpha * emb_expanded # [B, S, H, K, E] reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] reduced *= 1.0 / (self.cfg.num_reduction_heads * self.cfg.context_history_size) @@ -101,9 +99,7 @@ def _forward_embedding(self, history: torch.Tensor) -> torch.Tensor: Processes the input history through the embedding layer and optional reduction. """ if len(history.shape) == 2: # reshape if input shape [B, H] - history = history.view( - *history.shape[:-1], 1, history.shape[-1] - ) # [B, 1, H] + history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H] embed = self.embedding(history) # [B, S, H, E] embed = self.embed_dropout(embed) if self.reduce_embedding: @@ -139,9 +135,7 @@ def forward_fullsum( extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H] history = torch.stack( [ - extended_targets[ - :, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None) - ] + extended_targets[:, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None)] for i in reversed(range(self.cfg.context_history_size)) ], dim=-1, @@ -175,9 +169,7 @@ def forward_viterbi( history[:, t, :] = recent_labels current_labels = targets[:, t] non_blank_positions = current_labels != self.blank_id - recent_labels[non_blank_positions, 1:] = recent_labels[ - non_blank_positions, :-1 - ] + recent_labels[non_blank_positions, 1:] = recent_labels[non_blank_positions, :-1] recent_labels[non_blank_positions, 0] = current_labels[non_blank_positions] embed = self._forward_embedding(history) @@ -185,9 +177,7 @@ def forward_viterbi( @dataclass -class FfnnTransducerPredictionNetworkV1Config( - EmbeddingTransducerPredictionNetworkV1Config -): +class FfnnTransducerPredictionNetworkV1Config(EmbeddingTransducerPredictionNetworkV1Config): """ Attributes: ffnn_cfg: Configuration for FFNN prediction network From 69444ffcc44e7fe9a2eb4402f8fa6a6f355819af Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Wed, 28 May 2025 05:33:12 -0400 Subject: [PATCH 06/13] fix dropout problem --- i6_models/assemblies/transducer/prediction_network.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/i6_models/assemblies/transducer/prediction_network.py b/i6_models/assemblies/transducer/prediction_network.py index b7f6cb90..79c9ff95 100644 --- a/i6_models/assemblies/transducer/prediction_network.py +++ b/i6_models/assemblies/transducer/prediction_network.py @@ -22,7 +22,6 @@ class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration): blank_id: Index of the blank token. context_history_size: Number of previous output tokens to consider as context embedding_dim: Dimension of the embedding layer. - embedding_dropout: Dropout probability for the embedding layer. reduce_embedding: Whether to use a reduction mechanism for the context embedding. num_reduction_heads: Number of reduction heads if reduce_embedding is True. """ @@ -31,7 +30,6 @@ class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration): blank_id: int context_history_size: int embedding_dim: int - embedding_dropout: float reduce_embedding: bool num_reduction_heads: Optional[int] @@ -46,7 +44,6 @@ def from_child(cls, child_instance): child_instance.blank_id, child_instance.context_history_size, child_instance.embedding_dim, - child_instance.embedding_dropout, child_instance.reduce_embedding, child_instance.num_reduction_heads, ) @@ -63,7 +60,6 @@ def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None: embedding_dim=self.cfg.embedding_dim, padding_idx=self.blank_id, ) - self.embed_dropout = nn.Dropout(self.cfg.embedding_dropout) self.output_dim = ( self.cfg.embedding_dim * self.cfg.context_history_size if not self.cfg.reduce_embedding @@ -101,7 +97,6 @@ def _forward_embedding(self, history: torch.Tensor) -> torch.Tensor: if len(history.shape) == 2: # reshape if input shape [B, H] history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H] embed = self.embedding(history) # [B, S, H, E] - embed = self.embed_dropout(embed) if self.reduce_embedding: embed = self._reduce_embedding(embed) # [B, S, E] else: From e73180f57bb2c5da78eca75cd7317325014dad20 Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Wed, 28 May 2025 06:23:25 -0400 Subject: [PATCH 07/13] fix small issue --- i6_models/parts/ffnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 28bc4636..3ef07296 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -65,7 +65,7 @@ def forward( @dataclass -class FeedForwardBlockV1Config: +class FeedForwardBlockV1Config(ModelConfiguration): """ Configuration for the FeedForwardBlockV1 module. @@ -105,13 +105,13 @@ def __init__(self, cfg: FeedForwardBlockV1Config): prev_size = cfg.input_dim for i, layer_size in enumerate(cfg.layer_sizes): - network_layers.append(nn.Dropout(cfg.dropout)) network_layers.append(nn.Linear(prev_size, layer_size)) prev_size = layer_size if cfg.use_layer_norm: network_layers.append(nn.LayerNorm(prev_size)) if cfg.layer_activations[i] is not None: network_layers.append(cfg.layer_activations[i]) + network_layers.append(nn.Dropout(cfg.dropout)) self.output_dim = cfg.layer_sizes[-1] self.network = nn.Sequential(*network_layers) From a7297c7126c691ba30ab73434ba7b4a61bf2f0b1 Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Wed, 6 Aug 2025 11:42:53 -0400 Subject: [PATCH 08/13] adjust linear position --- i6_models/parts/ffnn.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 3ef07296..5ecc91f5 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -73,7 +73,7 @@ class FeedForwardBlockV1Config(ModelConfiguration): input_dim: Input feature dimension. layer_sizes: List of hidden layer sizes. The length of this list determines the number of layers. - dropout: Dropout probability. + dropouts: Dropout probability for each layer. layer_activations: List of activation function applied after each linear layer. None represents no activation. Must have the same length as layer_sizes. @@ -82,15 +82,16 @@ class FeedForwardBlockV1Config(ModelConfiguration): input_dim: int layer_sizes: List[int] - dropout: float + dropouts: List[float] layer_activations: List[Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]]] use_layer_norm: bool = True def __post_init__(self): super().__post_init__() - assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" + assert all(0.0 <= dropout <= 1.0 for dropout in self.dropouts), "Dropout values must be probabilities" assert len(self.layer_sizes) > 0, "layer_sizes must not be empty" assert len(self.layer_sizes) == len(self.layer_activations) + assert len(self.layer_sizes) == len(self.dropouts) class FeedForwardBlockV1(nn.Module): @@ -105,13 +106,13 @@ def __init__(self, cfg: FeedForwardBlockV1Config): prev_size = cfg.input_dim for i, layer_size in enumerate(cfg.layer_sizes): - network_layers.append(nn.Linear(prev_size, layer_size)) - prev_size = layer_size if cfg.use_layer_norm: network_layers.append(nn.LayerNorm(prev_size)) + network_layers.append(nn.Linear(prev_size, layer_size)) + prev_size = layer_size if cfg.layer_activations[i] is not None: network_layers.append(cfg.layer_activations[i]) - network_layers.append(nn.Dropout(cfg.dropout)) + network_layers.append(nn.Dropout(cfg.dropouts[i])) self.output_dim = cfg.layer_sizes[-1] self.network = nn.Sequential(*network_layers) From 6ce4f644dcf26555339f6c90099730a5b4e290c6 Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Wed, 6 Aug 2025 11:49:13 -0400 Subject: [PATCH 09/13] change to additive joint network --- i6_models/assemblies/transducer/joint_network.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py index 9b804d3f..52d1afce 100644 --- a/i6_models/assemblies/transducer/joint_network.py +++ b/i6_models/assemblies/transducer/joint_network.py @@ -40,12 +40,8 @@ def forward( """ Forward pass for recognition. """ - source_encodings = source_encodings.unsqueeze(2).expand( - target_encodings.size(0), -1, target_encodings.size(1), -1 - ) # [B, T, S, E] - target_encodings = target_encodings.unsqueeze(1).expand(-1, source_encodings.size(1), -1, -1) # [B, T, S, P] - joint_network_inputs = torch.cat([source_encodings, target_encodings], dim=-1) # [B, T, S, E + P] - output = self.ffnn(joint_network_inputs) # [B, T, S, F] + combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) + output = self.ffnn(combined_encodings) # [B, T, S, F] if not self.training: output = torch.log_softmax(output, dim=-1) # [B, T, S, F] From 7385d1f04734e9ab30fd3ad50597106994b53f45 Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Fri, 8 Aug 2025 06:37:52 -0400 Subject: [PATCH 10/13] add joint normalization --- .../assemblies/transducer/joint_network.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py index 52d1afce..4b053975 100644 --- a/i6_models/assemblies/transducer/joint_network.py +++ b/i6_models/assemblies/transducer/joint_network.py @@ -17,10 +17,11 @@ class TransducerJointNetworkV1Config(ModelConfiguration): Attributes: ffnn_cfg: Configuration for the internal feed-forward network. - num_layers: Number of FFNN prediction network layer + joint_normalization: whether use normalized joint gradient for fullsum """ ffnn_cfg: FeedForwardBlockV1Config + joint_normalization: bool class TransducerJointNetworkV1(nn.Module): @@ -30,6 +31,7 @@ def __init__( ) -> None: super().__init__() self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) + self.joint_normalization = cfg.joint_normalization self.output_dim = self.ffnn.output_dim def forward( @@ -73,19 +75,19 @@ def forward_fullsum( """ Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. """ - batch_outputs = [] - max_target_length = target_encodings.size(1) # S+1 - max_source_length = source_encodings.size(1) # T - - # Expand source_encodings - expanded_source = source_encodings.unsqueeze(2).expand(-1, -1, max_target_length, -1) # [B, T, S+1, E] + + # additive combination + combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) - # Expand target_encodings - expanded_target = target_encodings.unsqueeze(1).expand(-1, max_source_length, -1, -1) # [B, T, S+1, P] + if self.joint_normalization: + source_lengths_safe = torch.clamp(source_lengths, min=1).float() + target_lengths_safe = torch.clamp(target_lengths, min=1).float() + scale_enc = (1.0 / target_lengths_safe).view(-1, 1, 1).to(source_encodings.device) + scale_pred = (1.0 / source_lengths_safe).view(-1, 1, 1).to(target_encodings.device) - # Concatenate - combination = torch.cat([expanded_source, expanded_target], dim=-1) # [B, T, S+1, E + P] + source_encodings.register_hook (lambda g, s=scale_enc : g * s) + target_encodings.register_hook(lambda g, s=scale_pred: g * s) # Pass through FFNN - output = self.ffnn(combination) # [B, T, S+1, F] + output = self.ffnn(combined_encodings) # [B, T, S+1, F] return output, source_lengths, target_lengths From e707f5f3593709d93c13218ae330a1762a3a57f8 Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Tue, 2 Sep 2025 08:55:58 -0400 Subject: [PATCH 11/13] remove joint normalization(useless), update viterbi forward --- i6_models/assemblies/transducer/joint_network.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py index 4b053975..b4a9ec9b 100644 --- a/i6_models/assemblies/transducer/joint_network.py +++ b/i6_models/assemblies/transducer/joint_network.py @@ -17,11 +17,9 @@ class TransducerJointNetworkV1Config(ModelConfiguration): Attributes: ffnn_cfg: Configuration for the internal feed-forward network. - joint_normalization: whether use normalized joint gradient for fullsum """ ffnn_cfg: FeedForwardBlockV1Config - joint_normalization: bool class TransducerJointNetworkV1(nn.Module): @@ -31,7 +29,6 @@ def __init__( ) -> None: super().__init__() self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) - self.joint_normalization = cfg.joint_normalization self.output_dim = self.ffnn.output_dim def forward( @@ -59,8 +56,8 @@ def forward_viterbi( """ Forward pass for Viterbi training. """ - joint_network_inputs = torch.cat([source_encodings, target_encodings], dim=-1) # [B, T, E + P] - output = self.ffnn(joint_network_inputs) # [B, T, F] + combined_encodings = source_encodings + target_encodings + output = self.ffnn(combined_encodings) # [B, T, F] if not self.training: output = torch.log_softmax(output, dim=-1) # [B, T, F] return output, source_lengths, target_lengths @@ -79,15 +76,6 @@ def forward_fullsum( # additive combination combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) - if self.joint_normalization: - source_lengths_safe = torch.clamp(source_lengths, min=1).float() - target_lengths_safe = torch.clamp(target_lengths, min=1).float() - scale_enc = (1.0 / target_lengths_safe).view(-1, 1, 1).to(source_encodings.device) - scale_pred = (1.0 / source_lengths_safe).view(-1, 1, 1).to(target_encodings.device) - - source_encodings.register_hook (lambda g, s=scale_enc : g * s) - target_encodings.register_hook(lambda g, s=scale_pred: g * s) - # Pass through FFNN output = self.ffnn(combined_encodings) # [B, T, S+1, F] return output, source_lengths, target_lengths From dceb1d1f726f0e003a93ce268bf51dc8f64d14fc Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Mon, 15 Dec 2025 08:16:19 -0500 Subject: [PATCH 12/13] add projection into joint network --- .../assemblies/transducer/joint_network.py | 60 +++++++++++++++---- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/i6_models/assemblies/transducer/joint_network.py b/i6_models/assemblies/transducer/joint_network.py index b4a9ec9b..e3ccc324 100644 --- a/i6_models/assemblies/transducer/joint_network.py +++ b/i6_models/assemblies/transducer/joint_network.py @@ -26,10 +26,44 @@ class TransducerJointNetworkV1(nn.Module): def __init__( self, cfg: TransducerJointNetworkV1Config, + enc_input_dim: int, + pred_input_dim: int, ) -> None: super().__init__() - self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) - self.output_dim = self.ffnn.output_dim + hidden_dim = cfg.ffnn_cfg.layer_sizes[0] + self.enc_proj = nn.Linear(enc_input_dim, hidden_dim, bias=True) + self.pred_proj = nn.Linear(pred_input_dim, hidden_dim, bias=False) # Bias handled by enc_proj + + self.activation = cfg.ffnn_cfg.layer_activations[0] + self.dropout = nn.Dropout(cfg.ffnn_cfg.dropouts[0]) if cfg.ffnn_cfg.dropouts else nn.Identity() + + # Build the rest of the network (if any) + if len(cfg.ffnn_cfg.layer_sizes) > 1: + remaining_cfg = FeedForwardBlockV1Config( + input_dim=hidden_dim, + layer_sizes=cfg.ffnn_cfg.layer_sizes[1:], + dropouts=cfg.ffnn_cfg.dropouts[1:] if cfg.ffnn_cfg.dropouts else None, + layer_activations=cfg.ffnn_cfg.layer_activations[1:], + use_layer_norm=cfg.ffnn_cfg.use_layer_norm, + ) + self.ffnn = FeedForwardBlockV1(remaining_cfg) + else: + self.ffnn = nn.Identity() + + self.output_dim = cfg.ffnn_cfg.layer_sizes[-1] + + def _forward_joint(self, enc: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: + # Project individually then broadcast-sum + enc_proj = self.enc_proj(enc).unsqueeze(2) # [B, T, 1, H] + pred_proj = self.pred_proj(pred).unsqueeze(1) # [B, 1, U, H] + + combined = enc_proj + pred_proj + + if self.activation is not None: + combined = self.activation(combined) + combined = self.dropout(combined) + + return self.ffnn(combined) def forward( self, @@ -39,8 +73,7 @@ def forward( """ Forward pass for recognition. """ - combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) - output = self.ffnn(combined_encodings) # [B, T, S, F] + output = self._forward_joint(source_encodings, target_encodings) if not self.training: output = torch.log_softmax(output, dim=-1) # [B, T, S, F] @@ -56,8 +89,16 @@ def forward_viterbi( """ Forward pass for Viterbi training. """ - combined_encodings = source_encodings + target_encodings - output = self.ffnn(combined_encodings) # [B, T, F] + # For Viterbi, dimensions align (T=T), so we can sum directly without broadcasting + enc_proj = self.enc_proj(source_encodings) + pred_proj = self.pred_proj(target_encodings) + combined = enc_proj + pred_proj + + if self.activation is not None: + combined = self.activation(combined) + combined = self.dropout(combined) + + output = self.ffnn(combined) # [B, T, F] if not self.training: output = torch.log_softmax(output, dim=-1) # [B, T, F] return output, source_lengths, target_lengths @@ -72,10 +113,5 @@ def forward_fullsum( """ Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. """ - - # additive combination - combined_encodings = source_encodings.unsqueeze(2) + target_encodings.unsqueeze(1) - - # Pass through FFNN - output = self.ffnn(combined_encodings) # [B, T, S+1, F] + output = self._forward_joint(source_encodings, target_encodings) return output, source_lengths, target_lengths From 517d687f2af70b3a70a5a73ef14724361fcd73f9 Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Mon, 15 Dec 2025 08:19:42 -0500 Subject: [PATCH 13/13] fix prediction net --- i6_models/assemblies/transducer/prediction_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/i6_models/assemblies/transducer/prediction_network.py b/i6_models/assemblies/transducer/prediction_network.py index 79c9ff95..a165fd38 100644 --- a/i6_models/assemblies/transducer/prediction_network.py +++ b/i6_models/assemblies/transducer/prediction_network.py @@ -33,9 +33,9 @@ class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration): reduce_embedding: bool num_reduction_heads: Optional[int] - def __post__init__(self): + def __post_init__(self): super().__post_init__() - assert (num_reduction_heads is not None) == reduce_embedding + assert (self.num_reduction_heads is not None) == self.reduce_embedding @classmethod def from_child(cls, child_instance): @@ -164,8 +164,8 @@ def forward_viterbi( history[:, t, :] = recent_labels current_labels = targets[:, t] non_blank_positions = current_labels != self.blank_id - recent_labels[non_blank_positions, 1:] = recent_labels[non_blank_positions, :-1] - recent_labels[non_blank_positions, 0] = current_labels[non_blank_positions] + recent_labels[non_blank_positions, :-1] = recent_labels[non_blank_positions, 1:] + recent_labels[non_blank_positions, -1] = current_labels[non_blank_positions] embed = self._forward_embedding(history) return embed, target_lengths