Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Nov 28, 2025

Adds a mini-transformer example demonstrating the framework's capabilities for building and training transformer models with JAX sharding support.

Changes

  • State management: Moved step counter from s["trainer"]["step"] to s["step"] at experiment level
  • New layer primitives:
    • SkipConnection - residual connections with configurable combiner function
    • Repeated - sequential repetition of a layer with independent parameters per instance
    • Unembedding - projects from hidden dimension to vocabulary logits
    • RoPE - rotary position embeddings
  • Enhanced layer attributes: Added param_dtype, param_sharding, and out_sharding for mixed precision and distributed training
  • FrozenDict: Added __eq__ and __len__ methods for proper dict-like behavior

Example usage

from julax.layers import Chain, SkipConnection, Repeated, LayerNorm, Linear

transformer_block = SkipConnection(
    layer=Chain(layers=[
        LayerNorm(dim=512),
        Attention(...),
    ])
)

model = Chain(layers=[
    Embedding(in_dim=vocab_size, out_dim=512),
    Repeated(n=6, layer=transformer_block),
    Unembedding(in_dim=512, out_dim=vocab_size),
])

✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Copilot AI changed the title [WIP] Add a mini-transformer example Add a mini-transformer example Nov 28, 2025
Copilot AI requested a review from findmyway November 28, 2025 14:09
@findmyway findmyway closed this Nov 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants