-
Notifications
You must be signed in to change notification settings - Fork 1
Add a mini-transformer example #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a mini-transformer example to demonstrate the framework's capabilities for building and training transformer models. The changes refactor state management by moving the step counter from the trainer to the experiment level, add new layer primitives for building transformers, and enhance support for distributed training with JAX sharding.
Key changes:
- Moved step tracking from
s["trainer"]["step"]tos["step"]at the Experiment level for cleaner state architecture - Added
SkipConnection,Repeated, andUnembeddinglayer classes to support transformer architectures - Enhanced layers with
param_dtype,param_sharding, andout_shardingattributes for distributed training and mixed precision support
Reviewed changes
Copilot reviewed 9 out of 10 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| src/julax/observers.py | Updated observers to use top-level s["step"] instead of s["trainer"]["step"]; switched LossLogger to use jax.debug.print for JIT compatibility |
| src/julax/layers.py | Added SkipConnection, Repeated, and Unembedding layers; updated existing layers with sharding/dtype support; refactored LayerNorm to use jax.nn.standardize |
| src/julax/experiment.py | Moved step counter to top-level state; made checkpoint_manager optional; added close() method; updated run() to return final params/state |
| src/julax/einops.py | Changed Rearrange.sizes type from dict to FrozenDict for immutability |
| src/julax/core.py | Added param_dtype, param_sharding, out_sharding to LayerBase; removed step tracking from Trainer state; added donate_argnames to JIT decorator |
| src/julax/base.py | Added FrozenDict class and OutShardingType type alias for sharding support |
| experiments/transformer.py | Removed old transformer stub (replaced by mini_transformer.py) |
| experiments/mnist.py | Removed hardcoded checkpoint manager; updated to use top-level step in evaluate function; added E.close() call |
| experiments/mini_transformer.py | New file implementing a complete mini-transformer example with attention, MLP blocks, and proper initialization |
| .gitignore | Added tmp/ directory to gitignore |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class Repeated(LayerBase): | ||
| n: int | ||
| layer: LayerLike | ||
|
|
||
| def sublayers(self) -> dict: | ||
| return {f"layer_{i}": self.layer for i in range(self.n)} | ||
|
|
||
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: | ||
| S = State() | ||
| o = x | ||
| for i in range(self.n): | ||
| o, S[f"layer_{i}"] = self.layer(o, p[f"layer_{i}"], s[f"layer_{i}"]) | ||
| return o, S |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The Repeated class creates independent copies of parameters for each repetition (via sublayers() returning the same layer instance multiple times). This means each repetition will have different, independently trained parameters.
If weight sharing across repetitions is desired (common in some architectures), the current implementation won't support that. If independent parameters per layer is intended (as in typical transformer blocks), this is correct but should be documented.
Consider adding a docstring to clarify that each repetition has independent parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot open a new pull request to apply changes based on this feedback
| def __hash__(self): | ||
| return hash(frozenset(self.root.items())) |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FrozenDict.__hash__ implementation has a bug: it attempts to hash a dict's items directly via frozenset(self.root.items()), but dict items may contain unhashable values (like lists or other dicts). This will raise a TypeError when the dict contains unhashable values.
Since this is a Pydantic RootModel[dict] and dicts are generally unhashable, consider either:
- Removing the
__hash__method entirely (letting it raise TypeError when hashing is attempted) - Implementing a more robust hash that handles unhashable values
- Using
id(self)if object identity is sufficient
| def __hash__(self): | |
| return hash(frozenset(self.root.items())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but dict items may contain unhashable values (like lists or other dicts)
It will trigger an unhashable error at runtime. Left it for users.
| class SkipConnection(LayerBase): | ||
| layer: LayerLike | ||
| connection: Callable = jnp.add | ||
|
|
||
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: | ||
| S = State() | ||
| o, S["layer"] = self.layer(x, p["layer"], s["layer"]) | ||
| return self.connection(o, x), S | ||
|
|
||
|
|
||
| class Repeated(LayerBase): | ||
| n: int | ||
| layer: LayerLike |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new SkipConnection and Repeated layer classes lack docstrings, unlike other layer classes in the file (e.g., Branch has """1 -> N""" and Parallel has """N -> N"""). Adding docstrings would help users understand their purpose and usage.
Consider adding docstrings:
class SkipConnection(LayerBase):
"""Applies a layer and combines its output with the input using a connection function (default: addition)."""
# ...
class Repeated(LayerBase):
"""Applies the same layer n times sequentially."""
# ...| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: | ||
| h = x | ||
| S = {} | ||
| S = State() | ||
| for name, layer in zip(self.names, self.layers): | ||
| h, S[name] = layer(h, p[name], s[name]) | ||
| return h, State(**S) | ||
| return h, S | ||
|
|
||
|
|
||
| class Branch(NamedLayers): | ||
| """1 -> N""" | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: | ||
| O = {} | ||
| S = {} | ||
| S = State() | ||
| for name, layer in zip(self.names, self.layers): | ||
| O[name], S[name] = layer(x, p[name], s[name]) | ||
| # ??? return dict? | ||
| return tuple(O.values()), State(**S) | ||
| return tuple(O.values()), S | ||
|
|
||
|
|
||
| class Parallel(NamedLayers): | ||
| """N -> N""" | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The __init__ methods in Chain, Branch, and Parallel classes appear to be redundant as they only call super().__init__(*args, **kwargs) with no additional logic. Since these classes inherit from NamedLayers which already defines a custom __init__, and Pydantic's BaseModel automatically handles initialization, these explicit __init__ methods are unnecessary.
Consider removing these redundant __init__ methods unless there's a specific reason for their presence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intended to avoid pylint complaints.
| return p["w"].at[x].get(out_sharding=self.out_sharding), s | ||
|
|
||
|
|
||
| class Unembedding(Embedding): |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The Unembedding class inherits from Embedding, which may cause confusion about the meaning of in_dim and out_dim. In the usage (line 140-141 in mini_transformer.py), Unembedding(in_dim=dim, out_dim=num_vocab) creates a weight matrix of shape (dim, num_vocab), but the parameter names suggest the opposite direction compared to how they're used in the einsum operation.
For Unembedding, the einsum "bld,dn->bln" expects:
- Input: (batch, length, dim)
- Weight: (dim, num_vocab)
This matches the current implementation, but the inherited parameter names in_dim=dim and out_dim=num_vocab are semantically backward from a typical linear layer perspective where in_dim would be the input feature dimension and out_dim would be the output dimension.
Consider either:
- Adding a docstring to clarify the parameter semantics
- Overriding
__init__to swap or rename parameters for clarity - Not inheriting from
Embeddingand defining parameters explicitly
| class Unembedding(Embedding): | |
| class Unembedding(Embedding): | |
| """ | |
| Unembedding layer: projects from hidden dimension to vocabulary logits. | |
| Args: | |
| dim (int): Input feature dimension (hidden size). | |
| num_vocab (int): Output vocabulary size. | |
| w_init (Initializer, optional): Initializer for the weight matrix. | |
| The weight matrix has shape (dim, num_vocab), matching the einsum "bld,dn->bln". | |
| """ | |
| def __init__(self, dim: int, num_vocab: int, w_init: Initializer = variance_scaling(1.0, "fan_in", "normal", out_axis=0)): | |
| super().__init__(in_dim=dim, out_dim=num_vocab, w_init=w_init) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot create an issue to review it later.
| dispatch = plum.Dispatcher(warn_redefinition=True) | ||
|
|
||
|
|
||
| class FrozenDict(RootModel[dict]): |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class implements hash, but does not implement eq.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot open a new pull request to apply changes based on this feedback
|
@findmyway I've opened a new pull request, #2, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 9 out of 10 changed files in this pull request and generated 13 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class Repeated(LayerBase): | ||
| n: int | ||
| layer: LayerLike | ||
|
|
||
| def sublayers(self) -> dict: | ||
| return {f"layer_{i}": self.layer for i in range(self.n)} | ||
|
|
||
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: | ||
| S = State() | ||
| o = x | ||
| for i in range(self.n): | ||
| o, S[f"layer_{i}"] = self.layer(o, p[f"layer_{i}"], s[f"layer_{i}"]) | ||
| return o, S |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test coverage: The new Repeated layer lacks test coverage. Consider adding tests to verify that it correctly applies the same layer n times in sequence, properly manages parameter and state namespacing (layer_0, layer_1, etc.), and that the sublayers() method returns the correct dictionary structure.
| class Unembedding(Embedding): | ||
| def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: | ||
| return p["w"][x], s | ||
| return jnp.einsum("bld,dn->bln", x, p["w"], out_sharding=self.out_sharding), s |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test coverage: The new Unembedding layer lacks test coverage. Consider adding tests to verify the einsum operation produces the correct output shape and values, and that it properly supports the out_sharding parameter for distributed training.
| class FrozenDict(RootModel[dict]): | ||
| model_config = ConfigDict(frozen=True) | ||
|
|
||
| def __getitem__(self, item): | ||
| return self.root[item] | ||
|
|
||
| def __iter__(self): | ||
| return iter(self.root) | ||
|
|
||
| def keys(self): | ||
| return self.root.keys() | ||
|
|
||
| def values(self): | ||
| return self.root.values() | ||
|
|
||
| def items(self): | ||
| return self.root.items() | ||
|
|
||
| def __hash__(self): | ||
| return hash(frozenset(self.root.items())) |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Missing __len__ method: The FrozenDict class implements dict-like methods but is missing __len__, which is commonly expected for dict-like objects. Consider adding def __len__(self): return len(self.root) for completeness.
src/julax/layers.py
Outdated
| epsilon: float = 1e-5 | ||
| w_init: Initializer = ones | ||
| b_init: Initializer = zeros | ||
| compute_dtype: jnp.dtype | None = None |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Missing documentation: The new compute_dtype parameter in LayerNorm lacks documentation. Consider adding a comment or docstring explaining its purpose (likely for controlling precision during normalization computation) and what happens when it's None.
| from julax.observers import default_observer | ||
|
|
||
|
|
||
| class FakeSource(grain.sources.RandomAccessDataSource): |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Missing documentation: The FakeSource class lacks a docstring explaining its purpose as a simple synthetic data source for testing the transformer. Consider adding documentation to explain the repeating pattern in _data and why this particular sequence is used for training.
| class FakeSource(grain.sources.RandomAccessDataSource): | |
| class FakeSource(grain.sources.RandomAccessDataSource): | |
| """ | |
| A simple synthetic data source for testing transformer models. | |
| The data consists of a repeating pattern: | |
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1] | |
| repeated 1024 times to form a long sequence. | |
| This predictable, non-random sequence is useful for sequence modeling tasks, | |
| allowing the model to learn to predict the next token in a known pattern. | |
| """ |
| Rearrange( | ||
| "B T (qkv N H) -> B T (qkv N) H", | ||
| B=global_batch_size, | ||
| T=seq_len, | ||
| qkv=3, | ||
| N=num_heads, | ||
| H=head_dim, | ||
| ), | ||
| partial( | ||
| jnp.split, indices_or_sections=3, axis=2 | ||
| ), | ||
| lambda qkv: jax.nn.dot_product_attention( | ||
| *qkv, is_causal=True | ||
| ), | ||
| Rearrange( | ||
| "B T N H -> B T (N H)", | ||
| B=global_batch_size, | ||
| T=seq_len, | ||
| N=num_heads, | ||
| H=head_dim, |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded batch size in Rearrange patterns: The transformer architecture has hardcoded B=global_batch_size in the Rearrange operations (lines 89, 103). This makes the model inflexible to different batch sizes at inference time. Consider using einops' ability to infer dimensions or making the model work with variable batch sizes by removing the explicit B constraint or using a more flexible approach.
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant __init__ method: This __init__ method only calls super().__init__() with the same arguments and doesn't add any functionality. Since NamedLayers already defines the __init__ behavior, this override is unnecessary and can be removed.
| class SkipConnection(LayerBase): | ||
| layer: LayerLike | ||
| connection: Callable = jnp.add | ||
|
|
||
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: | ||
| S = State() | ||
| o, S["layer"] = self.layer(x, p["layer"], s["layer"]) | ||
| return self.connection(o, x), S |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test coverage: The new SkipConnection layer lacks test coverage. Consider adding tests to verify that it correctly combines the output of the wrapped layer with the input using the specified connection function, and that it properly handles state propagation.
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant __init__ method: This __init__ method only calls super().__init__() with the same arguments and doesn't add any functionality. Since NamedLayers already defines the __init__ behavior, this override is unnecessary and can be removed.
| w_rng, | ||
| (self.dim,), | ||
| dtype=self.param_dtype, | ||
| out_sharding=self.out_sharding, |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential inconsistency: The w parameter uses out_sharding instead of param_sharding. This is inconsistent with the b parameter which correctly uses param_sharding. The weight parameter sharding should likely use self.param_sharding instead of self.out_sharding.
| out_sharding=self.out_sharding, | |
| out_sharding=( | |
| None if self.param_sharding is None else P(self.param_sharding[-1]) | |
| ), |
|
@findmyway I've opened a new pull request, #3, to work on those changes. Once the pull request is ready, I'll request review from you. |
|
@findmyway I've opened a new pull request, #4, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: findmyway <5612003+findmyway@users.noreply.github.com>
Add __eq__ method to FrozenDict class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 10 out of 11 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return hash(frozenset(self.root.items())) | ||
|
|
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __hash__ implementation attempts to hash dictionary items directly using frozenset(self.root.items()). This will fail if any values in the dictionary are unhashable (e.g., lists, dicts, or mutable objects). Consider handling unhashable values or documenting this limitation. A safer approach might be to make values recursively frozen or to raise a more descriptive error.
| return hash(frozenset(self.root.items())) | |
| # Recursively freeze the dictionary to ensure all values are hashable | |
| return hash(self._recursive_freeze(self.root)) | |
| @staticmethod | |
| def _recursive_freeze(obj): | |
| """Recursively convert obj to a hashable structure.""" | |
| if isinstance(obj, dict): | |
| # Sort items to ensure consistent ordering | |
| return tuple(sorted((k, FrozenDict._recursive_freeze(v)) for k, v in obj.items())) | |
| elif isinstance(obj, (list, tuple)): | |
| return tuple(FrozenDict._recursive_freeze(v) for v in obj) | |
| elif isinstance(obj, set): | |
| return frozenset(FrozenDict._recursive_freeze(v) for v in obj) | |
| # Add more types here if needed (e.g., numpy arrays) | |
| # If the object is already hashable, return as is | |
| try: | |
| hash(obj) | |
| except TypeError: | |
| raise TypeError(f"Unhashable type encountered in FrozenDict: {type(obj)}") | |
| return obj |
| def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: | ||
| seq_length = x.shape[1] | ||
| position = jnp.arange(seq_length, dtype=jnp.float32)[ | ||
| jnp.newaxis, :, jnp.newaxis, jnp.newaxis | ||
| ] | ||
| sinusoid_inp = position / s["timescale"] | ||
| sin = jnp.sin(sinusoid_inp).astype(x.dtype) | ||
| cos = jnp.cos(sinusoid_inp).astype(x.dtype) | ||
| first_half, second_half = jnp.split(x, 2, axis=-1) | ||
| first_part = first_half * cos - second_half * sin | ||
| second_part = second_half * cos + first_half * sin | ||
| if self.cast_as_fprop_dtype: | ||
| first_part = first_part.astype(self.fprop_dtype) | ||
| second_part = second_part.astype(self.fprop_dtype) | ||
| x_out = jnp.concatenate((first_part, second_part), axis=-1) | ||
| return x_out, s |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The RoPE (Rotary Position Embedding) timescale computation in the state() method is performed once during initialization and stored. However, the forward() method recomputes the position array and sinusoids for every forward pass, even though these only depend on seq_length. For sequences of the same length (which is common in training), consider caching these computed values indexed by sequence length to avoid redundant computation.
| from jax.experimental import mesh_utils | ||
|
|
||
|
|
||
| def identity(x): |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The identity function lacks a docstring explaining its purpose. While the implementation is trivial, documentation would clarify its intended use case (e.g., as a placeholder layer in the Parallel composition as seen in mini_transformer.py line 110).
| def identity(x): | |
| def identity(x): | |
| """ | |
| Returns the input unchanged. | |
| This function can be used as a placeholder layer, for example in model compositions | |
| such as the `Parallel` composition, where a no-op function is required. | |
| """ |
| # TODO: cast dtype | ||
| return x * p["w"] + p["b"], s | ||
| x_std = jax.nn.standardize( | ||
| x.astype(self.compute_dtype), epsilon=self.epsilon |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The x.astype(self.compute_dtype) call will fail when compute_dtype is None (the default value). This should check if compute_dtype is not None before casting, or use a conditional expression like x if self.compute_dtype is None else x.astype(self.compute_dtype).
| x.astype(self.compute_dtype), epsilon=self.epsilon | |
| x if self.compute_dtype is None else x.astype(self.compute_dtype), epsilon=self.epsilon |
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method added to Chain, Branch, and Parallel classes (lines 71-72, 85-86, 100-101) only calls super().__init__(*args, **kwargs) without any additional logic. Since these methods don't add any behavior, they are redundant and can be removed. The parent class NamedLayers already has an __init__ method that will be called automatically.
| def save(self, p: Param, s: State): | ||
| self.checkpoint_manager.save( | ||
| s["trainer"]["step"], | ||
| args=ocp.args.Composite( | ||
| param=ocp.args.PyTreeSave(item=p), | ||
| state_trainer=ocp.args.PyTreeSave(item=s["trainer"]), | ||
| state_dataset_iter=grain.checkpoint.CheckpointSave(item=s["input"]), | ||
| ), | ||
| ) | ||
| if self.checkpoint_manager: | ||
| self.checkpoint_manager.save( | ||
| s["step"], | ||
| args=ocp.args.Composite( | ||
| param=ocp.args.PyTreeSave(item=p), | ||
| state_trainer=ocp.args.PyTreeSave(item=s["trainer"]), | ||
| state_dataset_iter=grain.checkpoint.CheckpointSave(item=s["input"]), | ||
| ), | ||
| ) |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The save() method doesn't include s["step"] in the checkpoint composite, and the restore() method on line 77 (not in changed region) doesn't include step in the returned State. This will cause a KeyError when observers try to access s["step"] after restoration. The step should be either: (1) saved in the composite as state_step=ocp.args.PyTreeSave(item=s["step"]) and restored, or (2) inferred from the checkpoint step number returned by the restore operation.
| w_rng, | ||
| (self.dim,), | ||
| dtype=self.param_dtype, | ||
| out_sharding=self.out_sharding, |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The w parameter initialization uses out_sharding=self.out_sharding (line 266), but should use out_sharding=self.param_sharding to be consistent with other layers like Linear (line 129) and with the b parameter initialization in this same method (line 272). The out_sharding field should be used for operation outputs, not parameter initialization.
| out_sharding=self.out_sharding, | |
| out_sharding=( | |
| None if self.param_sharding is None else P(self.param_sharding[-1]) | |
| ), |
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method added to Chain, Branch, and Parallel classes (lines 71-72, 85-86, 100-101) only calls super().__init__(*args, **kwargs) without any additional logic. Since these methods don't add any behavior, they are redundant and can be removed. The parent class NamedLayers already has an __init__ method that will be called automatically.
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method added to Chain, Branch, and Parallel classes (lines 71-72, 85-86, 100-101) only calls super().__init__(*args, **kwargs) without any additional logic. Since these methods don't add any behavior, they are redundant and can be removed. The parent class NamedLayers already has an __init__ method that will be called automatically.
|
@copilot Merge it |
|
@findmyway I've opened a new pull request, #5, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 11 out of 12 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| x_std = jax.nn.standardize( | ||
| x.astype(self.compute_dtype), epsilon=self.epsilon | ||
| ).astype(self.param_dtype) | ||
| return x_std * p["w"] + p["b"], s |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype handling in LayerNorm has issues:
- When
self.compute_dtypeisNone,.astype(None)will raise a TypeError - The output is cast to
self.param_dtypewhich is likely incorrect - the output should match the input dtype or remain in compute dtype
Consider this fix:
compute_dtype = self.compute_dtype if self.compute_dtype is not None else x.dtype
x_std = jax.nn.standardize(x.astype(compute_dtype), epsilon=self.epsilon)
return (x_std * p["w"] + p["b"]).astype(x.dtype), s| x_std = jax.nn.standardize( | |
| x.astype(self.compute_dtype), epsilon=self.epsilon | |
| ).astype(self.param_dtype) | |
| return x_std * p["w"] + p["b"], s | |
| compute_dtype = self.compute_dtype if self.compute_dtype is not None else x.dtype | |
| x_std = jax.nn.standardize( | |
| x.astype(compute_dtype), epsilon=self.epsilon | |
| ) | |
| return (x_std * p["w"] + p["b"]).astype(x.dtype), s |
| w_rng, | ||
| (self.dim,), | ||
| dtype=self.param_dtype, | ||
| out_sharding=self.out_sharding, |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect sharding parameter. The w_init call should use out_sharding=self.param_sharding instead of out_sharding=self.out_sharding. The pattern used in other layers (e.g., Linear at lines 125-130, Embedding at lines 197-202) shows that parameter initialization should use param_sharding, while out_sharding is used for computation outputs.
| out_sharding=self.out_sharding, | |
| out_sharding=self.param_sharding, |
| return x_out, s | ||
|
|
||
|
|
||
| class Unembedding(Embedding): |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The Unembedding class inherits from Embedding but uses a completely different forward operation (matrix multiplication vs. embedding lookup). This inheritance relationship is semantically incorrect - the two operations are fundamentally different.
Consider making Unembedding a standalone class that doesn't inherit from Embedding, or use composition instead of inheritance. The current design violates the Liskov Substitution Principle since an Unembedding cannot be used interchangeably with an Embedding.
| class Unembedding(Embedding): | |
| class Unembedding(LayerBase): | |
| in_dim: int | |
| out_dim: int | |
| w_init: Initializer = variance_scaling(1.0, "fan_in", "normal", out_axis=0) | |
| def param(self, rng: PRNG) -> Param: | |
| return Param( | |
| w=self.w_init( | |
| rng, | |
| (self.in_dim, self.out_dim), | |
| dtype=self.param_dtype, | |
| out_sharding=self.param_sharding, | |
| ) | |
| ) |
| class SkipConnection(LayerBase): | ||
| layer: LayerLike | ||
| connection: Callable = jnp.add | ||
|
|
||
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: | ||
| S = State() | ||
| o, S["layer"] = self.layer(x, p["layer"], s["layer"]) | ||
| return self.connection(o, x), S |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new SkipConnection layer lacks test coverage. Consider adding tests to verify:
- The skip connection correctly adds the input to the layer output
- The state is properly propagated
- Custom connection functions work as expected
| class Repeated(LayerBase): | ||
| n: int | ||
| layer: LayerLike | ||
|
|
||
| def sublayers(self) -> dict: | ||
| return {f"layer_{i}": self.layer for i in range(self.n)} | ||
|
|
||
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: | ||
| S = State() | ||
| o = x | ||
| for i in range(self.n): | ||
| o, S[f"layer_{i}"] = self.layer(o, p[f"layer_{i}"], s[f"layer_{i}"]) | ||
| return o, S |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new Repeated layer lacks test coverage. Consider adding tests to verify:
- The layer is correctly repeated n times
- Parameters and states are properly namespaced (layer_0, layer_1, etc.)
- The output of one iteration correctly feeds into the next
| class RotaryEmbedding(LayerBase): | ||
| """Rotary Position Embedding.""" | ||
|
|
||
| # Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/9204d6bbbf8bb19a05ebed72a55cfec687e0e044/src/MaxText/layers/embeddings.py#L271C11-L356C17 | ||
| embedding_dims: int | ||
| min_timescale: int = 1 | ||
| max_timescale: int = 10000 | ||
| cast_as_fprop_dtype: bool = True | ||
| fprop_dtype: Dtype = jnp.bfloat16 | ||
| rope_linear_scaling_factor: float = 1.0 | ||
|
|
||
| def state(self, rng: PRNG) -> State: | ||
| half_embedding_dim = self.embedding_dims // 2 | ||
| fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims | ||
| timescale = ( | ||
| self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction | ||
| ) | ||
| if self.rope_linear_scaling_factor != 1.0: | ||
| timescale = timescale * self.rope_linear_scaling_factor | ||
| return State(timescale=timescale) | ||
|
|
||
| def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: | ||
| seq_length = x.shape[1] | ||
| position = jnp.arange(seq_length, dtype=jnp.float32)[ | ||
| jnp.newaxis, :, jnp.newaxis, jnp.newaxis | ||
| ] | ||
| sinusoid_inp = position / s["timescale"] | ||
| sin = jnp.sin(sinusoid_inp).astype(x.dtype) | ||
| cos = jnp.cos(sinusoid_inp).astype(x.dtype) | ||
| first_half, second_half = jnp.split(x, 2, axis=-1) | ||
| first_part = first_half * cos - second_half * sin | ||
| second_part = second_half * cos + first_half * sin | ||
| if self.cast_as_fprop_dtype: | ||
| first_part = first_part.astype(self.fprop_dtype) | ||
| second_part = second_part.astype(self.fprop_dtype) | ||
| x_out = jnp.concatenate((first_part, second_part), axis=-1) | ||
| return x_out, s |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new RotaryEmbedding layer lacks test coverage. Consider adding tests to verify:
- The rotary position embeddings are computed correctly
- The timescale calculation matches the expected behavior
- The linear scaling factor works as intended
- The output shape matches the input shape
| class Unembedding(Embedding): | ||
| def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: | ||
| return p["w"][x], s | ||
| return jnp.einsum("bld,dn->bln", x, p["w"], out_sharding=self.out_sharding), s |
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new Unembedding layer lacks test coverage. Consider adding tests to verify:
- The matrix multiplication with the embedding weights is computed correctly
- The output shape is correct (batch, length, vocab_size)
- Sharding parameters work as expected
No description provided.