Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions i6_models/assemblies/transducer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .prediction_network import *
from .joint_network import *
117 changes: 117 additions & 0 deletions i6_models/assemblies/transducer/joint_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
__all__ = ["TransducerJointNetworkV1Config", "TransducerJointNetworkV1"]

from dataclasses import dataclass
from typing import Any, Dict, Tuple, Union

import torch
from torch import nn

from i6_models.config import ModelConfiguration
from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1


@dataclass
class TransducerJointNetworkV1Config(ModelConfiguration):
"""
Configuration for the Transducer Joint Network.

Attributes:
ffnn_cfg: Configuration for the internal feed-forward network.
"""

ffnn_cfg: FeedForwardBlockV1Config


class TransducerJointNetworkV1(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs docs.

def __init__(
self,
cfg: TransducerJointNetworkV1Config,
enc_input_dim: int,
pred_input_dim: int,
) -> None:
super().__init__()
hidden_dim = cfg.ffnn_cfg.layer_sizes[0]
self.enc_proj = nn.Linear(enc_input_dim, hidden_dim, bias=True)
self.pred_proj = nn.Linear(pred_input_dim, hidden_dim, bias=False) # Bias handled by enc_proj

self.activation = cfg.ffnn_cfg.layer_activations[0]
self.dropout = nn.Dropout(cfg.ffnn_cfg.dropouts[0]) if cfg.ffnn_cfg.dropouts else nn.Identity()

# Build the rest of the network (if any)
if len(cfg.ffnn_cfg.layer_sizes) > 1:
remaining_cfg = FeedForwardBlockV1Config(
input_dim=hidden_dim,
layer_sizes=cfg.ffnn_cfg.layer_sizes[1:],
dropouts=cfg.ffnn_cfg.dropouts[1:] if cfg.ffnn_cfg.dropouts else None,
layer_activations=cfg.ffnn_cfg.layer_activations[1:],
use_layer_norm=cfg.ffnn_cfg.use_layer_norm,
)
self.ffnn = FeedForwardBlockV1(remaining_cfg)
else:
self.ffnn = nn.Identity()

self.output_dim = cfg.ffnn_cfg.layer_sizes[-1]

def _forward_joint(self, enc: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
# Project individually then broadcast-sum
enc_proj = self.enc_proj(enc).unsqueeze(2) # [B, T, 1, H]
pred_proj = self.pred_proj(pred).unsqueeze(1) # [B, 1, U, H]

combined = enc_proj + pred_proj

if self.activation is not None:
combined = self.activation(combined)
combined = self.dropout(combined)

return self.ffnn(combined)

def forward(
self,
source_encodings: torch.Tensor, # [1, T, E]
target_encodings: torch.Tensor, # [B, S, P]
Comment on lines +70 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are source_encodings the output of the acoustic encoder and the target_encodings the output of the prediction network? Maybe we could rename (+ document) this better.

) -> torch.Tensor: # [B, T, S, F]
"""
Forward pass for recognition.
"""
output = self._forward_joint(source_encodings, target_encodings)

if not self.training:
output = torch.log_softmax(output, dim=-1) # [B, T, S, F]
Comment on lines +78 to +79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I'm a big fan of switching between logits and log probs based on whether it's train time or not. I'd rather pass a parameter or leave the log softmax to the forward_step function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I think we usually get logits in the train step and apply the appropriate softmax function there

return output

def forward_viterbi(
self,
source_encodings: torch.Tensor, # [B, T, E]
source_lengths: torch.Tensor, # [B]
target_encodings: torch.Tensor, # [B, T, P]
target_lengths: torch.Tensor, # [B]
) -> torch.Tensor: # [B, T, F]
"""
Forward pass for Viterbi training.
"""
# For Viterbi, dimensions align (T=T), so we can sum directly without broadcasting
enc_proj = self.enc_proj(source_encodings)
pred_proj = self.pred_proj(target_encodings)
combined = enc_proj + pred_proj

if self.activation is not None:
combined = self.activation(combined)
combined = self.dropout(combined)

output = self.ffnn(combined) # [B, T, F]
if not self.training:
output = torch.log_softmax(output, dim=-1) # [B, T, F]
Comment on lines +102 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

return output, source_lengths, target_lengths

def forward_fullsum(
self,
source_encodings: torch.Tensor, # [B, T, E]
source_lengths: torch.Tensor, # [B]
target_encodings: torch.Tensor, # [B, S+1, P]
target_lengths: torch.Tensor, # [B]
) -> torch.Tensor: # [B, T, S+1, F]
"""
Forward pass for fullsum training. Returns output with shape [B, T, S+1, F].
"""
output = self._forward_joint(source_encodings, target_encodings)
return output, source_lengths, target_lengths
219 changes: 219 additions & 0 deletions i6_models/assemblies/transducer/prediction_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
__all__ = [
"EmbeddingTransducerPredictionNetworkV1Config",
"EmbeddingTransducerPredictionNetworkV1",
"FfnnTransducerPredictionNetworkV1Config",
"FfnnTransducerPredictionNetworkV1",
]

from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn

from i6_models.config import ModelConfiguration
from i6_models.parts.ffnn import FeedForwardBlockV1Config, FeedForwardBlockV1


@dataclass
class EmbeddingTransducerPredictionNetworkV1Config(ModelConfiguration):
"""
num_outputs: Number of output units (vocabulary size + blank).
blank_id: Index of the blank token.
context_history_size: Number of previous output tokens to consider as context
embedding_dim: Dimension of the embedding layer.
reduce_embedding: Whether to use a reduction mechanism for the context embedding.
num_reduction_heads: Number of reduction heads if reduce_embedding is True.
"""

num_outputs: int
blank_id: int
context_history_size: int
embedding_dim: int
reduce_embedding: bool
num_reduction_heads: Optional[int]

def __post_init__(self):
super().__post_init__()
assert (self.num_reduction_heads is not None) == self.reduce_embedding

@classmethod
def from_child(cls, child_instance):
return cls(
child_instance.num_outputs,
child_instance.blank_id,
child_instance.context_history_size,
child_instance.embedding_dim,
child_instance.reduce_embedding,
child_instance.num_reduction_heads,
)
Comment on lines +41 to +49
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do, why is it necessary and different from

config2 = copy.deepcopy(config1)

?



class EmbeddingTransducerPredictionNetworkV1(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs docs.

def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config) -> None:
def __init__(self, cfg: EmbeddingTransducerPredictionNetworkV1Config):

super().__init__()
self.cfg = cfg
self.blank_id = self.cfg.blank_id
self.context_history_size = self.cfg.context_history_size
self.embedding = nn.Embedding(
num_embeddings=self.cfg.num_outputs,
embedding_dim=self.cfg.embedding_dim,
padding_idx=self.blank_id,
)
self.output_dim = (
self.cfg.embedding_dim * self.cfg.context_history_size
if not self.cfg.reduce_embedding
else self.cfg.embedding_dim
)

self.reduce_embedding = self.cfg.reduce_embedding
self.num_reduction_heads = self.cfg.num_reduction_heads
if self.reduce_embedding:
self.register_buffer(
"position_vectors",
torch.randn(
self.cfg.context_history_size,
self.cfg.num_reduction_heads,
self.cfg.embedding_dim,
),
)

def _reduce_embedding(self, emb: torch.Tensor) -> torch.Tensor:
"""
Reduces the context embedding using a weighted sum based on position vectors.
"""
emb_expanded = emb.unsqueeze(3) # [B, S, H, 1, E]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider unsqueezing from the back.

pos_expanded = self.position_vectors.unsqueeze(0).unsqueeze(0) # [1, 1, H, K, E]
alpha = (emb_expanded * pos_expanded).sum(dim=-1, keepdim=True) # [B, S, H, K, 1]
weighted = alpha * emb_expanded # [B, S, H, K, E]
reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider indexing dims from the back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim can be a tuple of ints, so we could do it in one step
https://docs.pytorch.org/docs/stable/generated/torch.sum.html

Suggested change
reduced = weighted.sum(dim=2).sum(dim=2) # [B, S, E]
reduced = weighted.sum(dim=(-2, -1)) # [B, S, E]

reduced *= 1.0 / (self.cfg.num_reduction_heads * self.cfg.context_history_size)
return reduced

def _forward_embedding(self, history: torch.Tensor) -> torch.Tensor:
"""
Processes the input history through the embedding layer and optional reduction.
"""
if len(history.shape) == 2: # reshape if input shape [B, H]
history = history.view(*history.shape[:-1], 1, history.shape[-1]) # [B, 1, H]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*history.shape[:-1] reads odd.. that should be the same as history.shape[0], since we have len(history.shape) == 2.. but talk to @NeoLegends about making this work with more batch dim.

embed = self.embedding(history) # [B, S, H, E]
if self.reduce_embedding:
embed = self._reduce_embedding(embed) # [B, S, E]
else:
embed = embed.flatten(start_dim=-2) # [B, S, H*E]
return embed

def forward(
self,
history: torch.Tensor, # [B, H]
) -> torch.Tensor: # [B, 1, P]
"""
Forward pass for recognition mode.
"""
embed = self._forward_embedding(history)
return embed

def forward_fullsum(
self,
targets: torch.Tensor, # [B, S]
target_lengths: torch.Tensor, # [B]
) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B]
"""
Forward pass for fullsum training.
"""
non_context_padding = torch.full(
(targets.size(0), self.cfg.context_history_size),
fill_value=self.blank_id,
dtype=targets.dtype,
device=targets.device,
) # [B, H]
extended_targets = torch.cat([non_context_padding, targets], dim=1) # [B, S+H]
history = torch.stack(
[
extended_targets[:, self.cfg.context_history_size - 1 - i : (-i if i != 0 else None)]
for i in reversed(range(self.cfg.context_history_size))
],
dim=-1,
) # [B, S+1, H]
Comment on lines +124 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChatGPT suggested this code

        B, S = targets.shape
        H = self.cfg.context_history_size

        # Pad left with H blanks: [B, S+H]
        extended = F.pad(targets, (H, 0), value=self.blank_id)

        # Unfold over sequence dim to get [B, S+1, H]
        # (PyTorch: unfold(size=H, step=1) "slides" a length-H window)
        history = extended.unfold(dimension=1, size=H, step=1)  # [B, S+1, H]

embed = self._forward_embedding(history)

return embed, target_lengths

def forward_viterbi(
self,
targets: torch.Tensor, # [B, T]
target_lengths: torch.Tensor, # [B]
) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B]
"""
Forward pass for viterbi training.
"""
B, T = targets.shape
history = torch.zeros(
(B, T, self.cfg.context_history_size),
dtype=targets.dtype,
device=targets.device,
) # [B, T, H]
recent_labels = torch.full(
(B, self.cfg.context_history_size),
fill_value=self.blank_id,
dtype=targets.dtype,
device=targets.device,
) # [B, H]

for t in range(T):
history[:, t, :] = recent_labels
current_labels = targets[:, t]
non_blank_positions = current_labels != self.blank_id
recent_labels[non_blank_positions, :-1] = recent_labels[non_blank_positions, 1:]
recent_labels[non_blank_positions, -1] = current_labels[non_blank_positions]
embed = self._forward_embedding(history)

return embed, target_lengths


@dataclass
class FfnnTransducerPredictionNetworkV1Config(EmbeddingTransducerPredictionNetworkV1Config):
"""
Attributes:
ffnn_cfg: Configuration for FFNN prediction network
"""

ffnn_cfg: FeedForwardBlockV1Config


class FfnnTransducerPredictionNetworkV1(EmbeddingTransducerPredictionNetworkV1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this class would benefit from using composition instead of inheritance. Make it contain/own an EmbeddingTransducerPredictionNetworkV1 instead of inheriting from one. That resolves all your issues wrt. config nesting/updating.

"""
FfnnTransducerPredictionNetworkV1 with feedforward layers.
"""

def __init__(self, cfg: FfnnTransducerPredictionNetworkV1Config):
super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the first config inherits from the second one, you are able to just:

Suggested change
super().__init__(EmbeddingTransducerPredictionNetworkV1Config.from_child(cfg))
super().__init__(cfg)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: With composition instead of inheritance, this comment is no longer relevant.

cfg.ffnn_cfg.input_dim = self.output_dim
self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg)
Comment on lines +191 to +192
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave the configs immutable. Always safer wrt. bugs.

Suggested change
cfg.ffnn_cfg.input_dim = self.output_dim
self.ffnn = FeedForwardBlockV1(cfg.ffnn_cfg)
self.ffnn = FeedForwardBlockV1(
dataclasses.replace(
cfg,
ffnn_cfg=dataclasses.replace(cfg.ffnn_cfg, input_dim=self.output_dim),
)
)

This creates copies of the dataclasses as needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we could not change anything and throw an error if a value is configured that is wrong..

self.output_dim = self.ffnn.output_dim

def forward(
self,
history: torch.Tensor, # [B, H]
) -> torch.Tensor: # [B, 1, P]
embed = super().forward(history)
output = self.ffnn(embed)
return output

def forward_fullsum(
self,
targets: torch.Tensor, # [B, S]
target_lengths: torch.Tensor, # [B]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the target_lengths seems to be unused in any of the forward calls.. is it needed?

) -> Tuple[torch.Tensor, torch.Tensor]: # [B, S + 1, P], [B]
embed, _ = super().forward_fullsum(targets, target_lengths)
output = self.ffnn(embed)
return output, target_lengths

def forward_viterbi(
self,
targets: torch.Tensor, # [B, T]
target_lengths: torch.Tensor, # [B]
) -> Tuple[torch.Tensor, torch.Tensor]: # [B, T, P], [B]
embed, _ = super().forward_viterbi(targets, target_lengths)
output = self.ffnn(embed)
return output, target_lengths
Loading