diff --git a/.gitignore b/.gitignore index 3a866bb..ddc5107 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +tmp/ checkpoints/ *.bin diff --git a/experiments/mini_transformer.py b/experiments/mini_transformer.py new file mode 100644 index 0000000..c135d8d --- /dev/null +++ b/experiments/mini_transformer.py @@ -0,0 +1,179 @@ +# /// script +# dependencies = [ +# "julax", +# ] +# +# [tool.uv.sources] +# julax = { path = "../", editable = true } +# /// + +# Reproduce https://sdbuchanan.com/blog/jax-2/ + +from functools import partial +import grain +import jax +import jax.numpy as jnp +import numpy as np +import optax +from jax.nn.initializers import truncated_normal +from julax.core import Learner, Trainer +from julax.einops import Rearrange +from julax.experiment import Experiment +from julax.layers import ( + Chain, + Linear, + LayerNorm, + Parallel, + Repeated, + RotaryEmbedding, + SkipConnection, + Embedding, + Unembedding, +) +from julax.observers import default_observer +from julax.utils import identity + + +class FakeSource(grain.sources.RandomAccessDataSource): + def __init__(self, seq_len: int = 256) -> None: + self._seq_len = seq_len + self._data = np.array( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 1024 + ) + + def __getitem__(self, index: int): + return { + "input_ids": self._data[index : index + self._seq_len], + "target_labels": self._data[index + 1 : index + 1 + self._seq_len], + } + + def __len__(self) -> int: + return len(self._data) - self._seq_len + + +def main( + seed: int = 5, + seq_len: int = 256, + global_batch_size: int = 128, + num_steps: int = 1000, + num_vocab: int = 10, + dim: int = 768, + num_heads: int = 12, + head_dim: int = 64, + num_layers: int = 2, + param_std: float = 0.02, +): + return Experiment( + name="mini_transformer", + trainer=Trainer( + learner=Learner( + feature_name="input_ids", + label_name="target_labels", + model=Chain( + emb=Embedding( + in_dim=num_vocab, + out_dim=dim, + w_init=truncated_normal(stddev=param_std), + ), + blocks=Repeated( + n=num_layers, + layer=Chain( + attn=SkipConnection( + layer=Chain( + norm_attn=LayerNorm(dim=dim), + attn=Chain( + # qkv projection + Linear( + in_dim=dim, + out_dim=3 * dim, + w_init=truncated_normal(stddev=param_std), + b_init=None, + ), + Rearrange( + "B T (qkv N H) -> B T (qkv N) H", + B=global_batch_size, + T=seq_len, + qkv=3, + N=num_heads, + H=head_dim, + ), + partial( + jnp.split, indices_or_sections=3, axis=2 + ), + Parallel( + RotaryEmbedding( + embedding_dims=head_dim, + fprop_dtype=jnp.float32, + ), + RotaryEmbedding( + embedding_dims=head_dim, + fprop_dtype=jnp.float32, + ), + identity, + ), + lambda qkv: jax.nn.dot_product_attention( + *qkv, is_causal=True + ), + Rearrange( + "B T N H -> B T (N H)", + B=global_batch_size, + T=seq_len, + N=num_heads, + H=head_dim, + ), + Linear( + in_dim=dim, + out_dim=dim, + w_init=truncated_normal(stddev=param_std), + b_init=None, + ), + ), + ) + ), + mlp=SkipConnection( + layer=Chain( + norm_mlp=LayerNorm(dim=dim), + mlp=Chain( + up=Linear( + in_dim=dim, + out_dim=4 * dim, + w_init=truncated_normal(stddev=param_std), + b_init=None, + ), + act=jax.nn.gelu, + down=Linear( + in_dim=4 * dim, + out_dim=dim, + w_init=truncated_normal(stddev=param_std), + b_init=None, + ), + ), + ) + ), + ), + ), + unemb=Unembedding( + in_dim=dim, + out_dim=num_vocab, + w_init=truncated_normal(stddev=param_std), + ), + ), + loss_fn=optax.softmax_cross_entropy_with_integer_labels, + ), + optimizer=optax.sgd(0.01), + ), + dataset=( + grain.MapDataset.source(FakeSource(seq_len)) + .shuffle(seed=seed) + .repeat() + .batch(batch_size=global_batch_size) + .slice(slice(num_steps)) + .to_iter_dataset() + ), + observer=default_observer(), + ) + + +x = main() +x.run() +x.close() diff --git a/experiments/mnist.py b/experiments/mnist.py index 75cc7dc..4c5fcce 100644 --- a/experiments/mnist.py +++ b/experiments/mnist.py @@ -11,14 +11,11 @@ import logging -from datetime import datetime -import os import grain import jax from jax.nn.initializers import truncated_normal import optax -import orbax.checkpoint as ocp import tensorflow_datasets as tfds from julax import ( @@ -62,17 +59,11 @@ def evaluate(x: Experiment, p: Param, s: State): n_total += 32 acc = n_correct / n_total - logging.info(f"Accuracy at step {s['trainer']['step']}: {acc}") + logging.info(f"Accuracy at step {s['step']}: {acc}") E = Experiment( name="mnist", - checkpoint_manager=ocp.CheckpointManager( - directory=os.path.join( - os.getcwd(), "checkpoints", datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - ), - options=ocp.CheckpointManagerOptions(save_interval_steps=100), - ), trainer=Trainer( learner=Learner( model=Chain( @@ -114,3 +105,4 @@ def evaluate(x: Experiment, p: Param, s: State): ) E.run() +E.close() diff --git a/experiments/transformer.py b/experiments/transformer.py deleted file mode 100644 index 4f02a0f..0000000 --- a/experiments/transformer.py +++ /dev/null @@ -1,31 +0,0 @@ -# /// script -# dependencies = [ -# "julax", -# ] -# -# [tool.uv.sources] -# julax = { path = "../", editable = true } -# /// - -import grain -import numpy as np - - -class FakeSource(grain.sources.RandomAccessDataSource): - def __init__(self, seq_len: int = 256) -> None: - self._seq_len = seq_len - self._data = np.array( - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 1024 - ) - - def __getitem__(self, index: int): - return { - "input_ids": self._data[index : index + self._seq_len], - "target_labels": self._data[index + 1 : index + 1 + self._seq_len], - } - - def __len__(self) -> int: - return len(self._data) - self._seq_len - - -dataset = grain.MapDataset.source(FakeSource()).shuffle(seed=10).batch(batch_size=2) diff --git a/pyproject.toml b/pyproject.toml index 2322f21..1ce96e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "julax" -version = "0.0.3-dev" +version = "0.0.3" description = "Just Layers over JAX" readme = "README.md" authors = [ @@ -29,6 +29,9 @@ dev = [ "pytest>=8.3.2", "pytest-cov>=5.0.0", ] +tpu = [ + "jax[tpu]>=0.7.2", +] [tool.ruff.lint] ignore = ["E741", "F811"] diff --git a/src/julax/base.py b/src/julax/base.py index d891e70..39d1923 100644 --- a/src/julax/base.py +++ b/src/julax/base.py @@ -1,8 +1,41 @@ from typing import TypeAlias, Any +from pydantic import ConfigDict, RootModel from jax import Array +from jax.sharding import PartitionSpec import plum PRNG: TypeAlias = Array PyTree: TypeAlias = Any +OutShardingType: TypeAlias = PartitionSpec | None + +# TODO: isinstance(jnp.dtype, jnp.float32) fails +Dtype: TypeAlias = Any dispatch = plum.Dispatcher(warn_redefinition=True) + + +class FrozenDict(RootModel[dict]): + model_config = ConfigDict(frozen=True) + + def __getitem__(self, item): + return self.root[item] + + def __iter__(self): + return iter(self.root) + + def keys(self): + return self.root.keys() + + def values(self): + return self.root.values() + + def items(self): + return self.root.items() + + def __hash__(self): + return hash(frozenset(self.root.items())) + + def __eq__(self, other): + if isinstance(other, FrozenDict): + return self.root == other.root + return self.root == other diff --git a/src/julax/core.py b/src/julax/core.py index d5494ba..234b658 100644 --- a/src/julax/core.py +++ b/src/julax/core.py @@ -11,11 +11,11 @@ import jax import jax.numpy as jnp -from jax import jit, value_and_grad, Array +from jax import jit, value_and_grad ##### -from julax.base import PRNG, PyTree, dispatch +from julax.base import PRNG, Dtype, OutShardingType, PyTree, dispatch # TODO: use RootModel[dict] for better customization # Or maybe SimpleNamespace? @@ -26,6 +26,10 @@ class LayerBase(BaseModel, ABC): + param_dtype: Dtype | None = None + param_sharding: OutShardingType = None + out_sharding: OutShardingType = None + model_config = ConfigDict( arbitrary_types_allowed=True, frozen=True, @@ -134,7 +138,7 @@ def to_layer(x): class Learner(LayerBase): - loss_fn: Callable[[PyTree, PyTree], Array] + loss_fn: Callable[[PyTree, PyTree], Any] model: LayerBase agg: Callable = jnp.mean feature_name: str = "feature" @@ -154,7 +158,7 @@ class Trainer(LayerBase): optimizer: Any def state(self, rng: PRNG) -> State: - return State(optimizer=None, step=0, loss=0.0) + return State(optimizer=None, loss=0.0) @dispatch def init( @@ -165,11 +169,9 @@ def init( def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: loss, state = self.learner(x, p["learner"], s["learner"]) - return loss, State( - learner=state, optimizer=s["optimizer"], step=s["step"] + 1, loss=loss - ) + return loss, State(learner=state, optimizer=s["optimizer"], loss=loss) - @partial(jit, static_argnums=0) + @partial(jit, static_argnums=0, donate_argnames=("p", "s")) def forward_and_backward( self, x: PyTree, p: Param, s: State ) -> tuple[Param, State]: @@ -178,5 +180,6 @@ def forward_and_backward( P = optax.apply_updates(p, updates) return P, S + @dispatch def __call__(self, x: PyTree, p: Param, s: State) -> tuple[Param, State]: return self.forward_and_backward(x, p, s) diff --git a/src/julax/einops.py b/src/julax/einops.py index c7e41ba..1011624 100644 --- a/src/julax/einops.py +++ b/src/julax/einops.py @@ -6,6 +6,8 @@ import jax.numpy as jnp from jax.nn.initializers import Initializer from pydantic import computed_field + +from julax.base import FrozenDict from .core import LayerBase, Param, PyTree, State, PRNG @@ -26,7 +28,7 @@ def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: class Rearrange(LayerBase): pattern: str - sizes: dict + sizes: FrozenDict def __init__(self, pattern: str, **kwargs): super().__init__(pattern=pattern, sizes=kwargs) diff --git a/src/julax/experiment.py b/src/julax/experiment.py index 5a9719c..d80ef10 100644 --- a/src/julax/experiment.py +++ b/src/julax/experiment.py @@ -21,34 +21,37 @@ class Experiment(LayerBase): name: str = "mnist" seed: int = 0 - checkpoint_manager: ocp.CheckpointManager trainer: Trainer dataset: grain.IterDataset batch_axis_names: list[str] = ["data"] mesh_shape: dict[str, int] = {"data": -1} + checkpoint_manager: ocp.CheckpointManager | None = None observer: ObserverBase = Field(default_factory=default_observer) def state(self, rng: PRNG) -> State: - return State(input=iter(self.dataset)) + return State(input=iter(self.dataset), step=0) def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: P, S = self.trainer(x, p["trainer"], s["trainer"]) - return Param(trainer=P), State(trainer=S, input=s["input"]) + return Param(trainer=P), State(trainer=S, input=s["input"], step=s["step"] + 1) def save(self, p: Param, s: State): - self.checkpoint_manager.save( - s["trainer"]["step"], - args=ocp.args.Composite( - param=ocp.args.PyTreeSave(item=p), - state_trainer=ocp.args.PyTreeSave(item=s["trainer"]), - state_dataset_iter=grain.checkpoint.CheckpointSave(item=s["input"]), - ), - ) + if self.checkpoint_manager: + self.checkpoint_manager.save( + s["step"], + args=ocp.args.Composite( + param=ocp.args.PyTreeSave(item=p), + state_trainer=ocp.args.PyTreeSave(item=s["trainer"]), + state_dataset_iter=grain.checkpoint.CheckpointSave(item=s["input"]), + ), + ) def restore(self) -> tuple[Param, State]: p, s = self.init(self.seed) + if self.checkpoint_manager is None: + return p, s try: restored = self.checkpoint_manager.restore( step=None, @@ -78,7 +81,11 @@ def restore(self) -> tuple[Param, State]: ) return p, s - def run(self): + def close(self): + if self.checkpoint_manager: + self.checkpoint_manager.close() + + def run(self) -> tuple[Param, State]: with create_mesh(self.mesh_shape) as mesh: p, s = self.restore() self.observer(self, p, s) @@ -95,5 +102,4 @@ def run(self): self.observer(self, p, s) self.save(p, s) - self.checkpoint_manager.wait_until_finished() return p, s diff --git a/src/julax/layers.py b/src/julax/layers.py index 16bdb58..cc78f55 100644 --- a/src/julax/layers.py +++ b/src/julax/layers.py @@ -1,7 +1,8 @@ from typing import Callable -from jax import Array import jax +from jax import Array +from jax.sharding import PartitionSpec as P import jax.numpy as jnp from jax.nn.initializers import ( Initializer, @@ -11,6 +12,8 @@ variance_scaling, ) +from julax.base import Dtype + from .core import PRNG, LayerBase, LayerLike, PyTree, Param, State, dispatch @@ -26,6 +29,31 @@ def to_layer(x: Callable): return F(f=x) +class SkipConnection(LayerBase): + layer: LayerLike + connection: Callable = jnp.add + + def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: + S = State() + o, S["layer"] = self.layer(x, p["layer"], s["layer"]) + return self.connection(o, x), S + + +class Repeated(LayerBase): + n: int + layer: LayerLike + + def sublayers(self) -> dict: + return {f"layer_{i}": self.layer for i in range(self.n)} + + def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: + S = State() + o = x + for i in range(self.n): + o, S[f"layer_{i}"] = self.layer(o, p[f"layer_{i}"], s[f"layer_{i}"]) + return o, S + + class NamedLayers(LayerBase): names: tuple[str, ...] layers: tuple[LayerLike, ...] @@ -40,37 +68,46 @@ def sublayers(self) -> dict: class Chain(NamedLayers): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: h = x - S = {} + S = State() for name, layer in zip(self.names, self.layers): h, S[name] = layer(h, p[name], s[name]) - return h, State(**S) + return h, S class Branch(NamedLayers): """1 -> N""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: O = {} - S = {} + S = State() for name, layer in zip(self.names, self.layers): O[name], S[name] = layer(x, p[name], s[name]) # ??? return dict? - return tuple(O.values()), State(**S) + return tuple(O.values()), S class Parallel(NamedLayers): """N -> N""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: assert len(x) == len(self.layers) O = {} - S = {} + S = State() for name, layer, xᵢ in zip(self.names, self.layers, x): O[name], S[name] = layer(xᵢ, p[name], s[name]) # ??? return dict? - return tuple(O.values()), State(**S) + return tuple(O.values()), S ##### @@ -85,13 +122,31 @@ class Linear(LayerBase): def param(self, rng: PRNG) -> Param: rng_w, rng_b = jax.random.split(rng) return Param( - w=self.w_init(rng_w, (self.in_dim, self.out_dim)), - b=self.b_init(rng_b, (self.out_dim,)) if self.b_init else None, + w=self.w_init( + rng_w, + (self.in_dim, self.out_dim), + dtype=self.param_dtype, + out_sharding=self.param_sharding, + ), + b=( + self.b_init( + rng_b, + (self.out_dim,), + dtype=self.param_dtype, + out_sharding=( + None + if self.param_sharding is None + else P(self.param_sharding[-1]) + ), + ) + if self.b_init + else None + ), ) def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: - o = jnp.einsum("...d,dh->...h", x, p["w"]) - if "b" in p: + o = jnp.einsum("...d,dh->...h", x, p["w"], out_sharding=self.out_sharding) + if p["b"] is not None: o += p["b"] return o, s @@ -138,28 +193,91 @@ class Embedding(LayerBase): w_init: Initializer = variance_scaling(1.0, "fan_in", "normal", out_axis=0) def param(self, rng: PRNG) -> Param: - return Param(w=self.w_init(rng, (self.in_dim, self.out_dim))) + return Param( + w=self.w_init( + rng, + (self.in_dim, self.out_dim), + dtype=self.param_dtype, + out_sharding=self.param_sharding, + ) + ) + + def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: + return p["w"].at[x].get(out_sharding=self.out_sharding), s + +class RotaryEmbedding(LayerBase): + """Rotary Position Embedding.""" + + # Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/9204d6bbbf8bb19a05ebed72a55cfec687e0e044/src/MaxText/layers/embeddings.py#L271C11-L356C17 + embedding_dims: int + min_timescale: int = 1 + max_timescale: int = 10000 + cast_as_fprop_dtype: bool = True + fprop_dtype: Dtype = jnp.bfloat16 + rope_linear_scaling_factor: float = 1.0 + + def state(self, rng: PRNG) -> State: + half_embedding_dim = self.embedding_dims // 2 + fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims + timescale = ( + self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction + ) + if self.rope_linear_scaling_factor != 1.0: + timescale = timescale * self.rope_linear_scaling_factor + return State(timescale=timescale) + + def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: + seq_length = x.shape[1] + position = jnp.arange(seq_length, dtype=jnp.float32)[ + jnp.newaxis, :, jnp.newaxis, jnp.newaxis + ] + sinusoid_inp = position / s["timescale"] + sin = jnp.sin(sinusoid_inp).astype(x.dtype) + cos = jnp.cos(sinusoid_inp).astype(x.dtype) + first_half, second_half = jnp.split(x, 2, axis=-1) + first_part = first_half * cos - second_half * sin + second_part = second_half * cos + first_half * sin + if self.cast_as_fprop_dtype: + first_part = first_part.astype(self.fprop_dtype) + second_part = second_part.astype(self.fprop_dtype) + x_out = jnp.concatenate((first_part, second_part), axis=-1) + return x_out, s + + +class Unembedding(Embedding): def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: - return p["w"][x], s + return jnp.einsum("bld,dn->bln", x, p["w"], out_sharding=self.out_sharding), s class LayerNorm(LayerBase): dim: int - ϵ: float = 1e-5 + epsilon: float = 1e-5 w_init: Initializer = ones b_init: Initializer = zeros + compute_dtype: Dtype | None = None def param(self, rng: PRNG) -> Param: w_rng, b_rng = jax.random.split(rng) return Param( - w=self.w_init(w_rng, (self.dim,)), b=self.b_init(b_rng, (self.dim,)) + w=self.w_init( + w_rng, + (self.dim,), + dtype=self.param_dtype, + out_sharding=self.out_sharding, + ), + b=self.b_init( + b_rng, + (self.dim,), + dtype=self.param_dtype, + out_sharding=( + None if self.param_sharding is None else P(self.param_sharding[-1]) + ), + ), ) def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: - x_mean = x.mean(axis=-1, keepdims=True) - x -= x_mean - var = (x * x).mean(axis=-1, keepdims=True) - x = x * jax.lax.rsqrt(var + self.ϵ) - # TODO: cast dtype - return x * p["w"] + p["b"], s + x_std = jax.nn.standardize( + x.astype(self.compute_dtype), epsilon=self.epsilon + ).astype(self.param_dtype) + return x_std * p["w"] + p["b"], s diff --git a/src/julax/observers.py b/src/julax/observers.py index 3978532..af6c35b 100644 --- a/src/julax/observers.py +++ b/src/julax/observers.py @@ -1,6 +1,8 @@ import logging import time from typing import Protocol + +import jax from .core import Param, State logger = logging.getLogger(__name__) @@ -27,7 +29,7 @@ def __init__(self, observer: Observer, n: int = 1): self.observer = observer def __call__(self, x, p: Param, s: State): - step = s["trainer"]["step"] + step = s["step"] if step % self.n == 0: return self.observer(x, p, s) @@ -37,7 +39,7 @@ def __init__(self, observer: Observer): self.observer = observer def __call__(self, x, p: Param, s: State): - step = s["trainer"]["step"] + step = s["step"] if step == 0: return self.observer(x, p, s) @@ -59,9 +61,8 @@ def __call__(self, x, p: Param, s: State): class LossLogger(ObserverBase): def __call__(self, x, p: Param, s: State): loss = s["trainer"]["loss"] - step = s["trainer"]["step"] - if step > 0: - logger.info(f"Step {step}: loss = {loss}") + step = s["step"] + jax.debug.print("Step {step}: loss={loss}", step=step, loss=loss) class StepTimeLogger(ObserverBase): @@ -81,7 +82,7 @@ def __call__(self, x, p: Param, s: State): if self.step_count % self.n == 0: now = time.perf_counter() avg_time = (now - self.last_time) / self.step_count - step = s["trainer"]["step"] + step = s["step"] logger.info( f"Step {step}: avg step time over last {self.step_count} steps: {avg_time:.6f}s" ) @@ -90,4 +91,4 @@ def __call__(self, x, p: Param, s: State): def default_observer() -> CompositeObserver: - return LossLogger() * StepTimeLogger() + return DoEveryNSteps(LossLogger(), n=10) * StepTimeLogger() diff --git a/src/julax/utils.py b/src/julax/utils.py index b5176bb..1e15f69 100644 --- a/src/julax/utils.py +++ b/src/julax/utils.py @@ -4,6 +4,10 @@ from jax.experimental import mesh_utils +def identity(x): + return x + + def create_mesh(mesh_shape: dict[str, int]) -> Mesh: # TODO: support multi-slice values = list(mesh_shape.values())