From 8cf7073a990c41e2a9e74b783943eb5036f5a5bd Mon Sep 17 00:00:00 2001 From: rsaeta Date: Tue, 20 Jun 2023 12:21:37 +0200 Subject: [PATCH 1/3] Simpler synth experiment --- spalbp/config.py | 1 + spalbp/config/experiment/synth.yaml | 1 + spalbp/config/model/base_model.yaml | 1 + .../__pycache__/__init__.cpython-39.pyc | Bin 268 -> 316 bytes spalbp/synth_data_gen.py | 57 ++++++++++++++++++ spalbp/synthetic_mask.py | 46 +++----------- 6 files changed, 67 insertions(+), 39 deletions(-) create mode 100644 spalbp/synth_data_gen.py diff --git a/spalbp/config.py b/spalbp/config.py index 4fb11a8..301a6d8 100644 --- a/spalbp/config.py +++ b/spalbp/config.py @@ -70,6 +70,7 @@ class ExperimentConfig: @dataclass class SynthMaskExperimentConfig(ExperimentConfig): offset: int + data_gen: str @dataclass diff --git a/spalbp/config/experiment/synth.yaml b/spalbp/config/experiment/synth.yaml index dd10e62..315056d 100644 --- a/spalbp/config/experiment/synth.yaml +++ b/spalbp/config/experiment/synth.yaml @@ -9,3 +9,4 @@ data_source: synthetic num_classes: 32000 wandb_project: sparse_transformer_synth offset: 70 +data_gen: simple diff --git a/spalbp/config/model/base_model.yaml b/spalbp/config/model/base_model.yaml index 05d9d19..4016860 100644 --- a/spalbp/config/model/base_model.yaml +++ b/spalbp/config/model/base_model.yaml @@ -10,3 +10,4 @@ type: transformer embedding_dim: 64 context_size: ${experiment.context_size} positional_encoding_type: sinusoidal +vocab_size: 32000 diff --git a/spalbp/lib/models/__pycache__/__init__.cpython-39.pyc b/spalbp/lib/models/__pycache__/__init__.cpython-39.pyc index 20ce493668fd192f018709ab97b07f908d83d626..9e599f33b5539cdd1417948fb3703fd03e6d240f 100644 GIT binary patch delta 85 zcmeBS+QY=1$ji&c00i~hYbSDVvrE?xElw>e)-S3|%qvbzElJdONiEJU$uH0kE6U7D iRS3?CX9AUj)#^?>ZwmmaFdh^D delta 37 rcmdnP)WgJ`$ji&c00i$ol}+T{#%Hddk)NBYUsRl!T9T+g@xLtq$kYr? diff --git a/spalbp/synth_data_gen.py b/spalbp/synth_data_gen.py new file mode 100644 index 0000000..e720389 --- /dev/null +++ b/spalbp/synth_data_gen.py @@ -0,0 +1,57 @@ +import torch +from utils import cuda + + +def random_sample_simple(batch_size, seq_len, mask_i=0, target_i=-1): + seqs_inputs = torch.randint(size=(batch_size, seq_len), low=100, high=32000) + attention_masks = torch.ones_like(seqs_inputs) + mask = torch.ones((batch_size, seq_len)) + mask[:, mask_i] = 0 + mask_token = 4 + targets = seqs_inputs.detach().clone() + seqs_inputs[:, mask_i] = mask_token + seqs_inputs[:, target_i] = targets[:, mask_i] + # Expand the attention mask to a symmetric matrix + attention_masks = attention_masks[:, None, :].expand(-1, seq_len, -1) + mask = mask.bool() + if cuda: + seqs_inputs = seqs_inputs.cuda() + attention_masks = attention_masks.cuda() + targets = targets.cuda() + mask = mask.cuda() + return seqs_inputs, attention_masks, targets, mask + + +def random_sample_data2(batch_size, seq_len, offset=70): + seqs_inputs = torch.randint(size=(batch_size, seq_len), low=100, high=32000) + attention_masks = torch.ones_like(seqs_inputs) + mask_token = 4 + mask = torch.rand((batch_size, seq_len)) > 0.05 + # mask = torch.ones((batch_size, seq_len)) + # mask[:, 45:55] = 0 + mask = mask.bool() + targets = seqs_inputs.detach().clone() + # Modify the input so that the masked token positions are filled with [MASK] tokens + # and the token at position mask + offset is the target token. + for b, m_i in (~mask).nonzero(): + seqs_inputs[b] = apply_offset_mask(seqs_inputs[b], m_i, mask_token, offset) + # Expand the attention mask to a symmetric matrix + attention_masks = attention_masks[:, None, :].expand(-1, seq_len, -1) + if cuda: + seqs_inputs = seqs_inputs.cuda() + attention_masks = attention_masks.cuda() + targets = targets.cuda() + mask = mask.cuda() + return seqs_inputs, attention_masks, targets, mask + + +def apply_offset_mask(seq_input, i, mask_token, offset): + """ + This function replaces seq_input[i] with the mask_token and replaces + seq_input[i+offset] with the target token. + """ + target_token = seq_input[i].item() + seq_input[i] = mask_token + new_pos = (i + offset) % seq_input.size(0) + seq_input[new_pos] = target_token + return seq_input \ No newline at end of file diff --git a/spalbp/synthetic_mask.py b/spalbp/synthetic_mask.py index 6bd44e0..abc7f97 100644 --- a/spalbp/synthetic_mask.py +++ b/spalbp/synthetic_mask.py @@ -12,10 +12,7 @@ import torch.nn.functional as F import wandb from omegaconf import OmegaConf -from lib.models import GeneratingTransformer from utils import ( - cuda, - device, setup, learners, save_model, @@ -23,41 +20,9 @@ post_process_cfg, get_model, ) - - -def random_sample_data2(batch_size, seq_len, offset=70): - seqs_inputs = torch.randint(size=(batch_size, seq_len), low=100, high=32000) - attention_masks = torch.ones_like(seqs_inputs) - mask_token = 4 - mask = torch.rand((batch_size, seq_len)) > 0.05 - # mask = torch.ones((batch_size, seq_len)) - # mask[:, 45:55] = 0 - mask = mask.bool() - targets = seqs_inputs.detach().clone() - # Modify the input so that the masked token positions are filled with [MASK] tokens - # and the token at position mask + offset is the target token. - for b, m_i in (~mask).nonzero(): - seqs_inputs[b] = apply_offset_mask(seqs_inputs[b], m_i, mask_token, offset) - # Expand the attention mask to a symmetric matrix - attention_masks = attention_masks[:, None, :].expand(-1, seq_len, -1) - if cuda: - seqs_inputs = seqs_inputs.cuda() - attention_masks = attention_masks.cuda() - targets = targets.cuda() - mask = mask.cuda() - return seqs_inputs, attention_masks, targets, mask - - -def apply_offset_mask(seq_input, i, mask_token, offset): - """ - This function replaces seq_input[i] with the mask_token and replaces - seq_input[i+offset] with the target token. - """ - target_token = seq_input[i].item() - seq_input[i] = mask_token - new_pos = (i + offset) % seq_input.size(0) - seq_input[new_pos] = target_token - return seq_input +from synth_data_gen import ( + random_sample_data2, random_sample_simple +) def _train(cfg: RunConfig): @@ -66,13 +31,16 @@ def _train(cfg: RunConfig): optimizer, scheduler = learners(model, cfg) tokens_seen = 0 train_cfg = cfg.experiment.training + data_sample_function = lambda: random_sample_simple(cfg.experiment.training.batch_size, cfg.experiment.context_size) \ + if cfg.experiment.data_gen == 'simple' \ + else lambda: random_sample_data2(cfg.experiment.training.batch_size, cfg.experiment.context_size, cfg.experiment.offset) if cfg.experiment.watch_model: wandb.watch(model) for i in range(train_cfg.num_batches): model.train() optimizer.zero_grad() - data_sample = random_sample_data2(train_cfg.batch_size, cfg.experiment.context_size, cfg.experiment.offset) + data_sample = data_sample_function() seqs_inputs, attention_masks, targets, mask = data_sample logits, aux_loss = model(seqs_inputs, attention_masks) From cdd24c638a90283e322c54b70715b75b00201b65 Mon Sep 17 00:00:00 2001 From: rsaeta Date: Tue, 20 Jun 2023 13:31:53 +0200 Subject: [PATCH 2/3] Fix nonadaptive with heads --- spalbp/lib/attention/dynamic_sparse_attentions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spalbp/lib/attention/dynamic_sparse_attentions.py b/spalbp/lib/attention/dynamic_sparse_attentions.py index 51abd1b..2bc42b1 100644 --- a/spalbp/lib/attention/dynamic_sparse_attentions.py +++ b/spalbp/lib/attention/dynamic_sparse_attentions.py @@ -57,9 +57,8 @@ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor # For each point (self.k), we expect to sample the 2**rank closest points from the first set of sampling, # then self.gadditional globally-sampled indices, and self.nadditional neighborhood-sampled indices. num_points = self.k * (2 ** rank + ((self.gadditional + self.nadditional) if self.training else 0)) - assert indices.size() == ( - batch, self.n_heads, context, num_points, 1), f'Expected size {(batch, context, num_points, 1)}. ' \ - f'Got {indices.size()}' + assert indices.size() == (batch, self.n_heads, context, num_points, 1), \ + f'Expected size {(batch, self.n_heads, context, num_points, 1)}. Got {indices.size()}' densities = sparse.densities(indices_fl, means, sigmas).clone() # (B, H, C, P, self.k) duplicates = util.nduplicates(indices).to(torch.bool) # (B, C, P) boolean mask of duplicates all-but-one densities[duplicates, :] = 0 # Removes all duplicates @@ -147,10 +146,11 @@ def __init__(self, def hyper(self, x: torch.Tensor): b, c, e = x.size() + h = self.n_heads k = self.k - means = self.pmeans[None, :, :, :].expand(b, c, k, 1) - sigmas = self.psigmas[None, :, :].expand(b, c, k) - values = self.pvalues[None, None, :].expand(b, c, k) + means = self.pmeans[None, :, :, :].expand(b, h, c, k, 1) + sigmas = self.psigmas[None, :, :].expand(b, h, c, k) + values = self.pvalues[None, None, :].expand(b, h, c, k) means = sparse.transform_means(means, (c,), method=self.transformation_method) sigmas = sparse.transform_sigmas(sigmas, (c,)) * self.sigma_scale From 6fb04f439ef055382d326b2784860f22f053ca44 Mon Sep 17 00:00:00 2001 From: rsaeta Date: Tue, 20 Jun 2023 14:09:33 +0200 Subject: [PATCH 3/3] Fix python ternary --- spalbp/synthetic_mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spalbp/synthetic_mask.py b/spalbp/synthetic_mask.py index abc7f97..0278049 100644 --- a/spalbp/synthetic_mask.py +++ b/spalbp/synthetic_mask.py @@ -31,15 +31,15 @@ def _train(cfg: RunConfig): optimizer, scheduler = learners(model, cfg) tokens_seen = 0 train_cfg = cfg.experiment.training - data_sample_function = lambda: random_sample_simple(cfg.experiment.training.batch_size, cfg.experiment.context_size) \ + data_sample_function = (lambda: random_sample_simple(cfg.experiment.training.batch_size, cfg.experiment.context_size)) \ if cfg.experiment.data_gen == 'simple' \ - else lambda: random_sample_data2(cfg.experiment.training.batch_size, cfg.experiment.context_size, cfg.experiment.offset) + else (lambda: random_sample_data2(cfg.experiment.training.batch_size, cfg.experiment.context_size, cfg.experiment.offset)) if cfg.experiment.watch_model: wandb.watch(model) for i in range(train_cfg.num_batches): model.train() optimizer.zero_grad() - + breakpoint() data_sample = data_sample_function() seqs_inputs, attention_masks, targets, mask = data_sample