-
Notifications
You must be signed in to change notification settings - Fork 0
Add transducer components #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e2c47c0
745f6d7
02b4573
5620c38
806be26
69444ff
e73180f
a7297c7
6ce4f64
7385d1f
e707f5f
dceb1d1
517d687
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from .prediction_network import * | ||
| from .joint_network import * |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| __all__ = ["TransducerJointNetworkV1Config", "TransducerJointNetworkV1"] | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, Tuple, Union | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
| from i6_models.config import ModelConfiguration | ||
| from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 | ||
|
|
||
|
|
||
| @dataclass | ||
| class TransducerJointNetworkV1Config(ModelConfiguration): | ||
| """ | ||
| Configuration for the Transducer Joint Network. | ||
|
|
||
| Attributes: | ||
| ffnn_cfg: Configuration for the internal feed-forward network. | ||
| """ | ||
|
|
||
| ffnn_cfg: FeedForwardBlockV1Config | ||
|
|
||
|
|
||
| class TransducerJointNetworkV1(nn.Module): | ||
| def __init__( | ||
| self, | ||
| cfg: TransducerJointNetworkV1Config, | ||
| enc_input_dim: int, | ||
| pred_input_dim: int, | ||
| ) -> None: | ||
| super().__init__() | ||
| hidden_dim = cfg.ffnn_cfg.layer_sizes[0] | ||
| self.enc_proj = nn.Linear(enc_input_dim, hidden_dim, bias=True) | ||
| self.pred_proj = nn.Linear(pred_input_dim, hidden_dim, bias=False) # Bias handled by enc_proj | ||
|
|
||
| self.activation = cfg.ffnn_cfg.layer_activations[0] | ||
| self.dropout = nn.Dropout(cfg.ffnn_cfg.dropouts[0]) if cfg.ffnn_cfg.dropouts else nn.Identity() | ||
|
|
||
| # Build the rest of the network (if any) | ||
| if len(cfg.ffnn_cfg.layer_sizes) > 1: | ||
| remaining_cfg = FeedForwardBlockV1Config( | ||
| input_dim=hidden_dim, | ||
| layer_sizes=cfg.ffnn_cfg.layer_sizes[1:], | ||
| dropouts=cfg.ffnn_cfg.dropouts[1:] if cfg.ffnn_cfg.dropouts else None, | ||
| layer_activations=cfg.ffnn_cfg.layer_activations[1:], | ||
| use_layer_norm=cfg.ffnn_cfg.use_layer_norm, | ||
| ) | ||
| self.ffnn = FeedForwardBlockV1(remaining_cfg) | ||
| else: | ||
| self.ffnn = nn.Identity() | ||
|
|
||
| self.output_dim = cfg.ffnn_cfg.layer_sizes[-1] | ||
|
|
||
| def _forward_joint(self, enc: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: | ||
| # Project individually then broadcast-sum | ||
| enc_proj = self.enc_proj(enc).unsqueeze(2) # [B, T, 1, H] | ||
| pred_proj = self.pred_proj(pred).unsqueeze(1) # [B, 1, U, H] | ||
|
|
||
| combined = enc_proj + pred_proj | ||
|
|
||
| if self.activation is not None: | ||
| combined = self.activation(combined) | ||
| combined = self.dropout(combined) | ||
|
|
||
| return self.ffnn(combined) | ||
|
|
||
| def forward( | ||
| self, | ||
| source_encodings: torch.Tensor, # [1, T, E] | ||
| target_encodings: torch.Tensor, # [B, S, P] | ||
|
Comment on lines
+70
to
+71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are source_encodings the output of the acoustic encoder and the target_encodings the output of the prediction network? Maybe we could rename (+ document) this better. |
||
| ) -> torch.Tensor: # [B, T, S, F] | ||
| """ | ||
| Forward pass for recognition. | ||
| """ | ||
| output = self._forward_joint(source_encodings, target_encodings) | ||
|
|
||
| if not self.training: | ||
| output = torch.log_softmax(output, dim=-1) # [B, T, S, F] | ||
|
Comment on lines
+78
to
+79
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think I'm a big fan of switching between logits and log probs based on whether it's train time or not. I'd rather pass a parameter or leave the log softmax to the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, I think we usually get logits in the train step and apply the appropriate softmax function there |
||
| return output | ||
|
|
||
| def forward_viterbi( | ||
| self, | ||
| source_encodings: torch.Tensor, # [B, T, E] | ||
| source_lengths: torch.Tensor, # [B] | ||
| target_encodings: torch.Tensor, # [B, T, P] | ||
| target_lengths: torch.Tensor, # [B] | ||
| ) -> torch.Tensor: # [B, T, F] | ||
| """ | ||
| Forward pass for Viterbi training. | ||
| """ | ||
| # For Viterbi, dimensions align (T=T), so we can sum directly without broadcasting | ||
| enc_proj = self.enc_proj(source_encodings) | ||
| pred_proj = self.pred_proj(target_encodings) | ||
| combined = enc_proj + pred_proj | ||
|
|
||
| if self.activation is not None: | ||
| combined = self.activation(combined) | ||
| combined = self.dropout(combined) | ||
|
|
||
| output = self.ffnn(combined) # [B, T, F] | ||
| if not self.training: | ||
| output = torch.log_softmax(output, dim=-1) # [B, T, F] | ||
|
Comment on lines
+102
to
+103
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above. |
||
| return output, source_lengths, target_lengths | ||
|
|
||
| def forward_fullsum( | ||
| self, | ||
| source_encodings: torch.Tensor, # [B, T, E] | ||
| source_lengths: torch.Tensor, # [B] | ||
| target_encodings: torch.Tensor, # [B, S+1, P] | ||
| target_lengths: torch.Tensor, # [B] | ||
| ) -> torch.Tensor: # [B, T, S+1, F] | ||
| """ | ||
| Forward pass for fullsum training. Returns output with shape [B, T, S+1, F]. | ||
| """ | ||
| output = self._forward_joint(source_encodings, target_encodings) | ||
| return output, source_lengths, target_lengths | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,219 @@ | ||||||||||||||||||
| __all__ = [ | ||||||||||||||||||
| "EmbeddingTransducerPredictionNetworkV1Config", | ||||||||||||||||||
| "EmbeddingTransducerPredictionNetworkV1", | ||||||||||||||||||
| "FfnnTransducerPredictionNetworkV1Config", | ||||||||||||||||||
| "FfnnTransducerPredictionNetworkV1", | ||||||||||||||||||
| ] | ||||||||||||||||||
|
|
||||||||||||||||||
| from dataclasses import dataclass, field | ||||||||||||||||||
| from typing import Any, Dict, Optional, Tuple, Union | ||||||||||||||||||
|
|
||||||||||||||||||
| import torch | ||||||||||||||||||
| from torch import nn | ||||||||||||||||||
|
|
||||||||||||||||||
| from i6_models.config import ModelConfiguration | ||||||||||||||||||
| from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1 | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @dataclass | ||||||||||||||||||
| class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration): | ||||||||||||||||||
| """ | ||||||||||||||||||
| num_outputs: Number of output units (vocabulary size + blank). | ||||||||||||||||||
| blank_id: Index of the blank token. | ||||||||||||||||||
| context_history_size: Number of previous output tokens to consider as context | ||||||||||||||||||
| embedding_dim: Dimension of the embedding layer. | ||||||||||||||||||
| reduce_embedding: Whether to use a reduction mechanism for the context embedding. | ||||||||||||||||||
| num_reduction_heads: Number of reduction heads if reduce_embedding is True. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| num_outputs: int | ||||||||||||||||||
| blank_id: int | ||||||||||||||||||
| context_history_size: int | ||||||||||||||||||
| embedding_dim: int | ||||||||||||||||||
| reduce_embedding: bool | ||||||||||||||||||
| num_reduction_heads: Optional[int] | ||||||||||||||||||
|
|
||||||||||||||||||
| def __post_init__(self): | ||||||||||||||||||
| super().__post_init__() | ||||||||||||||||||
| assert (self.num_reduction_heads is not None) == self.reduce_embedding | ||||||||||||||||||
|
|
||||||||||||||||||
| @classmethod | ||||||||||||||||||
| def from_child(cls, child_instance): | ||||||||||||||||||
| return cls( | ||||||||||||||||||
| child_instance.num_outputs, | ||||||||||||||||||
| child_instance.blank_id, | ||||||||||||||||||
| child_instance.context_history_size, | ||||||||||||||||||
| child_instance.embedding_dim, | ||||||||||||||||||
| child_instance.reduce_embedding, | ||||||||||||||||||
| child_instance.num_reduction_heads, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
Comment on lines
+41
to
+49
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this do, why is it necessary and different from config2 = copy.deepcopy(config1)? |
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class EmbeddingTransducerPredictionNetworkV1(nn.Module): | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs docs. |
||||||||||||||||||
| def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None: | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
| super().__init__() | ||||||||||||||||||
| self.cfg = cfg | ||||||||||||||||||
| self.blank_id = self.cfg.blank_id | ||||||||||||||||||
| self.context_history_size = self.cfg.context_history_size | ||||||||||||||||||
| self.embedding = nn.Embedding( | ||||||||||||||||||
| num_embeddings=self.cfg.num_outputs, | ||||||||||||||||||
| embedding_dim=self.cfg.embedding_dim, | ||||||||||||||||||
| padding_idx=self.blank_id, | ||||||||||||||||||
| ) | ||||||||||||||||||
| self.output_dim = ( | ||||||||||||||||||
| self.cfg.embedding_dim * self.cfg.context_history_size | ||||||||||||||||||
| if not self.cfg.reduce_embedding | ||||||||||||||||||
| else self.cfg.embedding_dim | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| self.reduce_embedding = self.cfg.reduce_embedding | ||||||||||||||||||
| self.num_reduction_heads = self.cfg.num_reduction_heads | ||||||||||||||||||
| if self.reduce_embedding: | ||||||||||||||||||
| self.register_buffer( | ||||||||||||||||||
| "position_vectors", | ||||||||||||||||||
| torch.randn( | ||||||||||||||||||
| self.cfg.context_history_size, | ||||||||||||||||||
| self.cfg.num_reduction_heads, | ||||||||||||||||||
| self.cfg.embedding_dim, | ||||||||||||||||||
| ), | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| def _reduce_embedding(self, emb: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Reduces the context embedding using a weighted sum based on position vectors. | ||||||||||||||||||
| """ | ||||||||||||||||||
| emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E] | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider unsqueezing from the back. |
||||||||||||||||||
| pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E] | ||||||||||||||||||
| alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1] | ||||||||||||||||||
| weighted = alpha * emb_expanded # [B, S, H, K, E] | ||||||||||||||||||
| reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E] | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider indexing dims from the back.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
| reduced *= 1.0 / (self.cfg.num_reduction_heads * self.cfg.context_history_size) | ||||||||||||||||||
| return reduced | ||||||||||||||||||
|
|
||||||||||||||||||
| def _forward_embedding(self, history: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Processes the input history through the embedding layer and optional reduction. | ||||||||||||||||||
| """ | ||||||||||||||||||
| if len(history.shape) == 2: # reshape if input shape [B, H] | ||||||||||||||||||
| history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H] | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||
| embed = self.embedding(history) # [B, S, H, E] | ||||||||||||||||||
| if self.reduce_embedding: | ||||||||||||||||||
| embed = self._reduce_embedding(embed) # [B, S, E] | ||||||||||||||||||
| else: | ||||||||||||||||||
| embed = embed.flatten(start_dim=-2) # [B, S, H*E] | ||||||||||||||||||
| return embed | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward( | ||||||||||||||||||
| self, | ||||||||||||||||||
| history: torch.Tensor, # [B, H] | ||||||||||||||||||
| ) -> torch.Tensor: # [B, 1, P] | ||||||||||||||||||
| """ | ||||||||||||||||||
| Forward pass for recognition mode. | ||||||||||||||||||
| """ | ||||||||||||||||||
| embed = self._forward_embedding(history) | ||||||||||||||||||
| return embed | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_fullsum( | ||||||||||||||||||
| self, | ||||||||||||||||||
| targets: torch.Tensor, # [B, S] | ||||||||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B] | ||||||||||||||||||
| """ | ||||||||||||||||||
| Forward pass for fullsum training. | ||||||||||||||||||
| """ | ||||||||||||||||||
| non_context_padding = torch.full( | ||||||||||||||||||
| (targets.size(0), self.cfg.context_history_size), | ||||||||||||||||||
| fill_value=self.blank_id, | ||||||||||||||||||
| dtype=targets.dtype, | ||||||||||||||||||
| device=targets.device, | ||||||||||||||||||
| ) # [B, H] | ||||||||||||||||||
| extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H] | ||||||||||||||||||
| history = torch.stack( | ||||||||||||||||||
| [ | ||||||||||||||||||
| extended_targets[:, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None)] | ||||||||||||||||||
| for i in reversed(range(self.cfg.context_history_size)) | ||||||||||||||||||
| ], | ||||||||||||||||||
| dim=-1, | ||||||||||||||||||
| ) # [B, S+1, H] | ||||||||||||||||||
|
Comment on lines
+124
to
+137
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ChatGPT suggested this code |
||||||||||||||||||
| embed = self._forward_embedding(history) | ||||||||||||||||||
|
|
||||||||||||||||||
| return embed, target_lengths | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_viterbi( | ||||||||||||||||||
| self, | ||||||||||||||||||
| targets: torch.Tensor, # [B, T] | ||||||||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] | ||||||||||||||||||
| """ | ||||||||||||||||||
| Forward pass for viterbi training. | ||||||||||||||||||
| """ | ||||||||||||||||||
| B, T = targets.shape | ||||||||||||||||||
| history = torch.zeros( | ||||||||||||||||||
| (B, T, self.cfg.context_history_size), | ||||||||||||||||||
| dtype=targets.dtype, | ||||||||||||||||||
| device=targets.device, | ||||||||||||||||||
| ) # [B, T, H] | ||||||||||||||||||
| recent_labels = torch.full( | ||||||||||||||||||
| (B, self.cfg.context_history_size), | ||||||||||||||||||
| fill_value=self.blank_id, | ||||||||||||||||||
| dtype=targets.dtype, | ||||||||||||||||||
| device=targets.device, | ||||||||||||||||||
| ) # [B, H] | ||||||||||||||||||
|
|
||||||||||||||||||
| for t in range(T): | ||||||||||||||||||
| history[:, t, :] = recent_labels | ||||||||||||||||||
| current_labels = targets[:, t] | ||||||||||||||||||
| non_blank_positions = current_labels != self.blank_id | ||||||||||||||||||
| recent_labels[non_blank_positions, :-1] = recent_labels[non_blank_positions, 1:] | ||||||||||||||||||
| recent_labels[non_blank_positions, -1] = current_labels[non_blank_positions] | ||||||||||||||||||
| embed = self._forward_embedding(history) | ||||||||||||||||||
|
|
||||||||||||||||||
| return embed, target_lengths | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @dataclass | ||||||||||||||||||
| class FfnnTransducerPredictionNetworkV1Config(EmbeddingTransducerPredictionNetworkV1Config): | ||||||||||||||||||
| """ | ||||||||||||||||||
| Attributes: | ||||||||||||||||||
| ffnn_cfg: Configuration for FFNN prediction network | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| ffnn_cfg: FeedForwardBlockV1Config | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class FfnnTransducerPredictionNetworkV1(EmbeddingTransducerPredictionNetworkV1): | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this class would benefit from using composition instead of inheritance. Make it contain/own an |
||||||||||||||||||
| """ | ||||||||||||||||||
| FfnnTransducerPredictionNetworkV1 with feedforward layers. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config): | ||||||||||||||||||
| super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg)) | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the first config inherits from the second one, you are able to just:
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EDIT: With composition instead of inheritance, this comment is no longer relevant. |
||||||||||||||||||
| cfg.ffnn_cfg.input_dim = self.output_dim | ||||||||||||||||||
| self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg) | ||||||||||||||||||
|
Comment on lines
+191
to
+192
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leave the configs immutable. Always safer wrt. bugs.
Suggested change
This creates copies of the dataclasses as needed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or we could not change anything and throw an error if a value is configured that is wrong.. |
||||||||||||||||||
| self.output_dim = self.ffnn.output_dim | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward( | ||||||||||||||||||
| self, | ||||||||||||||||||
| history: torch.Tensor, # [B, H] | ||||||||||||||||||
| ) -> torch.Tensor: # [B, 1, P] | ||||||||||||||||||
| embed = super().forward(history) | ||||||||||||||||||
| output = self.ffnn(embed) | ||||||||||||||||||
| return output | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_fullsum( | ||||||||||||||||||
| self, | ||||||||||||||||||
| targets: torch.Tensor, # [B, S] | ||||||||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B] | ||||||||||||||||||
| embed, _ = super().forward_fullsum(targets, target_lengths) | ||||||||||||||||||
| output = self.ffnn(embed) | ||||||||||||||||||
| return output, target_lengths | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_viterbi( | ||||||||||||||||||
| self, | ||||||||||||||||||
| targets: torch.Tensor, # [B, T] | ||||||||||||||||||
| target_lengths: torch.Tensor, # [B] | ||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B] | ||||||||||||||||||
| embed, _ = super().forward_viterbi(targets, target_lengths) | ||||||||||||||||||
| output = self.ffnn(embed) | ||||||||||||||||||
| return output, target_lengths | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs docs.