From e0c474585ab3bf480c933f1469e28356259925e9 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Sat, 10 May 2025 13:35:10 -0400 Subject: [PATCH 1/9] add autoregressive LM leaf and tasks --- .../tasks/generation/gfp_autoregressive.yaml | 25 ++++ .../stable_proteins_autoregressive.yaml | 25 ++++ cortex/model/leaf/__init__.py | 28 ++-- cortex/model/leaf/_autoregressive_lm_leaf.py | 140 ++++++++++++++++++ cortex/model/tree/_seq_model_tree.py | 19 +++ cortex/task/__init__.py | 1 + cortex/task/_autoregressive_lm_task.py | 86 +++++++++++ 7 files changed, 315 insertions(+), 9 deletions(-) create mode 100644 cortex/config/hydra/tasks/generation/gfp_autoregressive.yaml create mode 100644 cortex/config/hydra/tasks/generation/stable_proteins_autoregressive.yaml create mode 100644 cortex/model/leaf/_autoregressive_lm_leaf.py create mode 100644 cortex/task/_autoregressive_lm_task.py diff --git a/cortex/config/hydra/tasks/generation/gfp_autoregressive.yaml b/cortex/config/hydra/tasks/generation/gfp_autoregressive.yaml new file mode 100644 index 0000000..c4f7f53 --- /dev/null +++ b/cortex/config/hydra/tasks/generation/gfp_autoregressive.yaml @@ -0,0 +1,25 @@ +gfp: + _target_: cortex.task.AutoregressiveLanguageModelTask + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + input_map: + protein_seq: ['tokenized_seq'] + root_key: protein_seq + # Add BLOSUM62-based substitution corruption for data augmentation + corruption_process: + _target_: cortex.corruption.SubstitutionCorruptionProcess.from_blosum62 + corruption_rate: 0.1 # Apply corruption to 10% of masked tokens + data_module: + _target_: cortex.data.data_module.TaskDataModule + _recursive_: false + batch_size: ${fit.batch_size} + balance_train_partition: null + drop_last: true + lengths: [1.0, 0.0] + train_on_everything: false + num_workers: ${num_workers} + dataset_config: + _target_: cortex.data.dataset.TAPEFluorescenceDataset + root: ${dataset_root_dir} + download: ${download_datasets} + train: ??? diff --git a/cortex/config/hydra/tasks/generation/stable_proteins_autoregressive.yaml b/cortex/config/hydra/tasks/generation/stable_proteins_autoregressive.yaml new file mode 100644 index 0000000..9712672 --- /dev/null +++ b/cortex/config/hydra/tasks/generation/stable_proteins_autoregressive.yaml @@ -0,0 +1,25 @@ +stable_proteins: + _target_: cortex.task.AutoregressiveLanguageModelTask + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + input_map: + protein_seq: ['tokenized_seq'] + root_key: protein_seq + # Add BLOSUM62-based substitution corruption for data augmentation + corruption_process: + _target_: cortex.corruption.SubstitutionCorruptionProcess.from_blosum62 + corruption_rate: 0.1 # Apply corruption to 10% of masked tokens + data_module: + _target_: cortex.data.data_module.TaskDataModule + _recursive_: false + batch_size: ${fit.batch_size} + balance_train_partition: null + drop_last: true + lengths: [1.0, 0.0] + train_on_everything: false + num_workers: ${num_workers} + dataset_config: + _target_: cortex.data.dataset.TAPEStabilityDataset + root: ${dataset_root_dir} + download: ${download_datasets} + train: ??? diff --git a/cortex/model/leaf/__init__.py b/cortex/model/leaf/__init__.py index 2171882..8f1e7b8 100644 --- a/cortex/model/leaf/__init__.py +++ b/cortex/model/leaf/__init__.py @@ -1,4 +1,10 @@ from ._abstract_leaf import LeafNode, LeafNodeOutput +from ._autoregressive_lm_leaf import ( + AutoregressiveLanguageModelLeaf, + AutoregressiveLanguageModelLeafOutput, + autoregressive_log_likelihood, + format_autoregressive_lm_ensemble_output, +) from ._classifier_leaf import ClassifierLeaf, ClassifierLeafOutput, check_probs, format_classifier_ensemble_output from ._denoising_lm_leaf import ( DenoisingLanguageModelLeaf, @@ -11,21 +17,25 @@ from ._seq_regressor_leaf import SequenceRegressorLeaf, adjust_sequence_mask __all__ = [ - "LeafNode", - "LeafNodeOutput", + "adjust_sequence_mask", + "AutoregressiveLanguageModelLeaf", + "AutoregressiveLanguageModelLeafOutput", + "autoregressive_log_likelihood", + "check_probs", + "check_scale", "ClassifierLeaf", "ClassifierLeafOutput", - "check_probs", - "format_classifier_ensemble_output", "DenoisingLanguageModelLeaf", "DenoisingLanguageModelLeafOutput", + "format_autoregressive_lm_ensemble_output", + "format_classifier_ensemble_output", "format_denoising_lm_ensemble_output", - "RegressorLeaf", - "RegressorLeafOutput", - "check_scale", "format_regressor_ensemble_output", - "SequenceRegressorLeaf", - "adjust_sequence_mask", + "LeafNode", + "LeafNodeOutput", "mlm_conditional_log_likelihood", "mlm_pseudo_log_likelihood", + "RegressorLeaf", + "RegressorLeafOutput", + "SequenceRegressorLeaf", ] diff --git a/cortex/model/leaf/_autoregressive_lm_leaf.py b/cortex/model/leaf/_autoregressive_lm_leaf.py new file mode 100644 index 0000000..1fcee02 --- /dev/null +++ b/cortex/model/leaf/_autoregressive_lm_leaf.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch +from torch.nn import functional as F + +from cortex.corruption._abstract_corruption import CorruptionProcess +from cortex.model.branch import BranchNodeOutput +from cortex.model.leaf import ClassifierLeaf, LeafNodeOutput +from cortex.model.root import RootNodeOutput + +# avoids circular import +if TYPE_CHECKING: + from cortex.model.tree import NeuralTreeOutput + + +@dataclass +class AutoregressiveLanguageModelLeafOutput(LeafNodeOutput): + logits: torch.Tensor + + +class AutoregressiveLanguageModelLeaf(ClassifierLeaf): + """ + Leaf node which transforms branch sequence features to discrete sequence logits. + + Can optionally apply a corruption process to the masked tokens during training, + which serves as a form of data augmentation to increase sample diversity and + potentially improve embedding quality. This is particularly useful with + biologically-informed corruption processes like BLOSUM62-based substitutions + for protein sequences. + """ + + def __init__( + self, + *args, + corruption_process: Optional[CorruptionProcess] = None, + corruption_rate: float = 0.1, + **kwargs, + ): + """ + Initialize the AutoregressiveLanguageModelLeaf. + + Args: + corruption_process: Optional corruption process to apply to masked targets during training + corruption_rate: Fixed rate at which to apply corruption to masked targets (default: 0.1) + *args: Additional positional arguments to pass to the parent class + **kwargs: Additional keyword arguments to pass to the parent class + """ + super().__init__(*args, **kwargs) + self.corruption_process = corruption_process + self.corruption_rate = corruption_rate + + def forward(self, branch_outputs: BranchNodeOutput, *args, **kwargs) -> AutoregressiveLanguageModelLeafOutput: + """ + Args: + branch_outputs: TransforerBranchOutput (is_causal should be true) + Returns: + outputs: AutoregressiveLanguageModelLeafOutput + """ + branch_features = branch_outputs.branch_features + logits = self.encoder(branch_features) + outputs = AutoregressiveLanguageModelLeafOutput(logits=logits) + return outputs + + def loss( + self, + leaf_outputs: AutoregressiveLanguageModelLeafOutput, + root_outputs: RootNodeOutput, + *args, + **kwargs, + ) -> torch.Tensor: + masked_logits, masked_targets = self.format_outputs(leaf_outputs, root_outputs) + return self.loss_fn(masked_logits, masked_targets) + + def format_outputs( + self, leaf_outputs: AutoregressiveLanguageModelLeafOutput, root_outputs: RootNodeOutput + ) -> tuple[torch.Tensor, torch.Tensor]: + logits = leaf_outputs.logits[..., :-1, :] + tgt_tok_idxs = root_outputs.tgt_tok_idxs.to(logits.device)[..., 1:] + + # Apply data augmentation if corruption_process is provided and we're in training mode + if self.corruption_process is not None and self.training: + # Apply the corruption with fixed rate + tgt_tok_idxs, _ = self.corruption_process(tgt_tok_idxs, corrupt_frac=self.corruption_rate) + + return logits.flatten(0, -2), tgt_tok_idxs.flatten() + + def evaluate( + self, + leaf_outputs: AutoregressiveLanguageModelLeafOutput, + root_outputs: RootNodeOutput, + *args, + **kwargs, + ) -> dict: + # The model is already in eval mode during evaluation, so no corruption will be applied + logits, targets = self.format_outputs(leaf_outputs, root_outputs) + pred_class = logits.argmax(-1) + correct = pred_class.eq(targets) + log_prob = F.log_softmax(logits, dim=-1) + perplexity = 2 ** (-1 * log_prob / math.log(2)) + + metrics = { + "nll": self.loss_fn(logits, targets).item(), + "acc": correct.float().mean().item(), + "perplexity": perplexity.mean().item(), + } + + return metrics + + +def format_autoregressive_lm_ensemble_output( + leaf_outputs: list[AutoregressiveLanguageModelLeafOutput], + root_outputs: list[RootNodeOutput], + task_key: str, +): + res = {} + logits = [l_out.logits.flatten(0, -2) for l_out in leaf_outputs] + tgt_tok_idxs = [r_out.tgt_tok_idxs.flatten() for r_out in root_outputs] + + res[f"{task_key}_logits"] = torch.stack(logits) + res[f"{task_key}_targets"] = torch.stack(tgt_tok_idxs) + + return res + + +def autoregressive_log_likelihood( + tree_output: NeuralTreeOutput, + x_instances, + root_key: str, +): + """ + Compute the autoregressive log-likelihood of the tokens in `x_instances`. + """ + task_outputs = tree_output.fetch_task_outputs(root_key) + token_probs = task_outputs["logits"].log_softmax(-1) # (ensemble_size, batch_size, seq_len, vocab_size) + token_cll = token_probs.gather(-1, x_instances[None, ..., None]).squeeze(-1) # (ensemble_size, batch_size, seq_len) + return token_cll.mean(dim=(0, -1)) diff --git a/cortex/model/tree/_seq_model_tree.py b/cortex/model/tree/_seq_model_tree.py index a914543..71a063d 100644 --- a/cortex/model/tree/_seq_model_tree.py +++ b/cortex/model/tree/_seq_model_tree.py @@ -13,6 +13,10 @@ from torch import nn from cortex.model import online_weight_update_ +from cortex.model.leaf import ( + AutoregressiveLanguageModelLeaf, + format_autoregressive_lm_ensemble_output, +) from cortex.model.leaf._classifier_leaf import ClassifierLeaf, format_classifier_ensemble_output from cortex.model.leaf._denoising_lm_leaf import ( DenoisingLanguageModelLeaf, @@ -534,6 +538,21 @@ def format_task_outputs(self, task_out, task_keys, task_leaves): format_denoising_lm_ensemble_output(leaf_outputs=values, root_outputs=root_outputs, task_key=t_key) ) + # autoregressive language model leaves + values = [ + l_out + for l_key, l_out in task_out.items() + if l_key in task_leaves[t_key] and isinstance(self.leaf_nodes[l_key], AutoregressiveLanguageModelLeaf) + ] + if len(values) > 0: + root_keys = [self.leaf_nodes[l_key].root_key for l_key in task_leaves[t_key]] + root_outputs = [task_out[r_key] for r_key in root_keys] + predict_out.update( + format_autoregressive_lm_ensemble_output( + leaf_outputs=values, root_outputs=root_outputs, task_key=t_key + ) + ) + return predict_out def call_from_str_array( diff --git a/cortex/task/__init__.py b/cortex/task/__init__.py index cc94391..0257400 100644 --- a/cortex/task/__init__.py +++ b/cortex/task/__init__.py @@ -1,4 +1,5 @@ from ._abstract_task import BaseTask +from ._autoregressive_lm_task import AutoregressiveLanguageModelTask from ._classification import ClassificationTask from ._denoising_lm_task import DenoisingLanguageModelTask from ._regression import RegressionTask diff --git a/cortex/task/_autoregressive_lm_task.py b/cortex/task/_autoregressive_lm_task.py new file mode 100644 index 0000000..b46386e --- /dev/null +++ b/cortex/task/_autoregressive_lm_task.py @@ -0,0 +1,86 @@ +from collections import OrderedDict +from typing import Any, Optional + +import numpy as np +import torch +from transformers import BertTokenizer + +from cortex.data.data_module import TaskDataModule +from cortex.model.leaf import AutoregressiveLanguageModelLeaf +from cortex.task._abstract_task import BaseTask + + +class AutoregressiveLanguageModelTask(BaseTask): + def __init__( + self, + data_module: TaskDataModule, + input_map: dict[str, str], + leaf_key: str, + root_key: str, + tokenizer: BertTokenizer, + corruption_process: Optional[Any] = None, + corruption_rate: float = 0.1, + **kwargs, + ) -> None: + """ + Non-autoregressive text denoising task + + Args: + data_module: The data module for this task + input_map: Mapping from root keys to data column names + leaf_key: Key for this leaf in the neural tree + root_key: Key for the root node in the neural tree + tokenizer: Tokenizer used for tokenizing sequences + corruption_process: Optional corruption process to apply to masked targets during training + corruption_rate: Fixed rate at which to apply corruption to masked targets (default: 0.1) + """ + super().__init__( + data_module=data_module, + input_map=input_map, + leaf_key=leaf_key, + corrupt_train_inputs=True, + corrupt_inference_inputs=True, + ) + self.vocab_size = len(tokenizer.vocab) + self.root_key = root_key + self.corruption_process = corruption_process + self.corruption_rate = corruption_rate + + def format_inputs(self, batch: OrderedDict, corrupt_frac: Optional[float] = None) -> dict: + """ + Format input DataFrame for a `NeuralTree` object + """ + inputs = {} + for root_key, input_cols in self.input_map.items(): + inputs[root_key] = { + "seq_array": np.concatenate([np.array(batch[col]).reshape(-1, 1) for col in input_cols], axis=-1), + "corrupt_frac": corrupt_frac, + } + return inputs + + def create_leaf(self, in_dim: int, branch_key: str) -> AutoregressiveLanguageModelLeaf: + """ + Create the leaf node for this task to be added to a `NeuralTree` object. + """ + return AutoregressiveLanguageModelLeaf( + in_dim=in_dim, + num_classes=self.vocab_size, + branch_key=branch_key, + root_key=self.root_key, + last_layer_bias=True, + corruption_process=self.corruption_process, + corruption_rate=self.corruption_rate, + ) + + def compute_eval_metrics(self, ensemble_output, targets, task_key) -> dict: + logit_key = f"{task_key}_logits" + target_key = f"{task_key}_targets" + logits = ensemble_output[logit_key] + targets = ensemble_output[target_key][0] + avg_token_probs = logits.softmax(-1).mean(0) + top_1 = avg_token_probs.argmax(-1) + task_metrics = { + "nll": -1.0 * torch.distributions.Categorical(probs=avg_token_probs).log_prob(targets).mean().item(), + "top_1_acc": top_1.eq(targets).float().mean().item(), + } + return task_metrics From 7fcb2ed45e6271328b02ee3873f343adc27de0d0 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Sat, 10 May 2025 13:51:28 -0400 Subject: [PATCH 2/9] add some experimental features --- cortex/model/branch/_transformer_branch.py | 5 ++++- cortex/model/elemental/__init__.py | 2 ++ cortex/model/leaf/_classifier_leaf.py | 3 ++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/cortex/model/branch/_transformer_branch.py b/cortex/model/branch/_transformer_branch.py index 29e7a58..137e3f9 100644 --- a/cortex/model/branch/_transformer_branch.py +++ b/cortex/model/branch/_transformer_branch.py @@ -9,6 +9,7 @@ Apply, Expression, MeanPooling, + PoolingSelfAttention, WeightedMeanPooling, identity, ) @@ -32,7 +33,7 @@ def __init__( out_dim: int = 64, channel_dim: int = 64, num_blocks: int = 2, - num_heads: int = 5, + num_heads: int = 4, is_causal: bool = False, dropout_prob: float = 0.0, pooling_type: str = "mean", @@ -74,6 +75,8 @@ def __init__( self.pooling_op = MeanPooling() elif pooling_type == "weighted_mean": self.pooling_op = WeightedMeanPooling(out_dim) + elif pooling_type == "attention": + self.pooling_op = PoolingSelfAttention(num_heads=num_heads, embed_dim=out_dim, dropout_p=dropout_prob) else: raise NotImplementedError diff --git a/cortex/model/elemental/__init__.py b/cortex/model/elemental/__init__.py index f305dbd..f10e517 100644 --- a/cortex/model/elemental/__init__.py +++ b/cortex/model/elemental/__init__.py @@ -7,6 +7,7 @@ from ._layernorm import MaskLayerNorm1d from ._mean_pooling import MeanPooling, WeightedMeanPooling from ._mlp import MLP +from ._pooling_self_attention import PoolingSelfAttention from ._sine_pos_encoder import SinePosEncoder __all__ = [ @@ -20,6 +21,7 @@ "swish", "MaskLayerNorm1d", "MeanPooling", + "PoolingSelfAttention", "WeightedMeanPooling", "SinePosEncoder", ] diff --git a/cortex/model/leaf/_classifier_leaf.py b/cortex/model/leaf/_classifier_leaf.py index 88ec881..2d43ce1 100644 --- a/cortex/model/leaf/_classifier_leaf.py +++ b/cortex/model/leaf/_classifier_leaf.py @@ -82,7 +82,8 @@ def __init__( self.branch_key = branch_key self.root_key = root_key - encoder_modules = [] + # testing out normalizing the penultimate activations + encoder_modules = [nn.LayerNorm(in_dim, bias=False)] if num_layers >= 1: for _ in range(num_layers): encoder_modules.extend( From 4e2fa19bd9cbee5de7b3e8140078fda9bd451016 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Sat, 10 May 2025 13:53:56 -0400 Subject: [PATCH 3/9] add some experimental features --- .../elemental/_pooling_self_attention.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 cortex/model/elemental/_pooling_self_attention.py diff --git a/cortex/model/elemental/_pooling_self_attention.py b/cortex/model/elemental/_pooling_self_attention.py new file mode 100644 index 0000000..a0afa05 --- /dev/null +++ b/cortex/model/elemental/_pooling_self_attention.py @@ -0,0 +1,43 @@ +from torch import Tensor, nn + + +class PoolingSelfAttention(nn.Module): + def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0.0, bias: bool = False): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError("num_heads must evenly divide embed_dim") + + self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.dropout = nn.Dropout(dropout_p) + self.dropout_p = dropout_p + self.head_dim = embed_dim // num_heads + self.num_heads = num_heads + + def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + x, padding_mask = inputs + seq_len = x.size(-2) + queries, keys, values = self.c_attn(x).chunk(3, dim=-1) + + queries = queries.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + keys = keys.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + values = values.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + + # attn_mask: (*batch_shape, 1, num_queries, 1) + attn_mask = padding_mask[..., None, :, None] + queries = queries.sum(-2, keepdim=True) / attn_mask.sum(-2, keepdim=True) + + # attn_mask (*batch_shape, 1, 1, num_keys) + attn_mask = padding_mask[..., None, None, :] + + res = nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_mask, + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=False, + ) + + res = res.transpose(-2, -3).flatten(start_dim=-2) + res = self.dropout(res)[..., 0, :] # drop 1D query dim + return res From bb233da4084304c846ef489aeab4901775cc9735 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Sat, 10 May 2025 14:02:17 -0400 Subject: [PATCH 4/9] fix failing unit tests --- cortex/model/leaf/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cortex/model/leaf/__init__.py b/cortex/model/leaf/__init__.py index 8f1e7b8..ece10c0 100644 --- a/cortex/model/leaf/__init__.py +++ b/cortex/model/leaf/__init__.py @@ -1,11 +1,13 @@ from ._abstract_leaf import LeafNode, LeafNodeOutput + +# ruff: noqa: I001 +from ._classifier_leaf import ClassifierLeaf, ClassifierLeafOutput, check_probs, format_classifier_ensemble_output from ._autoregressive_lm_leaf import ( AutoregressiveLanguageModelLeaf, AutoregressiveLanguageModelLeafOutput, autoregressive_log_likelihood, format_autoregressive_lm_ensemble_output, ) -from ._classifier_leaf import ClassifierLeaf, ClassifierLeafOutput, check_probs, format_classifier_ensemble_output from ._denoising_lm_leaf import ( DenoisingLanguageModelLeaf, DenoisingLanguageModelLeafOutput, From ae700727f9fcf21ce6f38e47aa479d3fac5f1aa3 Mon Sep 17 00:00:00 2001 From: Stanton Date: Sat, 10 May 2025 20:34:02 +0000 Subject: [PATCH 5/9] bug fixes --- .../hydra/branches/protein_property_transformer.yaml | 1 + cortex/config/hydra/train_protein_model.yaml | 6 +++--- cortex/data/data_module/_task_data_module.py | 3 ++- cortex/model/elemental/_bidirectional_self_attention.py | 7 ++++--- cortex/model/elemental/_pooling_self_attention.py | 7 +++---- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cortex/config/hydra/branches/protein_property_transformer.yaml b/cortex/config/hydra/branches/protein_property_transformer.yaml index c0e20d4..fd04d5f 100644 --- a/cortex/config/hydra/branches/protein_property_transformer.yaml +++ b/cortex/config/hydra/branches/protein_property_transformer.yaml @@ -6,3 +6,4 @@ protein_property: num_heads: 4 dropout_prob: ${dropout_prob} is_causal: false + pooling_type: attention diff --git a/cortex/config/hydra/train_protein_model.yaml b/cortex/config/hydra/train_protein_model.yaml index 0350d77..9919b9e 100644 --- a/cortex/config/hydra/train_protein_model.yaml +++ b/cortex/config/hydra/train_protein_model.yaml @@ -2,9 +2,9 @@ defaults: - general_settings: default - logging: default - model_globals: default - - roots: [protein_seq] + - roots: [protein_seq_transformer] - trunk: default - - branches: [protein_property, generation] + - branches: [protein_property_transformer, generation] - tree: protein_model - tasks: - protein_property/log_fluorescence @@ -34,7 +34,7 @@ tree: weight_averaging: null optimizer: _target_: torch.optim.Adam - lr: 5e-3 + lr: 6e-4 weight_decay: 0. betas: [0.99, 0.999] fused: false diff --git a/cortex/data/data_module/_task_data_module.py b/cortex/data/data_module/_task_data_module.py index a0849bb..100bcd3 100644 --- a/cortex/data/data_module/_task_data_module.py +++ b/cortex/data/data_module/_task_data_module.py @@ -122,7 +122,8 @@ def get_dataloader(self, split: str = "train"): else: # Full batch for evaluation on the test set if split == "test": - self._dataloader_kwargs["batch_size"] = len(self.datasets[split]) + # self._dataloader_kwargs["batch_size"] = len(self.datasets[split]) + self._dataloader_kwargs["batch_size"] = 2 * self._batch_size dataloader = DataLoader(self.datasets[split], shuffle=True, drop_last=True, **self._dataloader_kwargs) if split == "test": self._dataloader_kwargs["batch_size"] = self._batch_size diff --git a/cortex/model/elemental/_bidirectional_self_attention.py b/cortex/model/elemental/_bidirectional_self_attention.py index 49b1dda..173f4b2 100644 --- a/cortex/model/elemental/_bidirectional_self_attention.py +++ b/cortex/model/elemental/_bidirectional_self_attention.py @@ -15,14 +15,15 @@ def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0 def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: x, padding_mask = inputs - seq_len = x.size(-2) - queries, keys, values = self.c_attn(x).chunk(3, dim=-1) + batch_size, seq_len, embed_dim = x.shape + queries, keys, values = self.c_attn(x).split(embed_dim, dim=-1) queries = queries.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) keys = keys.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) values = values.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) attn_mask = padding_mask[..., None, :, None] + attn_mask = attn_mask.expand(-1, -1, -1, seq_len).contiguous() res = nn.functional.scaled_dot_product_attention( queries, @@ -33,5 +34,5 @@ def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: is_causal=False, ) - res = res.transpose(-2, -3).flatten(start_dim=-2) + res = res.transpose(-2, -3).contiguous().flatten(start_dim=-2) return self.dropout(res), padding_mask diff --git a/cortex/model/elemental/_pooling_self_attention.py b/cortex/model/elemental/_pooling_self_attention.py index a0afa05..f039e14 100644 --- a/cortex/model/elemental/_pooling_self_attention.py +++ b/cortex/model/elemental/_pooling_self_attention.py @@ -13,8 +13,7 @@ def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0 self.head_dim = embed_dim // num_heads self.num_heads = num_heads - def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: - x, padding_mask = inputs + def forward(self, x: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: seq_len = x.size(-2) queries, keys, values = self.c_attn(x).chunk(3, dim=-1) @@ -27,7 +26,7 @@ def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: queries = queries.sum(-2, keepdim=True) / attn_mask.sum(-2, keepdim=True) # attn_mask (*batch_shape, 1, 1, num_keys) - attn_mask = padding_mask[..., None, None, :] + attn_mask = padding_mask[..., None, None, :].contiguous() res = nn.functional.scaled_dot_product_attention( queries, @@ -38,6 +37,6 @@ def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: is_causal=False, ) - res = res.transpose(-2, -3).flatten(start_dim=-2) + res = res.transpose(-2, -3).contiguous().flatten(start_dim=-2) res = self.dropout(res)[..., 0, :] # drop 1D query dim return res From dce6a4bab6359455b40fed9aa35a1b455b9d6739 Mon Sep 17 00:00:00 2001 From: Stanton Date: Sun, 11 May 2025 03:06:03 +0000 Subject: [PATCH 6/9] tweaking protein model --- cortex/cmdline/train_cortex_model.py | 3 +- cortex/config/hydra/branches/folding.yaml | 9 +++ .../protein_property_transformer.yaml | 2 +- .../config/hydra/model_globals/default.yaml | 4 +- .../hydra/roots/protein_seq_decoder.yaml | 19 +++++ .../hydra/roots/protein_seq_transformer.yaml | 6 +- .../delta_g.yaml | 0 .../stability.yaml | 0 .../stable_proteins_autoregressive.yaml | 2 +- .../hydra/tasks/generation/tape_combined.yaml | 25 ++++++ cortex/config/hydra/train_gpt.yaml | 52 ++++++++++++ cortex/config/hydra/train_protein_model.yaml | 30 +++---- cortex/data/data_module/_task_data_module.py | 6 +- cortex/data/dataset/__init__.py | 2 + cortex/data/dataset/_tape_combined.py | 20 +++++ cortex/logging/_wandb_setup.py | 2 +- cortex/model/tree/_seq_model_tree.py | 81 ++++++++++++++++--- 17 files changed, 225 insertions(+), 38 deletions(-) create mode 100644 cortex/config/hydra/branches/folding.yaml create mode 100644 cortex/config/hydra/roots/protein_seq_decoder.yaml rename cortex/config/hydra/tasks/{protein_property => folding}/delta_g.yaml (100%) rename cortex/config/hydra/tasks/{protein_property => folding}/stability.yaml (100%) create mode 100644 cortex/config/hydra/tasks/generation/tape_combined.yaml create mode 100644 cortex/config/hydra/train_gpt.yaml create mode 100644 cortex/data/dataset/_tape_combined.py diff --git a/cortex/cmdline/train_cortex_model.py b/cortex/cmdline/train_cortex_model.py index 0183114..267fc6c 100644 --- a/cortex/cmdline/train_cortex_model.py +++ b/cortex/cmdline/train_cortex_model.py @@ -6,9 +6,9 @@ import hydra import lightning as L import torch -import wandb from omegaconf import DictConfig, OmegaConf +import wandb from cortex.logging import wandb_setup @@ -35,6 +35,7 @@ def execute(cfg): """ instantiate and train a multitask neural tree """ + torch.set_float32_matmul_precision("medium") trainer = hydra.utils.instantiate(cfg.trainer) diff --git a/cortex/config/hydra/branches/folding.yaml b/cortex/config/hydra/branches/folding.yaml new file mode 100644 index 0000000..f453877 --- /dev/null +++ b/cortex/config/hydra/branches/folding.yaml @@ -0,0 +1,9 @@ +folding: + _target_: cortex.model.branch.TransformerBranch + out_dim: 8 + channel_dim: ${channel_dim} + num_blocks: 2 + num_heads: 8 + dropout_prob: ${dropout_prob} + is_causal: false + pooling_type: attention diff --git a/cortex/config/hydra/branches/protein_property_transformer.yaml b/cortex/config/hydra/branches/protein_property_transformer.yaml index fd04d5f..20e5085 100644 --- a/cortex/config/hydra/branches/protein_property_transformer.yaml +++ b/cortex/config/hydra/branches/protein_property_transformer.yaml @@ -3,7 +3,7 @@ protein_property: out_dim: 8 channel_dim: ${channel_dim} num_blocks: 2 - num_heads: 4 + num_heads: 8 dropout_prob: ${dropout_prob} is_causal: false pooling_type: attention diff --git a/cortex/config/hydra/model_globals/default.yaml b/cortex/config/hydra/model_globals/default.yaml index 26212f5..52dc045 100644 --- a/cortex/config/hydra/model_globals/default.yaml +++ b/cortex/config/hydra/model_globals/default.yaml @@ -1,7 +1,7 @@ # @package _global_ -channel_dim: 128 +channel_dim: 512 embed_dim: 32 -ensemble_size: 4 +ensemble_size: 2 dropout_prob: 0.0 kernel_size: 5 pooling_type: mean diff --git a/cortex/config/hydra/roots/protein_seq_decoder.yaml b/cortex/config/hydra/roots/protein_seq_decoder.yaml new file mode 100644 index 0000000..4877b36 --- /dev/null +++ b/cortex/config/hydra/roots/protein_seq_decoder.yaml @@ -0,0 +1,19 @@ +protein_seq: + _target_: cortex.model.root.TransformerRoot + corruption_process: + _target_: cortex.corruption.MaskCorruptionProcess + tokenizer_transform: + _target_: cortex.transforms.HuggingFaceTokenizerTransform + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + max_len: 256 + out_dim: ${embed_dim} + embed_dim: ${embed_dim} + channel_dim: ${channel_dim} + num_blocks: 2 + num_heads: 4 + is_causal: true + dropout_prob: ${dropout_prob} + pos_encoding: true + train_transforms: null + eval_transforms: null diff --git a/cortex/config/hydra/roots/protein_seq_transformer.yaml b/cortex/config/hydra/roots/protein_seq_transformer.yaml index 3127d25..ad8b937 100644 --- a/cortex/config/hydra/roots/protein_seq_transformer.yaml +++ b/cortex/config/hydra/roots/protein_seq_transformer.yaml @@ -7,11 +7,11 @@ protein_seq: tokenizer: _target_: cortex.tokenization.ProteinSequenceTokenizerFast max_len: 256 - out_dim: ${embed_dim} + out_dim: ${channel_dim} embed_dim: ${embed_dim} channel_dim: ${channel_dim} - num_blocks: 2 - num_heads: 4 + num_blocks: 10 + num_heads: 8 is_causal: false dropout_prob: ${dropout_prob} pos_encoding: true diff --git a/cortex/config/hydra/tasks/protein_property/delta_g.yaml b/cortex/config/hydra/tasks/folding/delta_g.yaml similarity index 100% rename from cortex/config/hydra/tasks/protein_property/delta_g.yaml rename to cortex/config/hydra/tasks/folding/delta_g.yaml diff --git a/cortex/config/hydra/tasks/protein_property/stability.yaml b/cortex/config/hydra/tasks/folding/stability.yaml similarity index 100% rename from cortex/config/hydra/tasks/protein_property/stability.yaml rename to cortex/config/hydra/tasks/folding/stability.yaml diff --git a/cortex/config/hydra/tasks/generation/stable_proteins_autoregressive.yaml b/cortex/config/hydra/tasks/generation/stable_proteins_autoregressive.yaml index 9712672..28e1cdb 100644 --- a/cortex/config/hydra/tasks/generation/stable_proteins_autoregressive.yaml +++ b/cortex/config/hydra/tasks/generation/stable_proteins_autoregressive.yaml @@ -1,4 +1,4 @@ -stable_proteins: +stable_proteins_autoregressive: _target_: cortex.task.AutoregressiveLanguageModelTask tokenizer: _target_: cortex.tokenization.ProteinSequenceTokenizerFast diff --git a/cortex/config/hydra/tasks/generation/tape_combined.yaml b/cortex/config/hydra/tasks/generation/tape_combined.yaml new file mode 100644 index 0000000..b15e901 --- /dev/null +++ b/cortex/config/hydra/tasks/generation/tape_combined.yaml @@ -0,0 +1,25 @@ +tape_combined: + _target_: cortex.task.DenoisingLanguageModelTask + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + input_map: + protein_seq: ['tokenized_seq'] + root_key: protein_seq + # Add BLOSUM62-based substitution corruption for data augmentation + corruption_process: + _target_: cortex.corruption.SubstitutionCorruptionProcess.from_blosum62 + corruption_rate: 0.1 # Apply corruption to 10% of masked tokens + data_module: + _target_: cortex.data.data_module.TaskDataModule + _recursive_: false + batch_size: ${fit.batch_size} + balance_train_partition: partition + drop_last: true + lengths: [1.0, 0.0] + train_on_everything: false + num_workers: ${num_workers} + dataset_config: + _target_: cortex.data.dataset.TAPECombinedDataset + root: ${dataset_root_dir} + download: ${download_datasets} + train: ??? diff --git a/cortex/config/hydra/train_gpt.yaml b/cortex/config/hydra/train_gpt.yaml new file mode 100644 index 0000000..bda9e10 --- /dev/null +++ b/cortex/config/hydra/train_gpt.yaml @@ -0,0 +1,52 @@ +defaults: + - general_settings: default + - logging: default + - model_globals: default + - roots: [protein_seq_decoder] + - trunk: default + - branches: [generation] + - tree: protein_model + - tasks: + - generation/stable_proteins_autoregressive + - _self_ + +fit: + batch_size: 128 + +trainer: + _target_: lightning.Trainer + accelerator: gpu + max_epochs: 64 + devices: 1 + num_sanity_val_steps: 0 + + +tree: + _recursive_: false + fit_cfg: + reinitialize_roots: true + linear_probing: false + weight_averaging: null + optimizer: + _target_: torch.optim.Adam + lr: 6e-4 + weight_decay: 0. + betas: [0.99, 0.999] + fused: false + lr_scheduler: + _target_: transformers.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + num_training_steps: ${trainer.max_epochs} + +tasks: + generation: + stable_proteins_autoregressive: + ensemble_size: 1 + +train_on_everything: false +linear_probing: false +dataset_root_dir: /home/stantos5/scratch/datasets +download_datasets: true +num_workers: 2 + +ckpt_name: ${exp_name}_${job_name} diff --git a/cortex/config/hydra/train_protein_model.yaml b/cortex/config/hydra/train_protein_model.yaml index 9919b9e..f5b919c 100644 --- a/cortex/config/hydra/train_protein_model.yaml +++ b/cortex/config/hydra/train_protein_model.yaml @@ -4,13 +4,14 @@ defaults: - model_globals: default - roots: [protein_seq_transformer] - trunk: default - - branches: [protein_property_transformer, generation] + - branches: [folding, protein_property_transformer, generation] - tree: protein_model - tasks: - protein_property/log_fluorescence - - protein_property/stability - - generation/gfp - - generation/stable_proteins + - folding/stability + # - generation/gfp + # - generation/stable_proteins + - generation/tape_combined - _self_ fit: @@ -23,8 +24,8 @@ trainer: devices: 1 # devices: 8 # strategy: ddp - num_sanity_val_steps: 0 - + num_sanity_val_steps: 1 + # precision: 16 tree: _recursive_: false @@ -44,17 +45,18 @@ tree: num_training_steps: ${trainer.max_epochs} tasks: + folding: + stability: + ensemble_size: ${ensemble_size} protein_property: log_fluorescence: - # ensemble_size: ${ensemble_size} - ensemble_size: 2 - stability: - # ensemble_size: ${ensemble_size} - ensemble_size: 2 + ensemble_size: ${ensemble_size} generation: - gfp: - ensemble_size: 1 - stable_proteins: + # gfp: + # ensemble_size: 1 + # stable_proteins: + # ensemble_size: 1 + tape_combined: ensemble_size: 1 train_on_everything: false diff --git a/cortex/data/data_module/_task_data_module.py b/cortex/data/data_module/_task_data_module.py index 100bcd3..2801d23 100644 --- a/cortex/data/data_module/_task_data_module.py +++ b/cortex/data/data_module/_task_data_module.py @@ -122,9 +122,9 @@ def get_dataloader(self, split: str = "train"): else: # Full batch for evaluation on the test set if split == "test": - # self._dataloader_kwargs["batch_size"] = len(self.datasets[split]) - self._dataloader_kwargs["batch_size"] = 2 * self._batch_size - dataloader = DataLoader(self.datasets[split], shuffle=True, drop_last=True, **self._dataloader_kwargs) + self._dataloader_kwargs["batch_size"] = len(self.datasets[split]) + # self._dataloader_kwargs["batch_size"] = 2 * self._batch_size + dataloader = DataLoader(self.datasets[split], shuffle=False, drop_last=False, **self._dataloader_kwargs) if split == "test": self._dataloader_kwargs["batch_size"] = self._batch_size return dataloader diff --git a/cortex/data/dataset/__init__.py b/cortex/data/dataset/__init__.py index 8e2c3b8..600ab6e 100644 --- a/cortex/data/dataset/__init__.py +++ b/cortex/data/dataset/__init__.py @@ -1,6 +1,7 @@ from ._data_frame_dataset import DataFrameDataset, ordered_dict_collator from ._numpy_dataset import NumpyDataset from ._rfp_dataset import RedFluorescentProteinDataset +from ._tape_combined import TAPECombinedDataset from ._tape_fluorescence import TAPEFluorescenceDataset from ._tape_stability import TAPEStabilityDataset from ._transformed_dataset import TransformedDataset @@ -12,5 +13,6 @@ "RedFluorescentProteinDataset", "TAPEFluorescenceDataset", "TAPEStabilityDataset", + "TAPECombinedDataset", "TransformedDataset", ] diff --git a/cortex/data/dataset/_tape_combined.py b/cortex/data/dataset/_tape_combined.py new file mode 100644 index 0000000..8576e12 --- /dev/null +++ b/cortex/data/dataset/_tape_combined.py @@ -0,0 +1,20 @@ +import pandas as pd + +from cortex.data.dataset import TAPEFluorescenceDataset, TAPEStabilityDataset +from cortex.data.dataset._data_frame_dataset import DataFrameDataset + + +# hack to combine TAPE datasets for self-supervised training +class TAPECombinedDataset(DataFrameDataset): + columns = [ + "tokenized_seq", + "partition", + ] + + def __init__(self, root: str, download: bool = False, **kwargs): + fluorescence_data = TAPEFluorescenceDataset(root=root, download=download, **kwargs)._data + stability_data = TAPEStabilityDataset(root=root, download=download, **kwargs)._data + + fluorescence_data["partition"] = "fluorescence" + stability_data["partition"] = "stability" + self._data = pd.concat([fluorescence_data[self.columns], stability_data[self.columns]], ignore_index=True) diff --git a/cortex/logging/_wandb_setup.py b/cortex/logging/_wandb_setup.py index 4b26b94..d09e162 100644 --- a/cortex/logging/_wandb_setup.py +++ b/cortex/logging/_wandb_setup.py @@ -1,10 +1,10 @@ import uuid from typing import MutableMapping -import wandb from omegaconf import DictConfig, OmegaConf import cortex +import wandb def wandb_setup(cfg: DictConfig): diff --git a/cortex/model/tree/_seq_model_tree.py b/cortex/model/tree/_seq_model_tree.py index 71a063d..2a979f6 100644 --- a/cortex/model/tree/_seq_model_tree.py +++ b/cortex/model/tree/_seq_model_tree.py @@ -1,6 +1,7 @@ import copy import math -from typing import Dict, Optional +from collections import OrderedDict +from typing import Any, Dict, Optional import hydra import lightning as L @@ -9,7 +10,7 @@ import torch from botorch.models.transforms.outcome import OutcomeTransform from lightning.pytorch.utilities.combined_loader import CombinedLoader -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from torch import nn from cortex.model import online_weight_update_ @@ -90,7 +91,8 @@ def get_dataloader(self, split="train"): else: raise ValueError(f"Invalid split {split}") - mode = "min_size" if split == "train" else "max_size_cycle" + # mode = "min_size" if split == "train" else "max_size_cycle" + mode = "min_size" return CombinedLoader(loaders, mode=mode) def training_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int] = None): @@ -209,6 +211,12 @@ def _weight_average_update( def on_train_epoch_start(self): self.lr_schedulers().step() + fit_cfg = self.hparams.fit_cfg + self.train() + self.requires_grad_(True) + if fit_cfg.linear_probing: + self.freeze_roots() + self.freeze_trunk() def on_fit_start(self) -> None: # Initialize root node weights with pretrained weights or random initialization if training from scratch @@ -255,7 +263,8 @@ def validation_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[ else: predict_targets[task_key] = None - predict_out = self.predict(data=task_batch, batch_limit=len(task_batch), predict_tasks=[task_key]) + with torch.no_grad(): + predict_out = self.predict(data=task_batch, batch_limit=128, predict_tasks=[task_key]) step_metrics.update( self.prediction_metrics( @@ -272,6 +281,10 @@ def validation_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[ return step_metrics + def on_validation_epoch_start(self): + self.eval() + self.requires_grad_(False) + def on_validation_epoch_end(self): # In Lightning 2.x, we need to process the accumulated outputs manually step_metrics = pd.DataFrame.from_records(self.validation_step_outputs) @@ -322,8 +335,8 @@ def add_task_nodes( self, branch_key: str, task_key: str, - task_cfg: OmegaConf, - branch_cfg: Optional[OmegaConf] = None, + task_cfg: DictConfig, + branch_cfg: Optional[DictConfig] = None, data_dir: str = "./data", ) -> None: task = hydra.utils.instantiate(task_cfg, leaf_key=task_key, data_dir=data_dir) @@ -347,7 +360,7 @@ def initialize_leaves(self, task_dict) -> None: if task_key in task_dict: l_node.initialize() - def build_tree(self, cfg: OmegaConf, skip_task_setup: bool = False) -> dict: + def build_tree(self, cfg: DictConfig, skip_task_setup: bool = False) -> dict: # create root nodes for root_key, root_cfg in cfg.roots.items(): self.root_nodes[root_key] = hydra.utils.instantiate(root_cfg, device=self.device, dtype=self.dtype) @@ -389,11 +402,11 @@ def build_tree(self, cfg: OmegaConf, skip_task_setup: bool = False) -> dict: def predict( self, - data: pd.DataFrame, + data: OrderedDict, batch_limit: Optional[int] = None, predict_tasks: Optional[list[str]] = None, format_outputs: bool = True, - cpu_offload: bool = False, + cpu_offload: bool = True, ) -> dict[str, torch.Tensor]: """ Args: @@ -402,10 +415,13 @@ def predict( predict_out: dict[str, torch.Tensor] of task prediction outputs for `predict_inputs`. """ self.eval() - batch_limit = len(data) if batch_limit is None else batch_limit - num_chunks = math.ceil(len(data) / batch_limit) + + keys = [k for k in data.keys() if k != "batch_size"] + batch_limit = len(data[keys[0]]) if batch_limit is None else batch_limit + num_chunks = math.ceil(len(data[keys[0]]) / batch_limit) + if num_chunks > 1: - batches = np.array_split(data, num_chunks) + batches = split_data(data, batch_limit) else: batches = [data] @@ -615,3 +631,44 @@ def get_param_prefixes(tree_outputs): param_prefixes.append(f"leaf_nodes.{leaf_key}") return param_prefixes + + +def split_data(data: OrderedDict[str, int | list[Any]], batch_size: int): + """ + Split a dictionary into n chunks with size at most batch_size. + If batch_size is not a divisor of the dictionary element length, the last chunk will be smaller. + Args: + data: dict to split + batch_size: max size of each chunk + Returns: + list of dict chunks + """ + chunked_vals = {} + num_chunks = float("inf") + for k, v in data.items(): + if isinstance(v, int): + continue + chunked_vals[k] = split_list(v, batch_size) + num_chunks = min(num_chunks, len(chunked_vals[k])) + + res = [{k: chunked_vals[k][i] for k in chunked_vals} for i in range(num_chunks)] + return res + + +def split_list(lst: list[Any], batch_size: int) -> list[list[Any]]: + """ + Split a list into chunks with size at most batch_size. + If batch_size is not a divisor of the list length, the last chunk will be smaller. + Args: + lst: list to split + batch_size: max size of each chunk + Returns: + list of list chunks + """ + res = [] + num_chunks = math.ceil(len(lst) / batch_size) + for i in range(num_chunks): + start = i * batch_size + end = min((i + 1) * batch_size, len(lst)) + res.append(lst[start:end]) + return res From 348324437cc5b3ec495e50e583401e5f2dfe3ce7 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Mon, 12 May 2025 10:27:18 -0400 Subject: [PATCH 7/9] fix bad isort --- cortex/cmdline/train_cortex_model.py | 2 ++ cortex/logging/_wandb_setup.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cortex/cmdline/train_cortex_model.py b/cortex/cmdline/train_cortex_model.py index 267fc6c..ef5a55f 100644 --- a/cortex/cmdline/train_cortex_model.py +++ b/cortex/cmdline/train_cortex_model.py @@ -8,7 +8,9 @@ import torch from omegaconf import DictConfig, OmegaConf +# ruff: noqa: I001 import wandb + from cortex.logging import wandb_setup diff --git a/cortex/logging/_wandb_setup.py b/cortex/logging/_wandb_setup.py index d09e162..364dec2 100644 --- a/cortex/logging/_wandb_setup.py +++ b/cortex/logging/_wandb_setup.py @@ -3,9 +3,11 @@ from omegaconf import DictConfig, OmegaConf -import cortex +# ruff: noqa: I001 import wandb +import cortex + def wandb_setup(cfg: DictConfig): """ From c3811cd82c8e0912051a51460dad0c5ef36fcd16 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Mon, 12 May 2025 10:36:07 -0400 Subject: [PATCH 8/9] update config names --- .../{folding.yaml => folding_encoder.yaml} | 0 ...n_property.yaml => protein_property_cnn.yaml} | 0 ...former.yaml => protein_property_encoder.yaml} | 0 .../{protein_seq.yaml => protein_seq_cnn.yaml} | 0 ...transformer.yaml => protein_seq_encoder.yaml} | 0 cortex/config/hydra/train_protein_model.yaml | 16 +++++----------- 6 files changed, 5 insertions(+), 11 deletions(-) rename cortex/config/hydra/branches/{folding.yaml => folding_encoder.yaml} (100%) rename cortex/config/hydra/branches/{protein_property.yaml => protein_property_cnn.yaml} (100%) rename cortex/config/hydra/branches/{protein_property_transformer.yaml => protein_property_encoder.yaml} (100%) rename cortex/config/hydra/roots/{protein_seq.yaml => protein_seq_cnn.yaml} (100%) rename cortex/config/hydra/roots/{protein_seq_transformer.yaml => protein_seq_encoder.yaml} (100%) diff --git a/cortex/config/hydra/branches/folding.yaml b/cortex/config/hydra/branches/folding_encoder.yaml similarity index 100% rename from cortex/config/hydra/branches/folding.yaml rename to cortex/config/hydra/branches/folding_encoder.yaml diff --git a/cortex/config/hydra/branches/protein_property.yaml b/cortex/config/hydra/branches/protein_property_cnn.yaml similarity index 100% rename from cortex/config/hydra/branches/protein_property.yaml rename to cortex/config/hydra/branches/protein_property_cnn.yaml diff --git a/cortex/config/hydra/branches/protein_property_transformer.yaml b/cortex/config/hydra/branches/protein_property_encoder.yaml similarity index 100% rename from cortex/config/hydra/branches/protein_property_transformer.yaml rename to cortex/config/hydra/branches/protein_property_encoder.yaml diff --git a/cortex/config/hydra/roots/protein_seq.yaml b/cortex/config/hydra/roots/protein_seq_cnn.yaml similarity index 100% rename from cortex/config/hydra/roots/protein_seq.yaml rename to cortex/config/hydra/roots/protein_seq_cnn.yaml diff --git a/cortex/config/hydra/roots/protein_seq_transformer.yaml b/cortex/config/hydra/roots/protein_seq_encoder.yaml similarity index 100% rename from cortex/config/hydra/roots/protein_seq_transformer.yaml rename to cortex/config/hydra/roots/protein_seq_encoder.yaml diff --git a/cortex/config/hydra/train_protein_model.yaml b/cortex/config/hydra/train_protein_model.yaml index f5b919c..3662566 100644 --- a/cortex/config/hydra/train_protein_model.yaml +++ b/cortex/config/hydra/train_protein_model.yaml @@ -2,15 +2,13 @@ defaults: - general_settings: default - logging: default - model_globals: default - - roots: [protein_seq_transformer] + - roots: [protein_seq_encoder] - trunk: default - - branches: [folding, protein_property_transformer, generation] + - branches: [folding_encoder, protein_property_encoder, generation] - tree: protein_model - tasks: - protein_property/log_fluorescence - folding/stability - # - generation/gfp - # - generation/stable_proteins - generation/tape_combined - _self_ @@ -20,11 +18,11 @@ fit: trainer: _target_: lightning.Trainer accelerator: gpu - max_epochs: 64 + max_epochs: 128 devices: 1 # devices: 8 # strategy: ddp - num_sanity_val_steps: 1 + num_sanity_val_steps: 0 # precision: 16 tree: @@ -35,7 +33,7 @@ tree: weight_averaging: null optimizer: _target_: torch.optim.Adam - lr: 6e-4 + lr: 3e-4 weight_decay: 0. betas: [0.99, 0.999] fused: false @@ -52,10 +50,6 @@ tasks: log_fluorescence: ensemble_size: ${ensemble_size} generation: - # gfp: - # ensemble_size: 1 - # stable_proteins: - # ensemble_size: 1 tape_combined: ensemble_size: 1 From cb3bcac1bbcf0b3a670a1bfb106a8b2b5756e6d1 Mon Sep 17 00:00:00 2001 From: Samuel Stanton Date: Mon, 12 May 2025 10:41:34 -0400 Subject: [PATCH 9/9] fix bad isort --- cortex/data/dataset/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cortex/data/dataset/__init__.py b/cortex/data/dataset/__init__.py index 600ab6e..614edc4 100644 --- a/cortex/data/dataset/__init__.py +++ b/cortex/data/dataset/__init__.py @@ -1,9 +1,11 @@ from ._data_frame_dataset import DataFrameDataset, ordered_dict_collator from ._numpy_dataset import NumpyDataset from ._rfp_dataset import RedFluorescentProteinDataset -from ._tape_combined import TAPECombinedDataset from ._tape_fluorescence import TAPEFluorescenceDataset from ._tape_stability import TAPEStabilityDataset + +# ruff: noqa: I001 +from ._tape_combined import TAPECombinedDataset from ._transformed_dataset import TransformedDataset __all__ = [