Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spalbp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class ExperimentConfig:
@dataclass
class SynthMaskExperimentConfig(ExperimentConfig):
offset: int
data_gen: str


@dataclass
Expand Down
1 change: 1 addition & 0 deletions spalbp/config/experiment/synth.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ data_source: synthetic
num_classes: 32000
wandb_project: sparse_transformer_synth
offset: 70
data_gen: simple
1 change: 1 addition & 0 deletions spalbp/config/model/base_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ type: transformer
embedding_dim: 64
context_size: ${experiment.context_size}
positional_encoding_type: sinusoidal
vocab_size: 32000
12 changes: 6 additions & 6 deletions spalbp/lib/attention/dynamic_sparse_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor
# For each point (self.k), we expect to sample the 2**rank closest points from the first set of sampling,
# then self.gadditional globally-sampled indices, and self.nadditional neighborhood-sampled indices.
num_points = self.k * (2 ** rank + ((self.gadditional + self.nadditional) if self.training else 0))
assert indices.size() == (
batch, self.n_heads, context, num_points, 1), f'Expected size {(batch, context, num_points, 1)}. ' \
f'Got {indices.size()}'
assert indices.size() == (batch, self.n_heads, context, num_points, 1), \
f'Expected size {(batch, self.n_heads, context, num_points, 1)}. Got {indices.size()}'
densities = sparse.densities(indices_fl, means, sigmas).clone() # (B, H, C, P, self.k)
duplicates = util.nduplicates(indices).to(torch.bool) # (B, C, P) boolean mask of duplicates all-but-one
densities[duplicates, :] = 0 # Removes all duplicates
Expand Down Expand Up @@ -147,10 +146,11 @@ def __init__(self,

def hyper(self, x: torch.Tensor):
b, c, e = x.size()
h = self.n_heads
k = self.k
means = self.pmeans[None, :, :, :].expand(b, c, k, 1)
sigmas = self.psigmas[None, :, :].expand(b, c, k)
values = self.pvalues[None, None, :].expand(b, c, k)
means = self.pmeans[None, :, :, :].expand(b, h, c, k, 1)
sigmas = self.psigmas[None, :, :].expand(b, h, c, k)
values = self.pvalues[None, None, :].expand(b, h, c, k)

means = sparse.transform_means(means, (c,), method=self.transformation_method)
sigmas = sparse.transform_sigmas(sigmas, (c,)) * self.sigma_scale
Expand Down
Binary file modified spalbp/lib/models/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
57 changes: 57 additions & 0 deletions spalbp/synth_data_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from utils import cuda


def random_sample_simple(batch_size, seq_len, mask_i=0, target_i=-1):
seqs_inputs = torch.randint(size=(batch_size, seq_len), low=100, high=32000)
attention_masks = torch.ones_like(seqs_inputs)
mask = torch.ones((batch_size, seq_len))
mask[:, mask_i] = 0
mask_token = 4
targets = seqs_inputs.detach().clone()
seqs_inputs[:, mask_i] = mask_token
seqs_inputs[:, target_i] = targets[:, mask_i]
# Expand the attention mask to a symmetric matrix
attention_masks = attention_masks[:, None, :].expand(-1, seq_len, -1)
mask = mask.bool()
if cuda:
seqs_inputs = seqs_inputs.cuda()
attention_masks = attention_masks.cuda()
targets = targets.cuda()
mask = mask.cuda()
return seqs_inputs, attention_masks, targets, mask


def random_sample_data2(batch_size, seq_len, offset=70):
seqs_inputs = torch.randint(size=(batch_size, seq_len), low=100, high=32000)
attention_masks = torch.ones_like(seqs_inputs)
mask_token = 4
mask = torch.rand((batch_size, seq_len)) > 0.05
# mask = torch.ones((batch_size, seq_len))
# mask[:, 45:55] = 0
mask = mask.bool()
targets = seqs_inputs.detach().clone()
# Modify the input so that the masked token positions are filled with [MASK] tokens
# and the token at position mask + offset is the target token.
for b, m_i in (~mask).nonzero():
seqs_inputs[b] = apply_offset_mask(seqs_inputs[b], m_i, mask_token, offset)
# Expand the attention mask to a symmetric matrix
attention_masks = attention_masks[:, None, :].expand(-1, seq_len, -1)
if cuda:
seqs_inputs = seqs_inputs.cuda()
attention_masks = attention_masks.cuda()
targets = targets.cuda()
mask = mask.cuda()
return seqs_inputs, attention_masks, targets, mask


def apply_offset_mask(seq_input, i, mask_token, offset):
"""
This function replaces seq_input[i] with the mask_token and replaces
seq_input[i+offset] with the target token.
"""
target_token = seq_input[i].item()
seq_input[i] = mask_token
new_pos = (i + offset) % seq_input.size(0)
seq_input[new_pos] = target_token
return seq_input
48 changes: 8 additions & 40 deletions spalbp/synthetic_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,17 @@
import torch.nn.functional as F
import wandb
from omegaconf import OmegaConf
from lib.models import GeneratingTransformer
from utils import (
cuda,
device,
setup,
learners,
save_model,
init_wandb,
post_process_cfg,
get_model,
)


def random_sample_data2(batch_size, seq_len, offset=70):
seqs_inputs = torch.randint(size=(batch_size, seq_len), low=100, high=32000)
attention_masks = torch.ones_like(seqs_inputs)
mask_token = 4
mask = torch.rand((batch_size, seq_len)) > 0.05
# mask = torch.ones((batch_size, seq_len))
# mask[:, 45:55] = 0
mask = mask.bool()
targets = seqs_inputs.detach().clone()
# Modify the input so that the masked token positions are filled with [MASK] tokens
# and the token at position mask + offset is the target token.
for b, m_i in (~mask).nonzero():
seqs_inputs[b] = apply_offset_mask(seqs_inputs[b], m_i, mask_token, offset)
# Expand the attention mask to a symmetric matrix
attention_masks = attention_masks[:, None, :].expand(-1, seq_len, -1)
if cuda:
seqs_inputs = seqs_inputs.cuda()
attention_masks = attention_masks.cuda()
targets = targets.cuda()
mask = mask.cuda()
return seqs_inputs, attention_masks, targets, mask


def apply_offset_mask(seq_input, i, mask_token, offset):
"""
This function replaces seq_input[i] with the mask_token and replaces
seq_input[i+offset] with the target token.
"""
target_token = seq_input[i].item()
seq_input[i] = mask_token
new_pos = (i + offset) % seq_input.size(0)
seq_input[new_pos] = target_token
return seq_input
from synth_data_gen import (
random_sample_data2, random_sample_simple
)


def _train(cfg: RunConfig):
Expand All @@ -66,13 +31,16 @@ def _train(cfg: RunConfig):
optimizer, scheduler = learners(model, cfg)
tokens_seen = 0
train_cfg = cfg.experiment.training
data_sample_function = (lambda: random_sample_simple(cfg.experiment.training.batch_size, cfg.experiment.context_size)) \
if cfg.experiment.data_gen == 'simple' \
else (lambda: random_sample_data2(cfg.experiment.training.batch_size, cfg.experiment.context_size, cfg.experiment.offset))
if cfg.experiment.watch_model:
wandb.watch(model)
for i in range(train_cfg.num_batches):
model.train()
optimizer.zero_grad()

data_sample = random_sample_data2(train_cfg.batch_size, cfg.experiment.context_size, cfg.experiment.offset)
breakpoint()
data_sample = data_sample_function()
seqs_inputs, attention_masks, targets, mask = data_sample

logits, aux_loss = model(seqs_inputs, attention_masks)
Expand Down