Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spalbp/config/model/base_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ type: transformer
embedding_dim: 64
context_size: ${experiment.context_size}
positional_encoding_type: sinusoidal
vocab_size: ${experiment.num_classes}
21 changes: 17 additions & 4 deletions spalbp/lib/attention/dynamic_sparse_attentions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from typing import Tuple
import torch
from torch import nn, Tensor
import wandb

from _context import sparse
from sparse import util

from spalbp.lib.attention.config import AdaptiveSparseAttentionConfig, NonAdaptiveSparseAttentionConfig


class StepGadditionalDecayer:

def __init__(self, initial_value, steps, decay_rate):
self.initial_value = initial_value
self.steps = steps
self.decay_rate = decay_rate
self.cur = 0
self.value = initial_value

def __call__(self):
self.cur += 1
if self.cur % self.steps == 0:
self.value *= self.decay_rate
return self.value


class _OneDimensionalSparseAttention(nn.Module):
Expand All @@ -27,7 +39,7 @@ def __init__(self,
self.emb = emb
self.n_heads = n_heads
self.k = k
self.gadditional = gadditional
self.gadditional = StepGadditionalDecayer(gadditional, 50_000, 0.9)
self.nadditional = nadditional
self.head_size = head_size

Expand Down Expand Up @@ -56,9 +68,10 @@ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor, output_attentio
means, sigmas, values = self.hyper(x) # (B, H, C, k, 1); (B, H, C, k, 1); (B, H, C, k)
batch, context, emb = x.size() # (B, C, E)
rank = means.size(-1)
gadditional = int(self.gadditional()) if self.training or not self.remove_rand_on_eval else 0
indices: Tensor = sparse.ngenerate(means,
# For evaluation, only get nearest
self.gadditional if self.training or not self.remove_rand_on_eval else 0,
gadditional,
self.nadditional if self.training or not self.remove_rand_on_eval else 0, # index for each point
rng=(context,),
relative_range=(3,),
Expand All @@ -73,7 +86,7 @@ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor, output_attentio
indices_fl = indices.float()
# For each point (self.k), we expect to sample the 2**rank closest points from the first set of sampling,
# then self.gadditional globally-sampled indices, and self.nadditional neighborhood-sampled indices.
num_points = self.k * (2 ** rank + ((self.gadditional + self.nadditional) if self.training or not self.remove_rand_on_eval else 0))
num_points = self.k * (2 ** rank + ((gadditional + self.nadditional) if self.training or not self.remove_rand_on_eval else 0))
assert indices.size() == (
batch, self.n_heads, context, num_points, 1), f'Expected size {(batch, context, num_points, 1)}. ' \
f'Got {indices.size()}'
Expand Down