-
Notifications
You must be signed in to change notification settings - Fork 1
Add Llama 3.2 1B #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds support for Llama 3.2 1B model training and includes significant refactoring of the julax framework including:
- Major layer class refactoring (Repeated→Repeat, SkipConnection→Residual)
- New visualization capabilities with rich console output
- Addition of RMSNorm layer and improved embedding layers
- New test infrastructure for input preprocessing
Reviewed changes
Copilot reviewed 10 out of 13 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| uv.lock | Updates dependencies including new packages for Llama support (certifi, requests, tiktoken, safetensors, tensorflow-datasets) |
| tests/test_inputs.py | Adds comprehensive test cases for text input preprocessing with packed sequences |
| src/julax/pprint.py | Implements rich console pretty-printing for model structure visualization |
| src/julax/layers.py | Major refactoring: renames Repeated→Repeat, SkipConnection→Residual; adds RMSNorm, Select layers; improves param/state length tracking |
| src/julax/inputs.py | Removes placeholder implementation (file deleted) |
| src/julax/experiment.py | Adds cached mesh property for improved performance |
| src/julax/core.py | Removes common dtype/sharding fields from LayerBase; adds numel() method for parameter counting; adds IPython display support |
| src/julax/base.py | Adds len method to PyTree wrapper |
| pyproject.toml | Updates version to 0.0.4-dev; moves dev dependencies to dependency-groups; adds new dev dependencies |
| experiments/03_Llama_3.2_1B.py | Implements complete Llama 3.2 1B architecture with HuggingFace weight loading capability |
| experiments/02_mini_transformer.py | Updates to use refactored API (Repeat, Residual) |
| experiments/01_mnist.py | New MNIST training example demonstrating framework usage |
| .gitignore | Adds models/ directory to ignore list |
Comments suppressed due to low confidence (2)
experiments/02_mini_transformer.py:84
- The parameter name changed from "layer" to "processor" in the Residual class. However, in the Residual class implementation in layers.py, the constructor takes "processor" as the first parameter but internally stores it in a field called "skip_through" and another field. The API inconsistency needs review to ensure the class works as expected.
experiments/02_mini_transformer.py:33 - The class name "Repeated" was changed to "Repeat" in the layers module, but the import reference here still uses the old name "Unembedding". However, looking at the diff, the Unembedding class definition was removed from layers.py (line 305-307 in the old version). This import will fail at runtime.
| [project] | ||
| name = "julax" | ||
| version = "0.0.3" | ||
| version = "0.0.4-dev" |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The version string "0.0.4-dev" uses a hyphen separator which is not compliant with PEP 440. According to PEP 440, development releases should use ".dev0" notation (e.g., "0.0.4.dev0") instead of "-dev". This may cause issues with package managers and version comparisons.
| version = "0.0.4-dev" | |
| version = "0.0.4.dev0" |
|
|
||
| class Parallel(NamedLayers): | ||
| class Residual(Branch): | ||
| def __init__(self, processor, *, skip_through=identity, reduce: Callable = jnp.add): |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a type inconsistency in the Residual class constructor. It accepts "skip_through" and "reduce" parameters where "skip_through" defaults to identity function, but the Branch parent class uses "reduce" parameter. The Residual then passes both processor and skip_through to Branch's super().init as positional arguments, which expects named layers. This design may cause confusion about which parameters go where.
| def __init__(self, processor, *, skip_through=identity, reduce: Callable = jnp.add): | |
| def __init__(self, processor, *, skip_through: LayerLike | None = None, reduce: Callable = jnp.add): | |
| if skip_through is None: | |
| skip_through = identity |
|
|
||
| def verify(): | ||
| m = create_transformer() | ||
| p, s = m.init() |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 'p' is unnecessary as it is redefined before this value is used.
| p, s = m.init() |
|
|
||
| def verify(): | ||
| m = create_transformer() | ||
| p, s = m.init() |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 's' is unnecessary as it is redefined before this value is used.
| p, s = m.init() |
No description provided.