From 27f1266b61d1a631b0b864d9e133e507f6436f5f Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Fri, 9 May 2025 20:40:45 -0400 Subject: [PATCH 1/7] add simple transformer encoder --- cortex/model/block/__init__.py | 2 + .../model/block/_transformer_encoder_block.py | 37 ++ cortex/model/elemental/__init__.py | 3 + .../_bidirectional_self_attention.py | 37 ++ cortex/model/elemental/_mlp.py | 18 + cortex/model/root/__init__.py | 3 + .../model/root/_transformer_encoder_root.py | 341 ++++++++++++++++++ .../block/test_transformer_encoder_block.py | 24 ++ .../test_bidirectional_self_attention.py | 17 + tests/cortex/model/elemental/test_mlp.py | 12 + .../root/test_transformer_encoder_root.py | 122 +++++++ 11 files changed, 616 insertions(+) create mode 100644 cortex/model/block/_transformer_encoder_block.py create mode 100644 cortex/model/elemental/_bidirectional_self_attention.py create mode 100644 cortex/model/elemental/_mlp.py create mode 100644 cortex/model/root/_transformer_encoder_root.py create mode 100644 tests/cortex/model/block/test_transformer_encoder_block.py create mode 100644 tests/cortex/model/elemental/test_bidirectional_self_attention.py create mode 100644 tests/cortex/model/elemental/test_mlp.py create mode 100644 tests/cortex/model/root/test_transformer_encoder_root.py diff --git a/cortex/model/block/__init__.py b/cortex/model/block/__init__.py index d05a78e..523987d 100644 --- a/cortex/model/block/__init__.py +++ b/cortex/model/block/__init__.py @@ -1,5 +1,7 @@ from ._conv1d_resid_block import Conv1dResidBlock +from ._transformer_encoder_block import TransformerEncoderBlock __all__ = [ "Conv1dResidBlock", + "TransformerEncoderBlock", ] diff --git a/cortex/model/block/_transformer_encoder_block.py b/cortex/model/block/_transformer_encoder_block.py new file mode 100644 index 0000000..9e1c55a --- /dev/null +++ b/cortex/model/block/_transformer_encoder_block.py @@ -0,0 +1,37 @@ +from torch import Tensor, nn + +from cortex.model.elemental import BidirectionalSelfAttention +from cortex.model.elemental import MLP + + +class TransformerEncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_heads: int = 4, + bias: bool = False, + dropout_p: float = 0.0, + ): + super().__init__() + self.ln_1 = nn.LayerNorm(in_channels, bias=bias) + self.attn = BidirectionalSelfAttention(num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias) + self.ln_2 = nn.LayerNorm(in_channels, bias=bias) + self.mlp = MLP(in_channels, out_channels, bias=bias, dropout_p=dropout_p) + + if not in_channels == out_channels: + self.proj = nn.Linear(in_channels, out_channels, bias=bias) + else: + self.proj = None + + def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + resid, padding_mask = inputs + x, padding_mask = self.attn((self.ln_1(resid), padding_mask)) + x = resid + x + + if self.proj is not None: + resid = self.proj(resid) + + x = resid + self.mlp(self.ln_2(x)) + + return x, padding_mask diff --git a/cortex/model/elemental/__init__.py b/cortex/model/elemental/__init__.py index 020dbc8..e898b0c 100644 --- a/cortex/model/elemental/__init__.py +++ b/cortex/model/elemental/__init__.py @@ -1,13 +1,16 @@ from ._apply import Apply +from ._bidirectional_self_attention import BidirectionalSelfAttention from ._ddp_standardize import DDPStandardize from ._expression import Expression from ._functional import identity, permute_spatial_channel_dims, swish from ._layernorm import MaskLayerNorm1d from ._mean_pooling import MeanPooling, WeightedMeanPooling +from ._mlp import MLP from ._sine_pos_encoder import SinePosEncoder __all__ = [ "Apply", + "BidirectionalSelfAttention", "DDPStandardize", "Expression", "identity", diff --git a/cortex/model/elemental/_bidirectional_self_attention.py b/cortex/model/elemental/_bidirectional_self_attention.py new file mode 100644 index 0000000..49b1dda --- /dev/null +++ b/cortex/model/elemental/_bidirectional_self_attention.py @@ -0,0 +1,37 @@ +from torch import Tensor, nn + + +class BidirectionalSelfAttention(nn.Module): + def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0.0, bias: bool = False): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError("num_heads must evenly divide embed_dim") + + self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.dropout = nn.Dropout(dropout_p) + self.dropout_p = dropout_p + self.head_dim = embed_dim // num_heads + self.num_heads = num_heads + + def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + x, padding_mask = inputs + seq_len = x.size(-2) + queries, keys, values = self.c_attn(x).chunk(3, dim=-1) + + queries = queries.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + keys = keys.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + values = values.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + + attn_mask = padding_mask[..., None, :, None] + + res = nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_mask, + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=False, + ) + + res = res.transpose(-2, -3).flatten(start_dim=-2) + return self.dropout(res), padding_mask diff --git a/cortex/model/elemental/_mlp.py b/cortex/model/elemental/_mlp.py new file mode 100644 index 0000000..dd1c09b --- /dev/null +++ b/cortex/model/elemental/_mlp.py @@ -0,0 +1,18 @@ +from torch import nn + + +class MLP(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + bias: bool = False, + dropout_p: float = 0.0, + ): + out_channels = out_channels if out_channels else in_channels + super().__init__( + nn.Linear(in_channels, 4 * in_channels, bias=bias), + nn.GELU(), + nn.Linear(4 * in_channels, out_channels, bias=bias), + nn.Dropout(dropout_p), + ) diff --git a/cortex/model/root/__init__.py b/cortex/model/root/__init__.py index 0a5f1fe..cab3016 100644 --- a/cortex/model/root/__init__.py +++ b/cortex/model/root/__init__.py @@ -1,9 +1,12 @@ from ._abstract_root import RootNode, RootNodeOutput from ._conv1d_root import Conv1dRoot, Conv1dRootOutput +from ._transformer_encoder_root import TransformerEncoderRoot, TransformerEncoderRootOutput __all__ = [ "RootNode", "RootNodeOutput", "Conv1dRoot", "Conv1dRootOutput", + "TransformerEncoderRoot", + "TransformerEncoderRootOutput", ] diff --git a/cortex/model/root/_transformer_encoder_root.py b/cortex/model/root/_transformer_encoder_root.py new file mode 100644 index 0000000..0925ffc --- /dev/null +++ b/cortex/model/root/_transformer_encoder_root.py @@ -0,0 +1,341 @@ +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +import torch +from torch import LongTensor, nn + +from cortex.corruption import CorruptionProcess, GaussianCorruptionProcess, MaskCorruptionProcess +from cortex.model.block import TransformerEncoderBlock +from cortex.model.elemental import SinePosEncoder +from cortex.model.root import RootNode +from cortex.transforms import HuggingFaceTokenizerTransform, PadTransform, ToTensor + + +@dataclass +class TransformerEncoderRootOutput: + """Output of TransforerEncoderRoot.""" + + root_features: torch.Tensor + padding_mask: torch.Tensor + corrupt_frac: Optional[torch.Tensor] = None + src_tok_idxs: Optional[torch.LongTensor] = None + tgt_tok_idxs: Optional[torch.LongTensor] = None + src_tok_embs: Optional[torch.Tensor] = None + is_corrupted: Optional[torch.Tensor] = None + + +class TransformerEncoderRoot(RootNode): + """ + A root node transforming an array of discrete sequences to an array of continuous sequence embeddings + """ + + def __init__( + self, + tokenizer_transform: HuggingFaceTokenizerTransform, + max_len: int, + out_dim: int = 64, + embed_dim: int = 64, + channel_dim: int = 256, + num_blocks: int = 2, + num_heads: int = 4, + dropout_prob: float = 0.0, + pos_encoding: bool = True, + train_transforms=None, + eval_transforms=None, + corruption_process: Optional[CorruptionProcess] = None, + **kwargs, + ) -> None: + super().__init__() + self.tokenizer = tokenizer_transform.tokenizer + self.vocab_size = len(self.tokenizer.vocab) + self.max_len = max_len + self.pad_tok_idx = self.tokenizer.padding_idx + if num_blocks >= 1: + self.tok_encoder = nn.Embedding(self.vocab_size, embed_dim, padding_idx=self.pad_tok_idx) + # optional positional encoding + if pos_encoding: + self.pos_encoder = SinePosEncoder(embed_dim, dropout_prob, max_len, batch_first=True) + else: + self.pos_encoder = None + + # create encoder + self.embed_dim = embed_dim + self.num_blocks = num_blocks + if num_blocks >= 1: + self.out_dim = out_dim + encoder_modules = [] + resid_block_kwargs = { + "num_heads": num_heads, + } + if num_blocks == 1: + encoder_modules.append( + TransformerEncoderBlock(embed_dim, out_dim, dropout_p=dropout_prob, **resid_block_kwargs) + ) + else: + encoder_modules.append(TransformerEncoderBlock(embed_dim, channel_dim, **resid_block_kwargs)) + + encoder_modules.extend( + [ + TransformerEncoderBlock( + channel_dim, + channel_dim, + **resid_block_kwargs, + ) + for _ in range(num_blocks - 2) + ] + ) + + encoder_modules.append( + TransformerEncoderBlock( + channel_dim, + out_dim, + dropout_p=dropout_prob, + **resid_block_kwargs, + ) + ) + self.encoder = nn.Sequential(*encoder_modules) + + shared_transforms = [ + tokenizer_transform, # convert np.array([str, str, ...]) to list[list[int, int, ...]] + ToTensor(padding_value=self.pad_tok_idx), # convert list[list[int, int, ...]] to tensor + PadTransform(max_length=self.max_len, pad_value=self.pad_tok_idx), # pad to max_len + ] + train_transforms = [] if train_transforms is None else list(train_transforms.values()) + eval_transforms = [] if eval_transforms is None else list(eval_transforms.values()) + self.train_transform = nn.Sequential(*(train_transforms + shared_transforms)) + self.eval_transform = nn.Sequential(*(eval_transforms + shared_transforms)) + self.corruption_process = corruption_process + + def initialize_weights(self, **kwargs): + # default random initialization + pass + + def get_token_embedding(self, tok_idx: int): + return self.tok_encoder(torch.tensor(tok_idx, device=self.device)) + + @property + def device(self): + return self.tok_encoder.weight.device + + def init_seq( + self, + inputs: Optional[Union[np.ndarray, torch.Tensor]] = None, # TODO deprecate + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + corrupt_frac: float = 0.0, + **kwargs, + ): + # infer input type if not specified + if inputs is not None: + if isinstance(inputs, np.ndarray): + seq_array = inputs + if isinstance(inputs, LongTensor): + tgt_tok_idxs = inputs + elif isinstance(inputs, torch.Tensor): + src_tok_embs = inputs + msg = "inputs is deprecated, use a specific argument instead" + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + # Determine batch size from any available input + batch_size = None + if seq_array is not None: + batch_size = seq_array.shape[0] + elif tgt_tok_idxs is not None: + batch_size = tgt_tok_idxs.shape[0] + elif src_tok_embs is not None: + batch_size = src_tok_embs.shape[0] + + # Fallback to default batch size of 1 if no inputs are provided + if batch_size is None: + batch_size = 1 + + if "mask_frac" in kwargs: + corrupt_frac = kwargs["mask_frac"] + msg = "mask_frac is deprecated, use corrupt_frac instead." + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + if self.corruption_process is not None and corrupt_frac is None: + corrupt_frac = self.corruption_process.sample_corrupt_frac(n=batch_size).to(self.device) + elif isinstance(corrupt_frac, float): + corrupt_frac = torch.full((batch_size,), corrupt_frac, device=self.device) + elif isinstance(corrupt_frac, torch.Tensor): + # Move tensor to the correct device + corrupt_frac = corrupt_frac.to(self.device) + else: + corrupt_frac = torch.full((batch_size,), 0.0, device=self.device) + + return seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac + + def tokenize_seq( + self, + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + ): + # begin forward pass from raw sequence + if seq_array is not None: + assert tgt_tok_idxs is None + assert src_tok_embs is None + if self.training: + tgt_tok_idxs = self.train_transform(seq_array) + else: + tgt_tok_idxs = self.eval_transform(seq_array) + tgt_tok_idxs = tgt_tok_idxs.to(self.device) + + # truncate token sequence to max context length + if tgt_tok_idxs is not None: + assert src_tok_embs is None + # truncate to max context length, keep final stop token + if tgt_tok_idxs.size(-1) > self.max_len: + tmp_tok_idxs = tgt_tok_idxs[..., : self.max_len - 1] + tgt_tok_idxs = torch.cat([tmp_tok_idxs, tgt_tok_idxs[..., -1:]], dim=-1) + + if corruption_allowed is None and tgt_tok_idxs is not None: + corruption_allowed = self.tokenizer.get_corruptible_mask(tgt_tok_idxs) + + # begin forward pass from tokenized sequence + if tgt_tok_idxs is not None: + # apply masking corruption + if isinstance(self.corruption_process, MaskCorruptionProcess) and ( + (isinstance(corrupt_frac, float) and corrupt_frac > 0.0) + or (isinstance(corrupt_frac, torch.Tensor) and torch.any(corrupt_frac > 0.0)) + ): + src_tok_idxs, is_corrupted = self.corruption_process( + x_start=tgt_tok_idxs, + mask_val=self.tokenizer.masking_idx, + corruption_allowed=corruption_allowed, + corrupt_frac=corrupt_frac, + ) + else: + src_tok_idxs = tgt_tok_idxs + is_corrupted = ( + torch.full_like(src_tok_idxs, False, dtype=torch.bool) if is_corrupted is None else is_corrupted + ) + + padding_mask = src_tok_idxs != self.pad_tok_idx + + if src_tok_embs is not None: + assert seq_array is None + assert padding_mask is not None + src_tok_idxs = None + + return ( + src_tok_idxs, + tgt_tok_idxs, + corruption_allowed, + is_corrupted, + padding_mask, + ) + + def embed_seq( + self, + src_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + normalize_embeds: bool = True, + ): + # begin forward pass from token embeddings + if src_tok_embs is None: + src_tok_embs = self.tok_encoder(src_tok_idxs) + if normalize_embeds: + src_tok_embs = src_tok_embs / src_tok_embs.norm(dim=-1, keepdim=True).clamp_min(1e-6) + src_tok_embs = src_tok_embs * math.sqrt(self.embed_dim) + + # apply gaussian embedding corruption + if isinstance(self.corruption_process, GaussianCorruptionProcess) and ( + (isinstance(corrupt_frac, float) and corrupt_frac > 0.0) + or (isinstance(corrupt_frac, torch.Tensor) and torch.any(corrupt_frac > 0.0)) + ): + assert corruption_allowed is not None + src_tok_embs, is_corrupted = self.corruption_process( + x_start=src_tok_embs, + corruption_allowed=corruption_allowed[..., None], + corrupt_frac=corrupt_frac, + ) + is_corrupted = is_corrupted.sum(-1).bool() + else: + none_corrupted = torch.zeros(*src_tok_embs.shape[:-1], dtype=torch.bool).to(src_tok_embs.device) + is_corrupted = none_corrupted if is_corrupted is None else is_corrupted + + return src_tok_embs, is_corrupted + + def process_seq( + self, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ): + # apply positional encoding if it exists + if self.pos_encoder is not None: + src_features = self.pos_encoder(src_tok_embs) + else: + src_features = src_tok_embs + + # main forward pass + src_features, _ = self.encoder((src_features, padding_mask.to(src_features))) + + return src_features + + def forward( + self, + inputs: Optional[Union[np.ndarray, torch.Tensor]] = None, # TODO deprecate + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + **kwargs, + ) -> TransformerEncoderRootOutput: + """ + Args: + seq_array: (batch_size,) array of discrete sequences (e.g. text strings) + Returns: + outputs: {'root_features': torch.Tensor, 'padding_mask': torch.Tensor} + """ + seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac = self.init_seq( + inputs, seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac, **kwargs + ) + ( + src_tok_idxs, + tgt_tok_idxs, + corruption_allowed, + is_corrupted, + padding_mask, + ) = self.tokenize_seq( + seq_array, + tgt_tok_idxs, + src_tok_embs, + padding_mask, + corrupt_frac, + is_corrupted, + corruption_allowed, + ) + src_tok_embs, is_corrupted = self.embed_seq( + src_tok_idxs, src_tok_embs, corrupt_frac, is_corrupted, corruption_allowed + ) + src_features = self.process_seq(src_tok_embs, padding_mask) + # Make sure corrupt_frac is on the same device as other tensors + if isinstance(corrupt_frac, torch.Tensor): + corrupt_frac = corrupt_frac.to(src_tok_embs.device) + + outputs = TransformerEncoderRootOutput( + root_features=src_features.contiguous(), + padding_mask=padding_mask, + src_tok_embs=src_tok_embs, + src_tok_idxs=src_tok_idxs, + tgt_tok_idxs=tgt_tok_idxs, + is_corrupted=is_corrupted, + corrupt_frac=corrupt_frac, + ) + return outputs diff --git a/tests/cortex/model/block/test_transformer_encoder_block.py b/tests/cortex/model/block/test_transformer_encoder_block.py new file mode 100644 index 0000000..e5420a8 --- /dev/null +++ b/tests/cortex/model/block/test_transformer_encoder_block.py @@ -0,0 +1,24 @@ +import torch + +from cortex.model.block import TransformerEncoderBlock + + +BATCH_SIZE = 2 +NUM_HEADS = 3 +EMBED_DIM = 12 +SEQ_LEN = 5 + + +def test_transformer_encoder_block(): + + module = TransformerEncoderBlock( + in_channels=EMBED_DIM, + out_channels=EMBED_DIM, + num_heads=NUM_HEADS, + ) + + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) + padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) + x_prime, _ = module((x, padding_mask)) + + assert x_prime.shape == x.shape diff --git a/tests/cortex/model/elemental/test_bidirectional_self_attention.py b/tests/cortex/model/elemental/test_bidirectional_self_attention.py new file mode 100644 index 0000000..b201425 --- /dev/null +++ b/tests/cortex/model/elemental/test_bidirectional_self_attention.py @@ -0,0 +1,17 @@ +import torch +from cortex.model.elemental import BidirectionalSelfAttention + +BATCH_SIZE = 2 +NUM_HEADS = 3 +EMBED_DIM = 12 +SEQ_LEN = 5 + + +def test_bidirectional_self_attention(): + module = BidirectionalSelfAttention(num_heads=NUM_HEADS, embed_dim=EMBED_DIM, dropout_p=0.0, bias=False) + + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) + padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) + x_prime, _ = module((x, padding_mask)) + + assert x_prime.shape == x.shape diff --git a/tests/cortex/model/elemental/test_mlp.py b/tests/cortex/model/elemental/test_mlp.py new file mode 100644 index 0000000..e5e940c --- /dev/null +++ b/tests/cortex/model/elemental/test_mlp.py @@ -0,0 +1,12 @@ +import torch +from cortex.model.elemental import MLP + + +def test_mlp(): + in_channels = 32 + module = MLP(in_channels) + + x = torch.randn(2, 3, in_channels) + res = module(x) + + assert res.shape == x.shape diff --git a/tests/cortex/model/root/test_transformer_encoder_root.py b/tests/cortex/model/root/test_transformer_encoder_root.py new file mode 100644 index 0000000..93cf6cc --- /dev/null +++ b/tests/cortex/model/root/test_transformer_encoder_root.py @@ -0,0 +1,122 @@ +import numpy as np +import torch + +from cortex.constants import COMPLEX_SEP_TOKEN +from cortex.corruption import MaskCorruptionProcess +from cortex.model.root import TransformerEncoderRoot, TransformerEncoderRootOutput +from cortex.tokenization import ProteinSequenceTokenizerFast +from cortex.transforms import HuggingFaceTokenizerTransform + + +def test_transformer_encoder_root(): + batch_size = 2 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_heads = 3 + num_blocks = 7 + + max_seq_len = 13 + dropout_prob = 0.125 + pos_encoding = True + tokenizer = ProteinSequenceTokenizerFast() + + root_node = TransformerEncoderRoot( + tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer), + max_len=max_seq_len, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + dropout_prob=dropout_prob, + pos_encoding=pos_encoding, + ) + + # src_tok_idxs = torch.randint(0, vocab_size, (batch_size, max_seq_len)) + seq_array = np.array( + [ + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + ] + ) + root_output = root_node(seq_array) + assert isinstance(root_output, TransformerEncoderRootOutput) + root_features = root_output.root_features + padding_mask = root_output.padding_mask + + assert torch.is_tensor(root_features) + assert torch.is_tensor(padding_mask) + + assert root_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert padding_mask.size() == torch.Size((batch_size, max_seq_len)) + + +def test_transformer_encoder_root_with_per_element_corrupt_frac(): + """Test TransformerEncoderRoot handles per-element corrupt_frac correctly.""" + batch_size = 4 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_heads = 3 + max_seq_len = 13 + tokenizer = ProteinSequenceTokenizerFast() + + # Create a root node with corruption process + corruption_process = MaskCorruptionProcess() + root_node = TransformerEncoderRoot( + tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer), + max_len=max_seq_len, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_heads=num_heads, + corruption_process=corruption_process, + ) + + # Create input sequences + seq_array = np.array( + [ + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + ] + ) + + # Test case 1: Scalar corrupt_frac + scalar_corrupt_frac = 0.3 + root_output1 = root_node(seq_array, corrupt_frac=scalar_corrupt_frac) + + # Verify corrupt_frac is a tensor with batch dimension + assert isinstance(root_output1.corrupt_frac, torch.Tensor) + assert root_output1.corrupt_frac.shape[0] == batch_size + assert torch.allclose( + root_output1.corrupt_frac, + torch.tensor([scalar_corrupt_frac] * batch_size, device=root_output1.corrupt_frac.device), + ) + + # Test case 2: Per-element corrupt_frac + per_element_corrupt_frac = torch.tensor([0.1, 0.2, 0.3, 0.4]) + root_output2 = root_node(seq_array, corrupt_frac=per_element_corrupt_frac) + + # Verify corrupt_frac maintains per-element values + assert isinstance(root_output2.corrupt_frac, torch.Tensor) + assert root_output2.corrupt_frac.shape[0] == batch_size + + # Debug: Print the actual values + print(f"Expected: {per_element_corrupt_frac}") + print(f"Actual: {root_output2.corrupt_frac}") + + # Temporarily commenting out this assertion until we fix the issue + assert torch.allclose(root_output2.corrupt_frac, per_element_corrupt_frac.to(root_output2.corrupt_frac.device)) + + # Test case 3: None corrupt_frac (should sample from corruption process) + root_output3 = root_node(seq_array, corrupt_frac=None) + + # Verify corrupt_frac is a tensor with batch dimension + assert isinstance(root_output3.corrupt_frac, torch.Tensor) + assert root_output3.corrupt_frac.shape[0] == batch_size + # Values should be between 0 and 1 + assert torch.all(root_output3.corrupt_frac >= 0.0) + assert torch.all(root_output3.corrupt_frac <= 1.0) From e037d24f31c6d7c76543ae795920603aab765808 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Fri, 9 May 2025 20:43:36 -0400 Subject: [PATCH 2/7] add transformer root config --- .../hydra/roots/protein_seq_transformer.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 cortex/config/hydra/roots/protein_seq_transformer.yaml diff --git a/cortex/config/hydra/roots/protein_seq_transformer.yaml b/cortex/config/hydra/roots/protein_seq_transformer.yaml new file mode 100644 index 0000000..4c57e2e --- /dev/null +++ b/cortex/config/hydra/roots/protein_seq_transformer.yaml @@ -0,0 +1,18 @@ +protein_seq: + _target_: cortex.model.root.TransformerEncoderRoot + corruption_process: + _target_: cortex.corruption.MaskCorruptionProcess + tokenizer_transform: + _target_: cortex.transforms.HuggingFaceTokenizerTransform + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + max_len: 256 + out_dim: ${embed_dim} + embed_dim: ${embed_dim} + channel_dim: ${channel_dim} + num_blocks: 2 + num_heads: 4 + dropout_prob: ${dropout_prob} + pos_encoding: true + train_transforms: null + eval_transforms: null From b74ccc393dd7ec6b1c55db3f0f0930232109805f Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Fri, 9 May 2025 20:58:21 -0400 Subject: [PATCH 3/7] add simple transformer encoder branch --- .../protein_property_transformer.yaml | 7 ++ cortex/model/branch/__init__.py | 3 + .../branch/_transformer_encoder_branch.py | 97 +++++++++++++++++++ .../branch/test_transformer_encoder_branch.py | 46 +++++++++ 4 files changed, 153 insertions(+) create mode 100644 cortex/config/hydra/branches/protein_property_transformer.yaml create mode 100644 cortex/model/branch/_transformer_encoder_branch.py create mode 100644 tests/cortex/model/branch/test_transformer_encoder_branch.py diff --git a/cortex/config/hydra/branches/protein_property_transformer.yaml b/cortex/config/hydra/branches/protein_property_transformer.yaml new file mode 100644 index 0000000..5d8836d --- /dev/null +++ b/cortex/config/hydra/branches/protein_property_transformer.yaml @@ -0,0 +1,7 @@ +protein_property: + _target_: cortex.model.branch.TransformerEncoderBranch + out_dim: 8 + channel_dim: ${channel_dim} + num_blocks: 2 + num_heads: 4 + dropout_prob: ${dropout_prob} diff --git a/cortex/model/branch/__init__.py b/cortex/model/branch/__init__.py index 16ed371..dea3819 100644 --- a/cortex/model/branch/__init__.py +++ b/cortex/model/branch/__init__.py @@ -1,9 +1,12 @@ from ._abstract_branch import BranchNode, BranchNodeOutput from ._conv1d_branch import Conv1dBranch, Conv1dBranchOutput +from ._transformer_encoder_branch import TransformerEncoderBranch, TransformerEncoderBranchOutput __all__ = [ "BranchNode", "BranchNodeOutput", "Conv1dBranch", "Conv1dBranchOutput", + "TransformerEncoderBranch", + "TransformerEncoderBranchOutput", ] diff --git a/cortex/model/branch/_transformer_encoder_branch.py b/cortex/model/branch/_transformer_encoder_branch.py new file mode 100644 index 0000000..84df4d8 --- /dev/null +++ b/cortex/model/branch/_transformer_encoder_branch.py @@ -0,0 +1,97 @@ +from dataclasses import dataclass + +import torch +from torch import nn + +from cortex.model.block import TransformerEncoderBlock +from cortex.model.branch import BranchNode, BranchNodeOutput +from cortex.model.elemental import ( + Apply, + Expression, + MeanPooling, + WeightedMeanPooling, + identity, +) +from cortex.model.trunk import PaddedTrunkOutput + + +@dataclass +class TransformerEncoderBranchOutput(BranchNodeOutput): + branch_mask: torch.Tensor + pooled_features: torch.Tensor + + +class TransformerEncoderBranch(BranchNode): + """ + Branch node which transforms aggregated trunk features to task branch specific features + """ + + def __init__( + self, + in_dim: int, + out_dim: int = 64, + channel_dim: int = 64, + num_blocks: int = 2, + num_heads: int = 5, + dropout_prob: float = 0.0, + pooling_type: str = "mean", + **kwargs, + ): + super().__init__() + # create encoder + self.in_dim = in_dim + self.out_dim = out_dim + self.channel_dim = channel_dim + self.num_blocks = num_blocks + + if num_blocks == 0: + # add projection if dims don't match + encoder_modules = [ + Expression(identity) if in_dim == out_dim else Apply(nn.Linear(in_dim, out_dim, bias=False)) + ] + else: + # conv layers expect inputs with shape (batch_size, input_dim, num_tokens) + encoder_modules = [] + + if num_blocks == 1: + encoder_modules.append(TransformerEncoderBlock(in_dim, out_dim, num_heads, dropout_p=dropout_prob)) + elif num_blocks > 1: + encoder_modules.append(TransformerEncoderBlock(in_dim, channel_dim, num_heads, dropout_p=dropout_prob)) + encoder_modules.extend( + [ + TransformerEncoderBlock(channel_dim, channel_dim, num_heads, dropout_p=dropout_prob) + for _ in range(num_blocks - 2) + ] + ) + encoder_modules.append(TransformerEncoderBlock(channel_dim, out_dim, num_heads, dropout_p=dropout_prob)) + + self.encoder = nn.Sequential(*encoder_modules) + if pooling_type == "mean": + self.pooling_op = MeanPooling() + elif pooling_type == "weighted_mean": + self.pooling_op = WeightedMeanPooling(out_dim) + else: + raise NotImplementedError + + def forward( + self, + trunk_outputs: PaddedTrunkOutput, + ) -> TransformerEncoderBranchOutput: + """ + Args: + trunk_outputs: {'trunk_features': torch.Tensor, 'padding_mask': torch.Tensor} + Returns: + outputs: {'branch_features': torch.Tensor, 'branch_mask': torch.Tensor, 'pooled_features': torch.Tensor} + """ + trunk_features = trunk_outputs.trunk_features + padding_mask = trunk_outputs.padding_mask + + branch_features, branch_mask = self.encoder((trunk_features, padding_mask.to(trunk_features))) + pooled_features = self.pooling_op(branch_features, branch_mask) + + branch_outputs = TransformerEncoderBranchOutput( + branch_features=branch_features.contiguous(), + branch_mask=branch_mask, + pooled_features=pooled_features, + ) + return branch_outputs diff --git a/tests/cortex/model/branch/test_transformer_encoder_branch.py b/tests/cortex/model/branch/test_transformer_encoder_branch.py new file mode 100644 index 0000000..6e59051 --- /dev/null +++ b/tests/cortex/model/branch/test_transformer_encoder_branch.py @@ -0,0 +1,46 @@ +import torch + +from cortex.model.branch import TransformerEncoderBranch, TransformerEncoderBranchOutput +from cortex.model.trunk import PaddedTrunkOutput + + +def test_conv_1d_branch(): + in_dim = 12 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_blocks = 7 + num_heads = 3 + max_seq_len = 13 + batch_size = 17 + dropout_prob = 0.125 + layernorm = True + + branch_node = TransformerEncoderBranch( + in_dim=in_dim, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + dropout_prob=dropout_prob, + layernorm=layernorm, + ) + + trunk_output = PaddedTrunkOutput( + trunk_features=torch.rand(batch_size, max_seq_len, in_dim), + padding_mask=torch.ones(batch_size, max_seq_len, dtype=torch.float), + ) + branch_output = branch_node(trunk_output) + assert isinstance(branch_output, TransformerEncoderBranchOutput) + branch_features = branch_output.branch_features + branch_mask = branch_output.branch_mask + pooled_features = branch_output.pooled_features + + assert torch.is_tensor(branch_features) + assert torch.is_tensor(branch_mask) + assert torch.is_tensor(pooled_features) + + assert branch_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert branch_mask.size() == torch.Size((batch_size, max_seq_len)) + assert pooled_features.size() == torch.Size((batch_size, out_dim)) From 47dfc677f0c4f8fdc3700a1dc3cd2b459db33175 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Fri, 9 May 2025 20:59:46 -0400 Subject: [PATCH 4/7] reformat test --- tests/cortex/model/block/test_transformer_encoder_block.py | 2 -- .../cortex/model/elemental/test_bidirectional_self_attention.py | 1 + tests/cortex/model/elemental/test_mlp.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cortex/model/block/test_transformer_encoder_block.py b/tests/cortex/model/block/test_transformer_encoder_block.py index e5420a8..4f44377 100644 --- a/tests/cortex/model/block/test_transformer_encoder_block.py +++ b/tests/cortex/model/block/test_transformer_encoder_block.py @@ -2,7 +2,6 @@ from cortex.model.block import TransformerEncoderBlock - BATCH_SIZE = 2 NUM_HEADS = 3 EMBED_DIM = 12 @@ -10,7 +9,6 @@ def test_transformer_encoder_block(): - module = TransformerEncoderBlock( in_channels=EMBED_DIM, out_channels=EMBED_DIM, diff --git a/tests/cortex/model/elemental/test_bidirectional_self_attention.py b/tests/cortex/model/elemental/test_bidirectional_self_attention.py index b201425..c0d6842 100644 --- a/tests/cortex/model/elemental/test_bidirectional_self_attention.py +++ b/tests/cortex/model/elemental/test_bidirectional_self_attention.py @@ -1,4 +1,5 @@ import torch + from cortex.model.elemental import BidirectionalSelfAttention BATCH_SIZE = 2 diff --git a/tests/cortex/model/elemental/test_mlp.py b/tests/cortex/model/elemental/test_mlp.py index e5e940c..f36a7ac 100644 --- a/tests/cortex/model/elemental/test_mlp.py +++ b/tests/cortex/model/elemental/test_mlp.py @@ -1,4 +1,5 @@ import torch + from cortex.model.elemental import MLP From bdb36cda00bb2d312731f8e125f75ab0a289bfdc Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Fri, 9 May 2025 21:00:12 -0400 Subject: [PATCH 5/7] reformat cortex --- cortex/model/block/_transformer_encoder_block.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cortex/model/block/_transformer_encoder_block.py b/cortex/model/block/_transformer_encoder_block.py index 9e1c55a..dfa4730 100644 --- a/cortex/model/block/_transformer_encoder_block.py +++ b/cortex/model/block/_transformer_encoder_block.py @@ -1,7 +1,6 @@ from torch import Tensor, nn -from cortex.model.elemental import BidirectionalSelfAttention -from cortex.model.elemental import MLP +from cortex.model.elemental import MLP, BidirectionalSelfAttention class TransformerEncoderBlock(nn.Module): @@ -15,7 +14,9 @@ def __init__( ): super().__init__() self.ln_1 = nn.LayerNorm(in_channels, bias=bias) - self.attn = BidirectionalSelfAttention(num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias) + self.attn = BidirectionalSelfAttention( + num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias + ) self.ln_2 = nn.LayerNorm(in_channels, bias=bias) self.mlp = MLP(in_channels, out_channels, bias=bias, dropout_p=dropout_p) From aac430c921aa8c04a4c5d91f1ff87dba61eb6090 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Fri, 9 May 2025 21:09:43 -0400 Subject: [PATCH 6/7] add transformer configs to tutorial dirs --- .../branches/protein_property_transformer.yaml | 6 ++++++ tutorials/hydra/roots/protein_seq_transformer.yaml | 14 ++++++++++++++ 2 files changed, 20 insertions(+) create mode 100644 tutorials/hydra/branches/protein_property_transformer.yaml create mode 100644 tutorials/hydra/roots/protein_seq_transformer.yaml diff --git a/tutorials/hydra/branches/protein_property_transformer.yaml b/tutorials/hydra/branches/protein_property_transformer.yaml new file mode 100644 index 0000000..04d8972 --- /dev/null +++ b/tutorials/hydra/branches/protein_property_transformer.yaml @@ -0,0 +1,6 @@ +protein_property: + _target_: cortex.model.branch.TransformerEncoderBranch + out_dim: 8 + channel_dim: ${feature_dim} + num_blocks: 1 + num_heads: 4 diff --git a/tutorials/hydra/roots/protein_seq_transformer.yaml b/tutorials/hydra/roots/protein_seq_transformer.yaml new file mode 100644 index 0000000..c2ab280 --- /dev/null +++ b/tutorials/hydra/roots/protein_seq_transformer.yaml @@ -0,0 +1,14 @@ +protein_seq: + _target_: cortex.model.root.TransformerEncoderRoot + corruption_process: + _target_: cortex.corruption.MaskCorruptionProcess + tokenizer_transform: + _target_: cortex.transforms.HuggingFaceTokenizerTransform + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + max_len: 256 + embed_dim: ${feature_dim} + channel_dim: ${feature_dim} + out_dim: ${feature_dim} + num_blocks: 2 + num_heads: 4 From 30f51043a76533177fee891a5f0838cbaf6dc0ba Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Fri, 9 May 2025 21:40:35 -0400 Subject: [PATCH 7/7] add causal self attention support --- .gitignore | 3 + .../protein_property_transformer.yaml | 3 +- .../hydra/roots/protein_seq_transformer.yaml | 3 +- cortex/model/block/__init__.py | 4 +- ...encoder_block.py => _transformer_block.py} | 16 ++-- cortex/model/branch/__init__.py | 6 +- ...coder_branch.py => _transformer_branch.py} | 28 +++--- cortex/model/elemental/__init__.py | 2 + .../model/elemental/_causal_self_attention.py | 35 ++++++++ cortex/model/root/__init__.py | 6 +- ...r_encoder_root.py => _transformer_root.py} | 24 ++--- .../model/block/test_transformer_block.py | 38 ++++++++ .../model/branch/test_transformer_branch.py | 88 +++++++++++++++++++ .../branch/test_transformer_encoder_branch.py | 46 ---------- .../test_causal_self_attention.py} | 10 +-- ...coder_root.py => test_transformer_root.py} | 58 +++++++++++- .../protein_property_transformer.yaml | 1 + .../hydra/roots/protein_seq_transformer.yaml | 1 + 18 files changed, 276 insertions(+), 96 deletions(-) rename cortex/model/block/{_transformer_encoder_block.py => _transformer_block.py} (71%) rename cortex/model/branch/{_transformer_encoder_branch.py => _transformer_branch.py} (76%) create mode 100644 cortex/model/elemental/_causal_self_attention.py rename cortex/model/root/{_transformer_encoder_root.py => _transformer_root.py} (95%) create mode 100644 tests/cortex/model/block/test_transformer_block.py create mode 100644 tests/cortex/model/branch/test_transformer_branch.py delete mode 100644 tests/cortex/model/branch/test_transformer_encoder_branch.py rename tests/cortex/model/{block/test_transformer_encoder_block.py => elemental/test_causal_self_attention.py} (54%) rename tests/cortex/model/root/{test_transformer_encoder_root.py => test_transformer_root.py} (71%) diff --git a/.gitignore b/.gitignore index c339a42..caa8e50 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ __pycache__ docs/build temp .coverage +*.ipynb_checkpoints +*/.cache +*/lightning_logs diff --git a/cortex/config/hydra/branches/protein_property_transformer.yaml b/cortex/config/hydra/branches/protein_property_transformer.yaml index 5d8836d..c0e20d4 100644 --- a/cortex/config/hydra/branches/protein_property_transformer.yaml +++ b/cortex/config/hydra/branches/protein_property_transformer.yaml @@ -1,7 +1,8 @@ protein_property: - _target_: cortex.model.branch.TransformerEncoderBranch + _target_: cortex.model.branch.TransformerBranch out_dim: 8 channel_dim: ${channel_dim} num_blocks: 2 num_heads: 4 dropout_prob: ${dropout_prob} + is_causal: false diff --git a/cortex/config/hydra/roots/protein_seq_transformer.yaml b/cortex/config/hydra/roots/protein_seq_transformer.yaml index 4c57e2e..3127d25 100644 --- a/cortex/config/hydra/roots/protein_seq_transformer.yaml +++ b/cortex/config/hydra/roots/protein_seq_transformer.yaml @@ -1,5 +1,5 @@ protein_seq: - _target_: cortex.model.root.TransformerEncoderRoot + _target_: cortex.model.root.TransformerRoot corruption_process: _target_: cortex.corruption.MaskCorruptionProcess tokenizer_transform: @@ -12,6 +12,7 @@ protein_seq: channel_dim: ${channel_dim} num_blocks: 2 num_heads: 4 + is_causal: false dropout_prob: ${dropout_prob} pos_encoding: true train_transforms: null diff --git a/cortex/model/block/__init__.py b/cortex/model/block/__init__.py index 523987d..c604d6d 100644 --- a/cortex/model/block/__init__.py +++ b/cortex/model/block/__init__.py @@ -1,7 +1,7 @@ from ._conv1d_resid_block import Conv1dResidBlock -from ._transformer_encoder_block import TransformerEncoderBlock +from ._transformer_block import TransformerBlock __all__ = [ "Conv1dResidBlock", - "TransformerEncoderBlock", + "TransformerBlock", ] diff --git a/cortex/model/block/_transformer_encoder_block.py b/cortex/model/block/_transformer_block.py similarity index 71% rename from cortex/model/block/_transformer_encoder_block.py rename to cortex/model/block/_transformer_block.py index dfa4730..af6c1a9 100644 --- a/cortex/model/block/_transformer_encoder_block.py +++ b/cortex/model/block/_transformer_block.py @@ -1,9 +1,9 @@ from torch import Tensor, nn -from cortex.model.elemental import MLP, BidirectionalSelfAttention +from cortex.model.elemental import MLP, BidirectionalSelfAttention, CausalSelfAttention -class TransformerEncoderBlock(nn.Module): +class TransformerBlock(nn.Module): def __init__( self, in_channels: int, @@ -11,12 +11,18 @@ def __init__( num_heads: int = 4, bias: bool = False, dropout_p: float = 0.0, + is_causal: bool = False, ): super().__init__() self.ln_1 = nn.LayerNorm(in_channels, bias=bias) - self.attn = BidirectionalSelfAttention( - num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias - ) + + if is_causal: + self.attn = CausalSelfAttention(num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias) + else: + self.attn = BidirectionalSelfAttention( + num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias + ) + self.ln_2 = nn.LayerNorm(in_channels, bias=bias) self.mlp = MLP(in_channels, out_channels, bias=bias, dropout_p=dropout_p) diff --git a/cortex/model/branch/__init__.py b/cortex/model/branch/__init__.py index dea3819..a9f9538 100644 --- a/cortex/model/branch/__init__.py +++ b/cortex/model/branch/__init__.py @@ -1,12 +1,12 @@ from ._abstract_branch import BranchNode, BranchNodeOutput from ._conv1d_branch import Conv1dBranch, Conv1dBranchOutput -from ._transformer_encoder_branch import TransformerEncoderBranch, TransformerEncoderBranchOutput +from ._transformer_branch import TransformerBranch, TransformerBranchOutput __all__ = [ "BranchNode", "BranchNodeOutput", "Conv1dBranch", "Conv1dBranchOutput", - "TransformerEncoderBranch", - "TransformerEncoderBranchOutput", + "TransformerBranch", + "TransformerBranchOutput", ] diff --git a/cortex/model/branch/_transformer_encoder_branch.py b/cortex/model/branch/_transformer_branch.py similarity index 76% rename from cortex/model/branch/_transformer_encoder_branch.py rename to cortex/model/branch/_transformer_branch.py index 84df4d8..29e7a58 100644 --- a/cortex/model/branch/_transformer_encoder_branch.py +++ b/cortex/model/branch/_transformer_branch.py @@ -3,7 +3,7 @@ import torch from torch import nn -from cortex.model.block import TransformerEncoderBlock +from cortex.model.block import TransformerBlock from cortex.model.branch import BranchNode, BranchNodeOutput from cortex.model.elemental import ( Apply, @@ -16,12 +16,12 @@ @dataclass -class TransformerEncoderBranchOutput(BranchNodeOutput): +class TransformerBranchOutput(BranchNodeOutput): branch_mask: torch.Tensor pooled_features: torch.Tensor -class TransformerEncoderBranch(BranchNode): +class TransformerBranch(BranchNode): """ Branch node which transforms aggregated trunk features to task branch specific features """ @@ -33,6 +33,7 @@ def __init__( channel_dim: int = 64, num_blocks: int = 2, num_heads: int = 5, + is_causal: bool = False, dropout_prob: float = 0.0, pooling_type: str = "mean", **kwargs, @@ -53,17 +54,20 @@ def __init__( # conv layers expect inputs with shape (batch_size, input_dim, num_tokens) encoder_modules = [] + block_kwargs = { + "num_heads": num_heads, + "is_causal": is_causal, + "dropout_p": dropout_prob, + } + if num_blocks == 1: - encoder_modules.append(TransformerEncoderBlock(in_dim, out_dim, num_heads, dropout_p=dropout_prob)) + encoder_modules.append(TransformerBlock(in_dim, out_dim, **block_kwargs)) elif num_blocks > 1: - encoder_modules.append(TransformerEncoderBlock(in_dim, channel_dim, num_heads, dropout_p=dropout_prob)) + encoder_modules.append(TransformerBlock(in_dim, channel_dim, **block_kwargs)) encoder_modules.extend( - [ - TransformerEncoderBlock(channel_dim, channel_dim, num_heads, dropout_p=dropout_prob) - for _ in range(num_blocks - 2) - ] + [TransformerBlock(channel_dim, channel_dim, **block_kwargs) for _ in range(num_blocks - 2)] ) - encoder_modules.append(TransformerEncoderBlock(channel_dim, out_dim, num_heads, dropout_p=dropout_prob)) + encoder_modules.append(TransformerBlock(channel_dim, out_dim, **block_kwargs)) self.encoder = nn.Sequential(*encoder_modules) if pooling_type == "mean": @@ -76,7 +80,7 @@ def __init__( def forward( self, trunk_outputs: PaddedTrunkOutput, - ) -> TransformerEncoderBranchOutput: + ) -> TransformerBranchOutput: """ Args: trunk_outputs: {'trunk_features': torch.Tensor, 'padding_mask': torch.Tensor} @@ -89,7 +93,7 @@ def forward( branch_features, branch_mask = self.encoder((trunk_features, padding_mask.to(trunk_features))) pooled_features = self.pooling_op(branch_features, branch_mask) - branch_outputs = TransformerEncoderBranchOutput( + branch_outputs = TransformerBranchOutput( branch_features=branch_features.contiguous(), branch_mask=branch_mask, pooled_features=pooled_features, diff --git a/cortex/model/elemental/__init__.py b/cortex/model/elemental/__init__.py index e898b0c..f305dbd 100644 --- a/cortex/model/elemental/__init__.py +++ b/cortex/model/elemental/__init__.py @@ -1,5 +1,6 @@ from ._apply import Apply from ._bidirectional_self_attention import BidirectionalSelfAttention +from ._causal_self_attention import CausalSelfAttention from ._ddp_standardize import DDPStandardize from ._expression import Expression from ._functional import identity, permute_spatial_channel_dims, swish @@ -11,6 +12,7 @@ __all__ = [ "Apply", "BidirectionalSelfAttention", + "CausalSelfAttention", "DDPStandardize", "Expression", "identity", diff --git a/cortex/model/elemental/_causal_self_attention.py b/cortex/model/elemental/_causal_self_attention.py new file mode 100644 index 0000000..0f1b76b --- /dev/null +++ b/cortex/model/elemental/_causal_self_attention.py @@ -0,0 +1,35 @@ +from torch import Tensor, nn + + +class CausalSelfAttention(nn.Module): + def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0.0, bias: bool = False): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError("num_heads must evenly divide embed_dim") + + self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.dropout = nn.Dropout(dropout_p) + self.dropout_p = dropout_p + self.head_dim = embed_dim // num_heads + self.num_heads = num_heads + + def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + x, padding_mask = inputs + seq_len = x.size(-2) + queries, keys, values = self.c_attn(x).chunk(3, dim=-1) + + queries = queries.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + keys = keys.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + values = values.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + + res = nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=None, + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=True, + ) + + res = res.transpose(-2, -3).flatten(start_dim=-2) + return self.dropout(res), padding_mask diff --git a/cortex/model/root/__init__.py b/cortex/model/root/__init__.py index cab3016..b2f1736 100644 --- a/cortex/model/root/__init__.py +++ b/cortex/model/root/__init__.py @@ -1,12 +1,12 @@ from ._abstract_root import RootNode, RootNodeOutput from ._conv1d_root import Conv1dRoot, Conv1dRootOutput -from ._transformer_encoder_root import TransformerEncoderRoot, TransformerEncoderRootOutput +from ._transformer_root import TransformerRoot, TransformerRootOutput __all__ = [ "RootNode", "RootNodeOutput", "Conv1dRoot", "Conv1dRootOutput", - "TransformerEncoderRoot", - "TransformerEncoderRootOutput", + "TransformerRoot", + "TransformerRootOutput", ] diff --git a/cortex/model/root/_transformer_encoder_root.py b/cortex/model/root/_transformer_root.py similarity index 95% rename from cortex/model/root/_transformer_encoder_root.py rename to cortex/model/root/_transformer_root.py index 0925ffc..75c4f2f 100644 --- a/cortex/model/root/_transformer_encoder_root.py +++ b/cortex/model/root/_transformer_root.py @@ -8,14 +8,14 @@ from torch import LongTensor, nn from cortex.corruption import CorruptionProcess, GaussianCorruptionProcess, MaskCorruptionProcess -from cortex.model.block import TransformerEncoderBlock +from cortex.model.block import TransformerBlock from cortex.model.elemental import SinePosEncoder from cortex.model.root import RootNode from cortex.transforms import HuggingFaceTokenizerTransform, PadTransform, ToTensor @dataclass -class TransformerEncoderRootOutput: +class TransformerRootOutput: """Output of TransforerEncoderRoot.""" root_features: torch.Tensor @@ -27,7 +27,7 @@ class TransformerEncoderRootOutput: is_corrupted: Optional[torch.Tensor] = None -class TransformerEncoderRoot(RootNode): +class TransformerRoot(RootNode): """ A root node transforming an array of discrete sequences to an array of continuous sequence embeddings """ @@ -41,6 +41,7 @@ def __init__( channel_dim: int = 256, num_blocks: int = 2, num_heads: int = 4, + is_causal: bool = False, dropout_prob: float = 0.0, pos_encoding: bool = True, train_transforms=None, @@ -69,17 +70,17 @@ def __init__( encoder_modules = [] resid_block_kwargs = { "num_heads": num_heads, + "dropout_p": dropout_prob, + "is_causal": is_causal, } if num_blocks == 1: - encoder_modules.append( - TransformerEncoderBlock(embed_dim, out_dim, dropout_p=dropout_prob, **resid_block_kwargs) - ) + encoder_modules.append(TransformerBlock(embed_dim, out_dim, **resid_block_kwargs)) else: - encoder_modules.append(TransformerEncoderBlock(embed_dim, channel_dim, **resid_block_kwargs)) + encoder_modules.append(TransformerBlock(embed_dim, channel_dim, **resid_block_kwargs)) encoder_modules.extend( [ - TransformerEncoderBlock( + TransformerBlock( channel_dim, channel_dim, **resid_block_kwargs, @@ -89,10 +90,9 @@ def __init__( ) encoder_modules.append( - TransformerEncoderBlock( + TransformerBlock( channel_dim, out_dim, - dropout_p=dropout_prob, **resid_block_kwargs, ) ) @@ -296,7 +296,7 @@ def forward( is_corrupted: Optional[torch.Tensor] = None, corruption_allowed: Optional[torch.Tensor] = None, **kwargs, - ) -> TransformerEncoderRootOutput: + ) -> TransformerRootOutput: """ Args: seq_array: (batch_size,) array of discrete sequences (e.g. text strings) @@ -329,7 +329,7 @@ def forward( if isinstance(corrupt_frac, torch.Tensor): corrupt_frac = corrupt_frac.to(src_tok_embs.device) - outputs = TransformerEncoderRootOutput( + outputs = TransformerRootOutput( root_features=src_features.contiguous(), padding_mask=padding_mask, src_tok_embs=src_tok_embs, diff --git a/tests/cortex/model/block/test_transformer_block.py b/tests/cortex/model/block/test_transformer_block.py new file mode 100644 index 0000000..8faddfc --- /dev/null +++ b/tests/cortex/model/block/test_transformer_block.py @@ -0,0 +1,38 @@ +import torch + +from cortex.model.block import TransformerBlock + +BATCH_SIZE = 2 +NUM_HEADS = 3 +EMBED_DIM = 12 +SEQ_LEN = 5 + + +def test_transformer_encoder_block(): + module = TransformerBlock( + in_channels=EMBED_DIM, + out_channels=EMBED_DIM, + num_heads=NUM_HEADS, + is_causal=False, + ) + + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) + padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) + x_prime, _ = module((x, padding_mask)) + + assert x_prime.shape == x.shape + + +def test_transformer_decoder_block(): + module = TransformerBlock( + in_channels=EMBED_DIM, + out_channels=EMBED_DIM, + num_heads=NUM_HEADS, + is_causal=True, + ) + + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) + padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) + x_prime, _ = module((x, padding_mask)) + + assert x_prime.shape == x.shape diff --git a/tests/cortex/model/branch/test_transformer_branch.py b/tests/cortex/model/branch/test_transformer_branch.py new file mode 100644 index 0000000..e8a12f9 --- /dev/null +++ b/tests/cortex/model/branch/test_transformer_branch.py @@ -0,0 +1,88 @@ +import torch + +from cortex.model.branch import TransformerBranch, TransformerBranchOutput +from cortex.model.trunk import PaddedTrunkOutput + + +def test_transformer_encoder_branch(): + in_dim = 12 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_blocks = 7 + num_heads = 3 + max_seq_len = 13 + batch_size = 17 + dropout_prob = 0.125 + is_causal = False + + branch_node = TransformerBranch( + in_dim=in_dim, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + is_causal=is_causal, + dropout_prob=dropout_prob, + ) + + trunk_output = PaddedTrunkOutput( + trunk_features=torch.rand(batch_size, max_seq_len, in_dim), + padding_mask=torch.ones(batch_size, max_seq_len, dtype=torch.float), + ) + branch_output = branch_node(trunk_output) + assert isinstance(branch_output, TransformerBranchOutput) + branch_features = branch_output.branch_features + branch_mask = branch_output.branch_mask + pooled_features = branch_output.pooled_features + + assert torch.is_tensor(branch_features) + assert torch.is_tensor(branch_mask) + assert torch.is_tensor(pooled_features) + + assert branch_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert branch_mask.size() == torch.Size((batch_size, max_seq_len)) + assert pooled_features.size() == torch.Size((batch_size, out_dim)) + + +def test_transformer_decoder_branch(): + in_dim = 12 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_blocks = 7 + num_heads = 3 + max_seq_len = 13 + batch_size = 17 + dropout_prob = 0.125 + is_causal = True + + branch_node = TransformerBranch( + in_dim=in_dim, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + is_causal=is_causal, + dropout_prob=dropout_prob, + ) + + trunk_output = PaddedTrunkOutput( + trunk_features=torch.rand(batch_size, max_seq_len, in_dim), + padding_mask=torch.ones(batch_size, max_seq_len, dtype=torch.float), + ) + branch_output = branch_node(trunk_output) + assert isinstance(branch_output, TransformerBranchOutput) + branch_features = branch_output.branch_features + branch_mask = branch_output.branch_mask + pooled_features = branch_output.pooled_features + + assert torch.is_tensor(branch_features) + assert torch.is_tensor(branch_mask) + assert torch.is_tensor(pooled_features) + + assert branch_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert branch_mask.size() == torch.Size((batch_size, max_seq_len)) + assert pooled_features.size() == torch.Size((batch_size, out_dim)) diff --git a/tests/cortex/model/branch/test_transformer_encoder_branch.py b/tests/cortex/model/branch/test_transformer_encoder_branch.py deleted file mode 100644 index 6e59051..0000000 --- a/tests/cortex/model/branch/test_transformer_encoder_branch.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - -from cortex.model.branch import TransformerEncoderBranch, TransformerEncoderBranchOutput -from cortex.model.trunk import PaddedTrunkOutput - - -def test_conv_1d_branch(): - in_dim = 12 - out_dim = 12 - embed_dim = 12 - channel_dim = 12 - num_blocks = 7 - num_heads = 3 - max_seq_len = 13 - batch_size = 17 - dropout_prob = 0.125 - layernorm = True - - branch_node = TransformerEncoderBranch( - in_dim=in_dim, - out_dim=out_dim, - embed_dim=embed_dim, - channel_dim=channel_dim, - num_blocks=num_blocks, - num_heads=num_heads, - dropout_prob=dropout_prob, - layernorm=layernorm, - ) - - trunk_output = PaddedTrunkOutput( - trunk_features=torch.rand(batch_size, max_seq_len, in_dim), - padding_mask=torch.ones(batch_size, max_seq_len, dtype=torch.float), - ) - branch_output = branch_node(trunk_output) - assert isinstance(branch_output, TransformerEncoderBranchOutput) - branch_features = branch_output.branch_features - branch_mask = branch_output.branch_mask - pooled_features = branch_output.pooled_features - - assert torch.is_tensor(branch_features) - assert torch.is_tensor(branch_mask) - assert torch.is_tensor(pooled_features) - - assert branch_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) - assert branch_mask.size() == torch.Size((batch_size, max_seq_len)) - assert pooled_features.size() == torch.Size((batch_size, out_dim)) diff --git a/tests/cortex/model/block/test_transformer_encoder_block.py b/tests/cortex/model/elemental/test_causal_self_attention.py similarity index 54% rename from tests/cortex/model/block/test_transformer_encoder_block.py rename to tests/cortex/model/elemental/test_causal_self_attention.py index 4f44377..b8edb42 100644 --- a/tests/cortex/model/block/test_transformer_encoder_block.py +++ b/tests/cortex/model/elemental/test_causal_self_attention.py @@ -1,6 +1,6 @@ import torch -from cortex.model.block import TransformerEncoderBlock +from cortex.model.elemental import CausalSelfAttention BATCH_SIZE = 2 NUM_HEADS = 3 @@ -8,12 +8,8 @@ SEQ_LEN = 5 -def test_transformer_encoder_block(): - module = TransformerEncoderBlock( - in_channels=EMBED_DIM, - out_channels=EMBED_DIM, - num_heads=NUM_HEADS, - ) +def test_causal_self_attention(): + module = CausalSelfAttention(num_heads=NUM_HEADS, embed_dim=EMBED_DIM, dropout_p=0.0, bias=False) x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) diff --git a/tests/cortex/model/root/test_transformer_encoder_root.py b/tests/cortex/model/root/test_transformer_root.py similarity index 71% rename from tests/cortex/model/root/test_transformer_encoder_root.py rename to tests/cortex/model/root/test_transformer_root.py index 93cf6cc..59fd3a0 100644 --- a/tests/cortex/model/root/test_transformer_encoder_root.py +++ b/tests/cortex/model/root/test_transformer_root.py @@ -3,7 +3,7 @@ from cortex.constants import COMPLEX_SEP_TOKEN from cortex.corruption import MaskCorruptionProcess -from cortex.model.root import TransformerEncoderRoot, TransformerEncoderRootOutput +from cortex.model.root import TransformerRoot, TransformerRootOutput from cortex.tokenization import ProteinSequenceTokenizerFast from cortex.transforms import HuggingFaceTokenizerTransform @@ -14,6 +14,7 @@ def test_transformer_encoder_root(): embed_dim = 12 channel_dim = 12 num_heads = 3 + is_causal = False num_blocks = 7 max_seq_len = 13 @@ -21,7 +22,7 @@ def test_transformer_encoder_root(): pos_encoding = True tokenizer = ProteinSequenceTokenizerFast() - root_node = TransformerEncoderRoot( + root_node = TransformerRoot( tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer), max_len=max_seq_len, out_dim=out_dim, @@ -29,6 +30,7 @@ def test_transformer_encoder_root(): channel_dim=channel_dim, num_blocks=num_blocks, num_heads=num_heads, + is_causal=is_causal, dropout_prob=dropout_prob, pos_encoding=pos_encoding, ) @@ -41,7 +43,7 @@ def test_transformer_encoder_root(): ] ) root_output = root_node(seq_array) - assert isinstance(root_output, TransformerEncoderRootOutput) + assert isinstance(root_output, TransformerRootOutput) root_features = root_output.root_features padding_mask = root_output.padding_mask @@ -59,18 +61,20 @@ def test_transformer_encoder_root_with_per_element_corrupt_frac(): embed_dim = 12 channel_dim = 12 num_heads = 3 + is_causal = False max_seq_len = 13 tokenizer = ProteinSequenceTokenizerFast() # Create a root node with corruption process corruption_process = MaskCorruptionProcess() - root_node = TransformerEncoderRoot( + root_node = TransformerRoot( tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer), max_len=max_seq_len, out_dim=out_dim, embed_dim=embed_dim, channel_dim=channel_dim, num_heads=num_heads, + is_causal=is_causal, corruption_process=corruption_process, ) @@ -120,3 +124,49 @@ def test_transformer_encoder_root_with_per_element_corrupt_frac(): # Values should be between 0 and 1 assert torch.all(root_output3.corrupt_frac >= 0.0) assert torch.all(root_output3.corrupt_frac <= 1.0) + + +def test_transformer_decoder_root(): + batch_size = 2 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_heads = 3 + is_causal = True + num_blocks = 7 + + max_seq_len = 13 + dropout_prob = 0.125 + pos_encoding = True + tokenizer = ProteinSequenceTokenizerFast() + + root_node = TransformerRoot( + tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer), + max_len=max_seq_len, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + is_causal=is_causal, + dropout_prob=dropout_prob, + pos_encoding=pos_encoding, + ) + + # src_tok_idxs = torch.randint(0, vocab_size, (batch_size, max_seq_len)) + seq_array = np.array( + [ + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + ] + ) + root_output = root_node(seq_array) + assert isinstance(root_output, TransformerRootOutput) + root_features = root_output.root_features + padding_mask = root_output.padding_mask + + assert torch.is_tensor(root_features) + assert torch.is_tensor(padding_mask) + + assert root_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert padding_mask.size() == torch.Size((batch_size, max_seq_len)) diff --git a/tutorials/hydra/branches/protein_property_transformer.yaml b/tutorials/hydra/branches/protein_property_transformer.yaml index 04d8972..5a4ede3 100644 --- a/tutorials/hydra/branches/protein_property_transformer.yaml +++ b/tutorials/hydra/branches/protein_property_transformer.yaml @@ -4,3 +4,4 @@ protein_property: channel_dim: ${feature_dim} num_blocks: 1 num_heads: 4 + is_causal: false diff --git a/tutorials/hydra/roots/protein_seq_transformer.yaml b/tutorials/hydra/roots/protein_seq_transformer.yaml index c2ab280..ea93943 100644 --- a/tutorials/hydra/roots/protein_seq_transformer.yaml +++ b/tutorials/hydra/roots/protein_seq_transformer.yaml @@ -12,3 +12,4 @@ protein_seq: out_dim: ${feature_dim} num_blocks: 2 num_heads: 4 + is_causal: false