Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/data/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def parse_labels(filepath):
test_labels = parse_labels(os.path.join(data_dir, "t10k-labels-idx1-ubyte.gz"))

if normalize:
train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0
train_images = train_images.astype(np.float32) / 127.5 - 1.
test_images = test_images.astype(np.float32) / 127.5 - 1.

return train_images, train_labels, test_images, test_labels
479 changes: 237 additions & 242 deletions examples/hello-gpt.ipynb

Large diffs are not rendered by default.

775 changes: 775 additions & 0 deletions examples/hello-mnist-vit.ipynb

Large diffs are not rendered by default.

26 changes: 17 additions & 9 deletions examples/hello-mnist.ipynb

Large diffs are not rendered by default.

27 changes: 26 additions & 1 deletion modula/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import jax
import copy

import jax
import einops

class Module:
def __init__(self):
self.children = []
Expand Down Expand Up @@ -204,3 +206,26 @@ def __init__(self, scalar):

def forward(self, x, w):
return x * self.sensitivity

class Mean(Bond):
def __init__(self, axis, size):
super().__init__()
self.smooth = True
self.axis = axis
self.size = size
self.sensitivity = 1 / size

def forward(self, x, w):
assert x.shape[self.axis] == self.size
return jax.numpy.mean(x, axis=self.axis)

class Patchify(Bond):
def __init__(self, size):
super().__init__()
self.smooth = True
self.sensitivity = 1
self.size = size

def forward(self, x, w):
p1, p2 = self.size
return einops.rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=p1, p2=p2)
48 changes: 48 additions & 0 deletions modula/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,54 @@ def dualize(self, grad_w, target_norm=1.0):
d_weight = jnp.nan_to_num(d_weight)
return [d_weight]

class Bias(Atom):
def __init__(self, d):
super().__init__()
self.d = d
self.smooth = True
self.mass = 1
self.sensitivity = 1

def forward(self, x, w):
weights = w[0] # shape [d]
return weights

def initialize(self, key):
return [jnp.zeros(shape=self.d)]

def project(self, w):
weight = w[0]
weight = weight / jnp.linalg.norm(weight) * jnp.sqrt(self.d)
return [weight]

def dualize(self, grad_w, target_norm=1.0):
grad = grad_w[0]
d_weight = grad / jnp.linalg.norm(grad) * jnp.sqrt(self.d) * target_norm
d_weight = jnp.nan_to_num(d_weight)
return [d_weight]

class Scale(Atom):
def __init__(self, d):
super().__init__()
self.d = d
self.smooth = True
self.mass = 1
self.sensitivity = 1

def forward(self, x, w):
weights = w[0] # shape [d]
return weights * x

def initialize(self, key):
return [jnp.ones(shape=self.d)]

def project(self, w):
weight = w[0]
return [jnp.sign(weight)]

def dualize(self, grad_w, target_norm=1.0):
grad = grad_w[0]
return [jnp.sign(grad) * target_norm]

if __name__ == "__main__":

Expand Down
39 changes: 29 additions & 10 deletions modula/bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def forward(self, x, w):
v, scores = x
return scores @ v

class Constant(Bond):
def __init__(self, f):
super().__init__()
self.f = f
self.smooth = True
self.sensitivity = 0

def forward(self, x, w):
return self.f()

class LayerNorm(Bond):
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
self.smooth = True
self.sensitivity = 1

def forward(self, x, w):
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.var(x, axis=-1, keepdims=True)
return (x - mean) / jnp.sqrt(var + self.eps)

class Rope(Bond):
"""Rotates queries and keys by relative context window distance."""
def __init__(self, d_head, base=10000):
Expand All @@ -106,18 +128,14 @@ def __init__(self, d_head, base=10000):

self.rope_dim = d_head // 2
self.inverse_frequencies = 1/base**(jnp.arange(self.rope_dim) / self.rope_dim)
self.seq_len_cached = None
self.sin_cached = None
self.cos_cached = None

def get_cached(self, seq_len):
if self.seq_len_cached != seq_len:
self.seq_len_cached = seq_len
distance = jnp.arange(seq_len)
freqs = jnp.outer(distance, self.inverse_frequencies) # shape [seq_len, rope_dim]
self.cos_cached = jnp.expand_dims(jnp.cos(freqs), (0, 1)) # shape [seq_len, rope_dim]
self.sin_cached = jnp.expand_dims(jnp.sin(freqs), (0, 1)) # shape [seq_len, rope_dim]
return self.sin_cached, self.cos_cached
# Actually caching the return value may lead to leaked intermediate value error
distance = jnp.arange(seq_len)
freqs = jnp.outer(distance, self.inverse_frequencies) # shape [seq_len, rope_dim]
cos = jnp.expand_dims(jnp.cos(freqs), (0, 1)) # shape [seq_len, rope_dim]
sin = jnp.expand_dims(jnp.sin(freqs), (0, 1)) # shape [seq_len, rope_dim]
return sin, cos

def rotate(self, x):
batch, n_heads, seq_len, d_head = x.shape
Expand All @@ -126,6 +144,7 @@ def rotate(self, x):
x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim]
x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim]

# Why is the order reversed!?
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this too and corrected it in #13. I guess it doesn't really matter for performance?

cos, sin = self.get_cached(seq_len)
y1 = cos * x1 + sin * x2
y2 = -sin * x1 + cos * x2
Expand Down
71 changes: 63 additions & 8 deletions modula/compound.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import jax.numpy as jnp

from modula.abstract import *
from modula.atom import *
from modula.bond import *
Expand All @@ -8,14 +10,21 @@ def MLP(output_dim, input_dim, width, depth):
m = m @ Linear(width, width) @ ReLU()
return m @ Linear(width, input_dim)

def Attention(num_heads, d_embed, d_query, d_value, softmax_scale, causal):
def Attention(num_heads, d_embed, d_query, d_value, softmax_scale, causal, posemb="rope", bias=False):
"""Multi-head attention"""
Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed)
W = Linear(d_embed, num_heads * d_value) @ MergeHeads()

AttentionScores = Softmax(softmax_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K)
Q, K, V = Linear(num_heads * d_query, d_embed), Linear(num_heads * d_query, d_embed), Linear(num_heads * d_value, d_embed)
Q = SplitIntoHeads(num_heads) @ (Q + Bias(num_heads * d_query) if bias else Q)
K = SplitIntoHeads(num_heads) @ (K + Bias(num_heads * d_query) if bias else K)
V = SplitIntoHeads(num_heads) @ (V + Bias(num_heads * d_value) if bias else V)
W = Linear(d_embed, num_heads * d_value)
W = (W + Bias(d_embed) if bias else W) @ MergeHeads()
QK = (Q, K)
if posemb == "rope":
QK = Rope(d_query) @ QK
attn = AttentionQK() @ QK
if causal:
attn = CausalMask() @ attn
AttentionScores = Softmax(softmax_scale) @ attn
return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores)

def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0):
Expand All @@ -31,4 +40,50 @@ def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mas

out = final_scale * Linear(vocab_size, d_embed)

return out @ blocks @ embed
return out @ blocks @ embed

def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32):
"""Follows the MoCo v3 logic."""
y, x = jnp.mgrid[:h, :w]

assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
omega = jnp.arange(width // 4) / (width // 4 - 1)
omega = 1. / (temperature**omega)
y = jnp.einsum("m,d->md", y.flatten(), omega)
x = jnp.einsum("m,d->md", x.flatten(), omega)
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
return jnp.asarray(pe, dtype)[None, :, :]

def ViT(num_classes, image_size=(28, 28), patch_size=(7, 7), num_heads=4, d_embed=32, d_query=8, d_value=8, num_blocks=4, blocks_mass=5, attention_scale=1.0, final_scale=1.0, channels=1, LN=True, bias=True, scale=True):
i1, i2 = image_size
p1, p2 = patch_size
h, w = i1 // p1, i2 // p2
patchify = Linear(d_embed, p1 * p2 * channels) @ Patchify(patch_size)
if bias:
patchify = patchify + Bias(d_embed)
posemb = Constant(lambda: posemb_sincos_2d(h, w, d_embed))

att = Attention(num_heads, d_embed, d_query, d_value, attention_scale, causal=False, posemb="none", bias=bias)
mlp = (Linear(d_embed, 4*d_embed) + Bias(d_embed) if bias else Linear(d_embed, 4*d_embed)) @ GeLU() @ (Linear(4*d_embed, d_embed) + Bias(4*d_embed) if bias else Linear(4*d_embed, d_embed))
if LN:
ln = LayerNorm()
if bias and scale:
ln = (Scale(d_embed) + Bias(d_embed)) @ ln
elif bias:
ln = ln + Bias(d_embed)
elif scale:
ln = Scale(d_embed) @ ln
att = att @ ln
mlp = mlp @ ln
att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att
mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp
blocks = (mlp_block @ att_block) ** num_blocks
blocks.tare(absolute=blocks_mass)

gap = Mean(axis=1, size=h * w)
out = final_scale * (Linear(num_classes, d_embed) + Bias(num_classes) if bias else Linear(num_classes, d_embed))

ret = blocks @ (patchify + posemb)
if LN: # Final LN
ret = ln @ ret
return out @ gap @ ret